mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
Compare commits
12 Commits
v0.23.7
...
yury/use-s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
026958c22a | ||
|
|
9bfab103a0 | ||
|
|
f7e6cdcbf0 | ||
|
|
af6fdd3af2 | ||
|
|
5781ec7a8e | ||
|
|
1219006a6e | ||
|
|
4791e41004 | ||
|
|
9131069d12 | ||
|
|
dece311076 | ||
|
|
206d903de5 | ||
|
|
3e20f23646 | ||
|
|
025fefc6bd |
1
.gitattributes
vendored
Normal file
1
.gitattributes
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
*.go text eol=lf
|
||||||
20
.github/workflows/golangci-lint.yml
vendored
20
.github/workflows/golangci-lint.yml
vendored
@@ -1,12 +1,23 @@
|
|||||||
name: golangci-lint
|
name: golangci-lint
|
||||||
on: [pull_request]
|
on: [pull_request]
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
pull-requests: read
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
golangci:
|
golangci:
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
os: [macos-latest, windows-latest, ubuntu-latest]
|
||||||
name: lint
|
name: lint
|
||||||
runs-on: ubuntu-latest
|
runs-on: ${{ matrix.os }}
|
||||||
|
timeout-minutes: 15
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
@@ -14,7 +25,12 @@ jobs:
|
|||||||
uses: actions/setup-go@v4
|
uses: actions/setup-go@v4
|
||||||
with:
|
with:
|
||||||
go-version: "1.20.x"
|
go-version: "1.20.x"
|
||||||
|
cache: false
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
|
if: matrix.os == 'ubuntu-latest'
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
|
||||||
- name: golangci-lint
|
- name: golangci-lint
|
||||||
uses: golangci/golangci-lint-action@v3
|
uses: golangci/golangci-lint-action@v3
|
||||||
|
with:
|
||||||
|
version: latest
|
||||||
|
args: --timeout=12m
|
||||||
21
.github/workflows/test-infrastructure-files.yml
vendored
21
.github/workflows/test-infrastructure-files.yml
vendored
@@ -112,6 +112,27 @@ jobs:
|
|||||||
grep -A 6 PKCEAuthorizationFlow management.json | grep -A 5 ProviderConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT
|
grep -A 6 PKCEAuthorizationFlow management.json | grep -A 5 ProviderConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT
|
||||||
grep -A 7 PKCEAuthorizationFlow management.json | grep -A 6 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
|
grep -A 7 PKCEAuthorizationFlow management.json | grep -A 6 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
|
||||||
|
|
||||||
|
- name: Install modules
|
||||||
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: Build management binary
|
||||||
|
working-directory: management
|
||||||
|
run: CGO_ENABLED=1 go build -o netbird-mgmt main.go
|
||||||
|
|
||||||
|
- name: Build management docker image
|
||||||
|
working-directory: management
|
||||||
|
run: |
|
||||||
|
docker build -t netbirdio/management:latest .
|
||||||
|
|
||||||
|
- name: Build signal binary
|
||||||
|
working-directory: signal
|
||||||
|
run: CGO_ENABLED=0 go build -o netbird-signal main.go
|
||||||
|
|
||||||
|
- name: Build signal docker image
|
||||||
|
working-directory: signal
|
||||||
|
run: |
|
||||||
|
docker build -t netbirdio/signal:latest .
|
||||||
|
|
||||||
- name: run docker compose up
|
- name: run docker compose up
|
||||||
working-directory: infrastructure_files
|
working-directory: infrastructure_files
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
5
client/ui/build-ui-linux.sh
Normal file
5
client/ui/build-ui-linux.sh
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
sudo apt update
|
||||||
|
sudo apt remove gir1.2-appindicator3-0.1
|
||||||
|
sudo apt install -y libayatana-appindicator3-dev
|
||||||
|
go build
|
||||||
@@ -202,9 +202,10 @@ func (s *serviceClient) getSettingsForm() *widget.Form {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_, err = client.Login(s.ctx, &proto.LoginRequest{
|
_, err = client.Login(s.ctx, &proto.LoginRequest{
|
||||||
ManagementUrl: s.iMngURL.Text,
|
ManagementUrl: s.iMngURL.Text,
|
||||||
AdminURL: s.iAdminURL.Text,
|
AdminURL: s.iAdminURL.Text,
|
||||||
PreSharedKey: s.iPreSharedKey.Text,
|
PreSharedKey: s.iPreSharedKey.Text,
|
||||||
|
IsLinuxDesktopClient: runtime.GOOS == "linux",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("login to management URL: %v", err)
|
log.Errorf("login to management URL: %v", err)
|
||||||
@@ -233,7 +234,9 @@ func (s *serviceClient) login() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{})
|
loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{
|
||||||
|
IsLinuxDesktopClient: runtime.GOOS == "linux",
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("login to management URL with: %v", err)
|
log.Errorf("login to management URL with: %v", err)
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -62,12 +62,9 @@ type AccountManager interface {
|
|||||||
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
|
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
|
||||||
MarkPATUsed(tokenID string) error
|
MarkPATUsed(tokenID string) error
|
||||||
GetUser(claims jwtclaims.AuthorizationClaims) (*User, error)
|
GetUser(claims jwtclaims.AuthorizationClaims) (*User, error)
|
||||||
AccountExists(accountId string) (*bool, error)
|
|
||||||
GetPeerByKey(peerKey string) (*Peer, error)
|
|
||||||
GetPeers(accountID, userID string) ([]*Peer, error)
|
GetPeers(accountID, userID string) ([]*Peer, error)
|
||||||
MarkPeerConnected(peerKey string, connected bool) error
|
MarkPeerConnected(peerKey string, connected bool) error
|
||||||
DeletePeer(accountID, peerID, userID string) error
|
DeletePeer(accountID, peerID, userID string) error
|
||||||
GetPeerByIP(accountId string, peerIP string) (*Peer, error)
|
|
||||||
UpdatePeer(accountID, userID string, peer *Peer) (*Peer, error)
|
UpdatePeer(accountID, userID string, peer *Peer) (*Peer, error)
|
||||||
GetNetworkMap(peerID string) (*NetworkMap, error)
|
GetNetworkMap(peerID string) (*NetworkMap, error)
|
||||||
GetPeerNetwork(peerID string) (*Network, error)
|
GetPeerNetwork(peerID string) (*Network, error)
|
||||||
@@ -84,7 +81,6 @@ type AccountManager interface {
|
|||||||
ListGroups(accountId string) ([]*Group, error)
|
ListGroups(accountId string) ([]*Group, error)
|
||||||
GroupAddPeer(accountId, groupID, peerID string) error
|
GroupAddPeer(accountId, groupID, peerID string) error
|
||||||
GroupDeletePeer(accountId, groupID, peerID string) error
|
GroupDeletePeer(accountId, groupID, peerID string) error
|
||||||
GroupListPeers(accountId, groupID string) ([]*Peer, error)
|
|
||||||
GetPolicy(accountID, policyID, userID string) (*Policy, error)
|
GetPolicy(accountID, policyID, userID string) (*Policy, error)
|
||||||
SavePolicy(accountID, userID string, policy *Policy) error
|
SavePolicy(accountID, userID string, policy *Policy) error
|
||||||
DeletePolicy(accountID, policyID, userID string) error
|
DeletePolicy(accountID, policyID, userID string) error
|
||||||
@@ -303,17 +299,6 @@ func (a *Account) GetRoutesByPrefix(prefix netip.Prefix) []*route.Route {
|
|||||||
return routes
|
return routes
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPeerByIP returns peer by it's IP if exists under account or nil otherwise
|
|
||||||
func (a *Account) GetPeerByIP(peerIP string) *Peer {
|
|
||||||
for _, peer := range a.Peers {
|
|
||||||
if peerIP == peer.IP.String() {
|
|
||||||
return peer
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetGroup returns a group by ID if exists, nil otherwise
|
// GetGroup returns a group by ID if exists, nil otherwise
|
||||||
func (a *Account) GetGroup(groupID string) *Group {
|
func (a *Account) GetGroup(groupID string) *Group {
|
||||||
return a.Groups[groupID]
|
return a.Groups[groupID]
|
||||||
@@ -1491,9 +1476,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
|
|||||||
if err := am.Store.SaveAccount(account); err != nil {
|
if err := am.Store.SaveAccount(account); err != nil {
|
||||||
log.Errorf("failed to save account: %v", err)
|
log.Errorf("failed to save account: %v", err)
|
||||||
} else {
|
} else {
|
||||||
if err := am.updateAccountPeers(account); err != nil {
|
am.updateAccountPeers(account)
|
||||||
log.Errorf("failed updating account peers while updating user %s", account.Id)
|
|
||||||
}
|
|
||||||
for _, g := range addNewGroups {
|
for _, g := range addNewGroups {
|
||||||
if group := account.GetGroup(g); group != nil {
|
if group := account.GetGroup(g); group != nil {
|
||||||
am.storeEvent(user.Id, user.Id, account.Id, activity.GroupAddedToUser,
|
am.storeEvent(user.Id, user.Id, account.Id, activity.GroupAddedToUser,
|
||||||
@@ -1604,26 +1587,6 @@ func isDomainValid(domain string) bool {
|
|||||||
return re.Match([]byte(domain))
|
return re.Match([]byte(domain))
|
||||||
}
|
}
|
||||||
|
|
||||||
// AccountExists checks whether account exists (returns true) or not (returns false)
|
|
||||||
func (am *DefaultAccountManager) AccountExists(accountID string) (*bool, error) {
|
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
var res bool
|
|
||||||
_, err := am.Store.GetAccount(accountID)
|
|
||||||
if err != nil {
|
|
||||||
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
|
||||||
res = false
|
|
||||||
return &res, nil
|
|
||||||
} else {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
res = true
|
|
||||||
return &res, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDNSDomain returns the configured dnsDomain
|
// GetDNSDomain returns the configured dnsDomain
|
||||||
func (am *DefaultAccountManager) GetDNSDomain() string {
|
func (am *DefaultAccountManager) GetDNSDomain() string {
|
||||||
return am.dnsDomain
|
return am.dnsDomain
|
||||||
|
|||||||
@@ -706,30 +706,6 @@ func createAccount(am *DefaultAccountManager, accountID, userID, domain string)
|
|||||||
return account, nil
|
return account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccountManager_AccountExists(t *testing.T) {
|
|
||||||
manager, err := createManager(t)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
expectedId := "test_account"
|
|
||||||
userId := "account_creator"
|
|
||||||
_, err = createAccount(manager, expectedId, userId, "")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
exists, err := manager.AccountExists(expectedId)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !*exists {
|
|
||||||
t.Errorf("expected account to exist after creation, got false")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAccountManager_GetAccount(t *testing.T) {
|
func TestAccountManager_GetAccount(t *testing.T) {
|
||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -122,7 +122,9 @@ func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string
|
|||||||
am.storeEvent(userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
|
am.storeEvent(userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.updateAccountPeers(account)
|
am.updateAccountPeers(account)
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func toProtocolDNSConfig(update nbdns.Config) *proto.DNSConfig {
|
func toProtocolDNSConfig(update nbdns.Config) *proto.DNSConfig {
|
||||||
|
|||||||
@@ -84,10 +84,7 @@ func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *G
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.updateAccountPeers(account)
|
am.updateAccountPeers(account)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// the following snippet tracks the activity and stores the group events in the event store.
|
// the following snippet tracks the activity and stores the group events in the event store.
|
||||||
// It has to happen after all the operations have been successfully performed.
|
// It has to happen after all the operations have been successfully performed.
|
||||||
@@ -229,7 +226,9 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string)
|
|||||||
|
|
||||||
am.storeEvent(userId, groupID, accountId, activity.GroupDeleted, g.EventMeta())
|
am.storeEvent(userId, groupID, accountId, activity.GroupDeleted, g.EventMeta())
|
||||||
|
|
||||||
return am.updateAccountPeers(account)
|
am.updateAccountPeers(account)
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListGroups objects of the peers
|
// ListGroups objects of the peers
|
||||||
@@ -281,7 +280,9 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string)
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.updateAccountPeers(account)
|
am.updateAccountPeers(account)
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GroupDeletePeer removes peer from the group
|
// GroupDeletePeer removes peer from the group
|
||||||
@@ -309,31 +310,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID stri
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.updateAccountPeers(account)
|
am.updateAccountPeers(account)
|
||||||
}
|
|
||||||
|
return nil
|
||||||
// GroupListPeers returns list of the peers from the group
|
|
||||||
func (am *DefaultAccountManager) GroupListPeers(accountID, groupID string) ([]*Peer, error) {
|
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.Errorf(status.NotFound, "account not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
group, ok := account.Groups[groupID]
|
|
||||||
if !ok {
|
|
||||||
return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
|
||||||
}
|
|
||||||
|
|
||||||
peers := make([]*Peer, 0, len(account.Groups))
|
|
||||||
for _, peerID := range group.Peers {
|
|
||||||
p, ok := account.Peers[peerID]
|
|
||||||
if ok {
|
|
||||||
peers = append(peers, p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return peers, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
|
|||||||
if appMetrics != nil {
|
if appMetrics != nil {
|
||||||
// update gauge based on number of connected peers which is equal to open gRPC streams
|
// update gauge based on number of connected peers which is equal to open gRPC streams
|
||||||
err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
|
err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
|
||||||
return int64(len(peersUpdateManager.peerChannels))
|
return peersUpdateManager.Len()
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -53,14 +53,6 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandle
|
|||||||
Issued: server.GroupIssuedAPI,
|
Issued: server.GroupIssuedAPI,
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) {
|
|
||||||
for _, peer := range TestPeers {
|
|
||||||
if peer.IP.String() == peerIP {
|
|
||||||
return peer, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("peer not found")
|
|
||||||
},
|
|
||||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||||
return &server.Account{
|
return &server.Account{
|
||||||
Id: claims.AccountId,
|
Id: claims.AccountId,
|
||||||
|
|||||||
@@ -125,15 +125,6 @@ func initRoutesTestData() *RoutesHandler {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) {
|
|
||||||
if peerIP != existingPeerID {
|
|
||||||
return nil, status.Errorf(status.NotFound, "Peer with ID %s not found", peerIP)
|
|
||||||
}
|
|
||||||
return &server.Peer{
|
|
||||||
Key: existingPeerKey,
|
|
||||||
IP: netip.MustParseAddr(existingPeerID).AsSlice(),
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||||
return testingAccount, testingAccount.Users["test_user"], nil
|
return testingAccount, testingAccount.Users["test_user"], nil
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -20,12 +20,9 @@ type MockAccountManager struct {
|
|||||||
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
|
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
|
||||||
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
|
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
|
||||||
GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||||
AccountExistsFunc func(accountId string) (*bool, error)
|
|
||||||
GetPeerByKeyFunc func(peerKey string) (*server.Peer, error)
|
|
||||||
GetPeersFunc func(accountID, userID string) ([]*server.Peer, error)
|
GetPeersFunc func(accountID, userID string) ([]*server.Peer, error)
|
||||||
MarkPeerConnectedFunc func(peerKey string, connected bool) error
|
MarkPeerConnectedFunc func(peerKey string, connected bool) error
|
||||||
DeletePeerFunc func(accountID, peerKey, userID string) error
|
DeletePeerFunc func(accountID, peerKey, userID string) error
|
||||||
GetPeerByIPFunc func(accountId string, peerIP string) (*server.Peer, error)
|
|
||||||
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
|
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
|
||||||
GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
|
GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
|
||||||
AddPeerFunc func(setupKey string, userId string, peer *server.Peer) (*server.Peer, *server.NetworkMap, error)
|
AddPeerFunc func(setupKey string, userId string, peer *server.Peer) (*server.Peer, *server.NetworkMap, error)
|
||||||
@@ -35,7 +32,6 @@ type MockAccountManager struct {
|
|||||||
ListGroupsFunc func(accountID string) ([]*server.Group, error)
|
ListGroupsFunc func(accountID string) ([]*server.Group, error)
|
||||||
GroupAddPeerFunc func(accountID, groupID, peerID string) error
|
GroupAddPeerFunc func(accountID, groupID, peerID string) error
|
||||||
GroupDeletePeerFunc func(accountID, groupID, peerID string) error
|
GroupDeletePeerFunc func(accountID, groupID, peerID string) error
|
||||||
GroupListPeersFunc func(accountID, groupID string) ([]*server.Peer, error)
|
|
||||||
GetRuleFunc func(accountID, ruleID, userID string) (*server.Rule, error)
|
GetRuleFunc func(accountID, ruleID, userID string) (*server.Rule, error)
|
||||||
SaveRuleFunc func(accountID, userID string, rule *server.Rule) error
|
SaveRuleFunc func(accountID, userID string, rule *server.Rule) error
|
||||||
DeleteRuleFunc func(accountID, ruleID, userID string) error
|
DeleteRuleFunc func(accountID, ruleID, userID string) error
|
||||||
@@ -140,22 +136,6 @@ func (am *MockAccountManager) GetAccountByUserOrAccountID(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AccountExists mock implementation of AccountExists from server.AccountManager interface
|
|
||||||
func (am *MockAccountManager) AccountExists(accountId string) (*bool, error) {
|
|
||||||
if am.AccountExistsFunc != nil {
|
|
||||||
return am.AccountExistsFunc(accountId)
|
|
||||||
}
|
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method AccountExists is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPeerByKey mocks implementation of GetPeerByKey from server.AccountManager interface
|
|
||||||
func (am *MockAccountManager) GetPeerByKey(peerKey string) (*server.Peer, error) {
|
|
||||||
if am.GetPeerByKeyFunc != nil {
|
|
||||||
return am.GetPeerByKeyFunc(peerKey)
|
|
||||||
}
|
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method GetPeerByKey is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
|
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
|
||||||
func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool) error {
|
func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool) error {
|
||||||
if am.MarkPeerConnectedFunc != nil {
|
if am.MarkPeerConnectedFunc != nil {
|
||||||
@@ -164,14 +144,6 @@ func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool)
|
|||||||
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPeerByIP mock implementation of GetPeerByIP from server.AccountManager interface
|
|
||||||
func (am *MockAccountManager) GetPeerByIP(accountId string, peerIP string) (*server.Peer, error) {
|
|
||||||
if am.GetPeerByIPFunc != nil {
|
|
||||||
return am.GetPeerByIPFunc(accountId, peerIP)
|
|
||||||
}
|
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method GetPeerByIP is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface
|
// GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface
|
||||||
func (am *MockAccountManager) GetAccountFromPAT(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
|
func (am *MockAccountManager) GetAccountFromPAT(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
|
||||||
if am.GetAccountFromPATFunc != nil {
|
if am.GetAccountFromPATFunc != nil {
|
||||||
@@ -296,14 +268,6 @@ func (am *MockAccountManager) GroupDeletePeer(accountID, groupID, peerID string)
|
|||||||
return status.Errorf(codes.Unimplemented, "method GroupDeletePeer is not implemented")
|
return status.Errorf(codes.Unimplemented, "method GroupDeletePeer is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GroupListPeers mock implementation of GroupListPeers from server.AccountManager interface
|
|
||||||
func (am *MockAccountManager) GroupListPeers(accountID, groupID string) ([]*server.Peer, error) {
|
|
||||||
if am.GroupListPeersFunc != nil {
|
|
||||||
return am.GroupListPeersFunc(accountID, groupID)
|
|
||||||
}
|
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method GroupListPeers is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetRule mock implementation of GetRule from server.AccountManager interface
|
// GetRule mock implementation of GetRule from server.AccountManager interface
|
||||||
func (am *MockAccountManager) GetRule(accountID, ruleID, userID string) (*server.Rule, error) {
|
func (am *MockAccountManager) GetRule(accountID, ruleID, userID string) (*server.Rule, error) {
|
||||||
if am.GetRuleFunc != nil {
|
if am.GetRuleFunc != nil {
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
@@ -74,11 +73,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.updateAccountPeers(account)
|
am.updateAccountPeers(account)
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
return newNSGroup.Copy(), status.Errorf(status.Internal, "failed to update peers after create nameserver %s", name)
|
|
||||||
}
|
|
||||||
|
|
||||||
am.storeEvent(userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
|
am.storeEvent(userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
|
||||||
|
|
||||||
@@ -113,11 +108,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, n
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.updateAccountPeers(account)
|
am.updateAccountPeers(account)
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
return status.Errorf(status.Internal, "failed to update peers after update nameserver %s", nsGroupToSave.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
am.storeEvent(userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
|
am.storeEvent(userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
|
||||||
|
|
||||||
@@ -147,10 +138,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, use
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.updateAccountPeers(account)
|
am.updateAccountPeers(account)
|
||||||
if err != nil {
|
|
||||||
return status.Errorf(status.Internal, "failed to update peers after deleting nameserver %s", nsGroupID)
|
|
||||||
}
|
|
||||||
|
|
||||||
am.storeEvent(userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
|
am.storeEvent(userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
|
||||||
|
|
||||||
|
|||||||
@@ -195,16 +195,6 @@ func (p *PeerStatus) Copy() *PeerStatus {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPeerByKey looks up peer by its public WireGuard key
|
|
||||||
func (am *DefaultAccountManager) GetPeerByKey(peerPubKey string) (*Peer, error) {
|
|
||||||
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return account.FindPeerByPubKey(peerPubKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
|
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
|
||||||
// the current user is not an admin.
|
// the current user is not an admin.
|
||||||
func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*Peer, error) {
|
func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*Peer, error) {
|
||||||
@@ -290,10 +280,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected
|
|||||||
if oldStatus.LoginExpired {
|
if oldStatus.LoginExpired {
|
||||||
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
|
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
|
||||||
// the expired one. Here we notify them that connection is now allowed again.
|
// the expired one. Here we notify them that connection is now allowed again.
|
||||||
err = am.updateAccountPeers(account)
|
am.updateAccountPeers(account)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -364,10 +351,7 @@ func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *Pe
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.updateAccountPeers(account)
|
am.updateAccountPeers(account)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return peer, nil
|
return peer, nil
|
||||||
}
|
}
|
||||||
@@ -433,26 +417,9 @@ func (am *DefaultAccountManager) DeletePeer(accountID, peerID, userID string) er
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.updateAccountPeers(account)
|
am.updateAccountPeers(account)
|
||||||
}
|
|
||||||
|
|
||||||
// GetPeerByIP returns peer by its IP
|
return nil
|
||||||
func (am *DefaultAccountManager) GetPeerByIP(accountID string, peerIP string) (*Peer, error) {
|
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, peer := range account.Peers {
|
|
||||||
if peerIP == peer.IP.String() {
|
|
||||||
return peer, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, status.Errorf(status.NotFound, "peer with IP %s not found", peerIP)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
|
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
|
||||||
@@ -622,10 +589,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
|
|||||||
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
|
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
|
||||||
am.storeEvent(opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
am.storeEvent(opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
||||||
|
|
||||||
err = am.updateAccountPeers(account)
|
am.updateAccountPeers(account)
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
networkMap := account.GetPeerNetworkMap(newPeer.ID, am.dnsDomain)
|
networkMap := account.GetPeerNetworkMap(newPeer.ID, am.dnsDomain)
|
||||||
return newPeer, networkMap, nil
|
return newPeer, networkMap, nil
|
||||||
@@ -740,10 +704,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*Peer, *NetworkMap,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if updateRemotePeers {
|
if updateRemotePeers {
|
||||||
err = am.updateAccountPeers(account)
|
am.updateAccountPeers(account)
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil
|
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil
|
||||||
}
|
}
|
||||||
@@ -817,10 +778,7 @@ func (am *DefaultAccountManager) checkAndUpdatePeerSSHKey(peer *Peer, account *A
|
|||||||
}
|
}
|
||||||
|
|
||||||
// trigger network map update
|
// trigger network map update
|
||||||
err = am.updateAccountPeers(account)
|
am.updateAccountPeers(account)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return peer, nil
|
return peer, nil
|
||||||
}
|
}
|
||||||
@@ -865,7 +823,9 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerID string, sshKey string)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// trigger network map update
|
// trigger network map update
|
||||||
return am.updateAccountPeers(account)
|
am.updateAccountPeers(account)
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPeer for a given accountID, peerID and userID error if not found.
|
// GetPeer for a given accountID, peerID and userID error if not found.
|
||||||
@@ -922,18 +882,12 @@ func updatePeerMeta(peer *Peer, meta PeerSystemMeta, account *Account) (*Peer, b
|
|||||||
|
|
||||||
// updateAccountPeers updates all peers that belong to an account.
|
// updateAccountPeers updates all peers that belong to an account.
|
||||||
// Should be called when changes have to be synced to peers.
|
// Should be called when changes have to be synced to peers.
|
||||||
func (am *DefaultAccountManager) updateAccountPeers(account *Account) error {
|
func (am *DefaultAccountManager) updateAccountPeers(account *Account) {
|
||||||
peers := account.GetPeers()
|
peers := account.GetPeers()
|
||||||
|
|
||||||
for _, peer := range peers {
|
for _, peer := range peers {
|
||||||
remotePeerNetworkMap, err := am.GetNetworkMap(peer.ID)
|
remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain())
|
update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain())
|
||||||
am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update})
|
am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update})
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -350,7 +350,9 @@ func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Po
|
|||||||
}
|
}
|
||||||
am.storeEvent(userID, policy.ID, accountID, action, policy.EventMeta())
|
am.storeEvent(userID, policy.ID, accountID, action, policy.EventMeta())
|
||||||
|
|
||||||
return am.updateAccountPeers(account)
|
am.updateAccountPeers(account)
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePolicy from the store
|
// DeletePolicy from the store
|
||||||
@@ -375,7 +377,9 @@ func (am *DefaultAccountManager) DeletePolicy(accountID, policyID, userID string
|
|||||||
|
|
||||||
am.storeEvent(userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
am.storeEvent(userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
||||||
|
|
||||||
return am.updateAccountPeers(account)
|
am.updateAccountPeers(account)
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListPolicies from the store
|
// ListPolicies from the store
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetRoute gets a route object from account and route IDs
|
// GetRoute gets a route object from account and route IDs
|
||||||
@@ -185,11 +184,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.updateAccountPeers(account)
|
am.updateAccountPeers(account)
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
return &newRoute, status.Errorf(status.Internal, "failed to update peers after create route %s", newPrefix)
|
|
||||||
}
|
|
||||||
|
|
||||||
am.storeEvent(userID, newRoute.ID, accountID, activity.RouteCreated, newRoute.EventMeta())
|
am.storeEvent(userID, newRoute.ID, accountID, activity.RouteCreated, newRoute.EventMeta())
|
||||||
|
|
||||||
@@ -250,10 +245,7 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.updateAccountPeers(account)
|
am.updateAccountPeers(account)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
am.storeEvent(userID, routeToSave.ID, accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
am.storeEvent(userID, routeToSave.ID, accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
||||||
|
|
||||||
@@ -283,7 +275,9 @@ func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string)
|
|||||||
|
|
||||||
am.storeEvent(userID, routy.ID, accountID, activity.RouteRemoved, routy.EventMeta())
|
am.storeEvent(userID, routy.ID, accountID, activity.RouteRemoved, routy.EventMeta())
|
||||||
|
|
||||||
return am.updateAccountPeers(account)
|
am.updateAccountPeers(account)
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListRoutes returns a list of routes from account
|
// ListRoutes returns a list of routes from account
|
||||||
|
|||||||
@@ -317,7 +317,9 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return newKey, am.updateAccountPeers(account)
|
am.updateAccountPeers(account)
|
||||||
|
|
||||||
|
return newKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListSetupKeys returns a list of all setup keys of the account
|
// ListSetupKeys returns a list of all setup keys of the account
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
@@ -14,25 +15,29 @@ type UpdateMessage struct {
|
|||||||
Update *proto.SyncResponse
|
Update *proto.SyncResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type UpdateChannel chan *UpdateMessage
|
||||||
|
|
||||||
type PeersUpdateManager struct {
|
type PeersUpdateManager struct {
|
||||||
// peerChannels is an update channel indexed by Peer.ID
|
// peerChannels is an update channel indexed by Peer.ID
|
||||||
peerChannels map[string]chan *UpdateMessage
|
peerChannels sync.Map
|
||||||
channelsMux *sync.Mutex
|
// peerChannelLocks keeps the peer locks to organize channel creations
|
||||||
|
peerChannelLocks sync.Map
|
||||||
|
// len is the length of peerChannels
|
||||||
|
len atomic.Int64
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPeersUpdateManager returns a new instance of PeersUpdateManager
|
// NewPeersUpdateManager returns a new instance of PeersUpdateManager
|
||||||
func NewPeersUpdateManager() *PeersUpdateManager {
|
func NewPeersUpdateManager() *PeersUpdateManager {
|
||||||
return &PeersUpdateManager{
|
return &PeersUpdateManager{}
|
||||||
peerChannels: make(map[string]chan *UpdateMessage),
|
|
||||||
channelsMux: &sync.Mutex{},
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendUpdate sends update message to the peer's channel
|
// SendUpdate sends update message to the peer's channel
|
||||||
func (p *PeersUpdateManager) SendUpdate(peerID string, update *UpdateMessage) {
|
func (p *PeersUpdateManager) SendUpdate(peerID string, update *UpdateMessage) {
|
||||||
p.channelsMux.Lock()
|
if ch, ok := p.peerChannels.Load(peerID); ok {
|
||||||
defer p.channelsMux.Unlock()
|
channel, ok := ch.(UpdateChannel)
|
||||||
if channel, ok := p.peerChannels[peerID]; ok {
|
if !ok {
|
||||||
|
log.Warnf("could not cast to UpdateChannel")
|
||||||
|
}
|
||||||
select {
|
select {
|
||||||
case channel <- update:
|
case channel <- update:
|
||||||
log.Debugf("update was sent to channel for peer %s", peerID)
|
log.Debugf("update was sent to channel for peer %s", peerID)
|
||||||
@@ -45,35 +50,48 @@ func (p *PeersUpdateManager) SendUpdate(peerID string, update *UpdateMessage) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer.
|
// CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer.
|
||||||
func (p *PeersUpdateManager) CreateChannel(peerID string) chan *UpdateMessage {
|
func (p *PeersUpdateManager) CreateChannel(peerID string) UpdateChannel {
|
||||||
p.channelsMux.Lock()
|
// we have to lock the whole operation by peerID as we do two non atomic operations:
|
||||||
defer p.channelsMux.Unlock()
|
// - closeChannel()
|
||||||
|
// - Store
|
||||||
|
value, _ := p.peerChannelLocks.LoadOrStore(peerID, &sync.Mutex{})
|
||||||
|
mtx := value.(*sync.Mutex)
|
||||||
|
mtx.Lock()
|
||||||
|
defer mtx.Unlock()
|
||||||
|
|
||||||
|
p.closeChannel(peerID)
|
||||||
|
|
||||||
if channel, ok := p.peerChannels[peerID]; ok {
|
|
||||||
delete(p.peerChannels, peerID)
|
|
||||||
close(channel)
|
|
||||||
}
|
|
||||||
// mbragin: todo shouldn't it be more? or configurable?
|
// mbragin: todo shouldn't it be more? or configurable?
|
||||||
channel := make(chan *UpdateMessage, channelBufferSize)
|
channel := make(UpdateChannel, channelBufferSize)
|
||||||
p.peerChannels[peerID] = channel
|
p.peerChannels.Store(peerID, channel)
|
||||||
|
p.len.Add(1)
|
||||||
|
|
||||||
log.Debugf("opened updates channel for a peer %s", peerID)
|
log.Debugf("opened updates channel for a peer %s", peerID)
|
||||||
return channel
|
return channel
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PeersUpdateManager) closeChannel(peerID string) {
|
func (p *PeersUpdateManager) GetChannel(peerID string) UpdateChannel {
|
||||||
if channel, ok := p.peerChannels[peerID]; ok {
|
if ch, ok := p.peerChannels.Load(peerID); ok {
|
||||||
delete(p.peerChannels, peerID)
|
channel := ch.(UpdateChannel)
|
||||||
close(channel)
|
return channel
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
log.Debugf("closed updates channel of a peer %s", peerID)
|
func (p *PeersUpdateManager) closeChannel(peerID string) {
|
||||||
|
if ch, ok := p.peerChannels.LoadAndDelete(peerID); ok {
|
||||||
|
channel, ok := ch.(UpdateChannel)
|
||||||
|
if !ok {
|
||||||
|
log.Errorf("could not cast to UpdateChannel")
|
||||||
|
}
|
||||||
|
p.len.Add(-1)
|
||||||
|
close(channel)
|
||||||
|
log.Debugf("closed updates channel of a peer %s", peerID)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseChannels closes updates channel for each given peer
|
// CloseChannels closes updates channel for each given peer
|
||||||
func (p *PeersUpdateManager) CloseChannels(peerIDs []string) {
|
func (p *PeersUpdateManager) CloseChannels(peerIDs []string) {
|
||||||
p.channelsMux.Lock()
|
|
||||||
defer p.channelsMux.Unlock()
|
|
||||||
for _, id := range peerIDs {
|
for _, id := range peerIDs {
|
||||||
p.closeChannel(id)
|
p.closeChannel(id)
|
||||||
}
|
}
|
||||||
@@ -81,18 +99,22 @@ func (p *PeersUpdateManager) CloseChannels(peerIDs []string) {
|
|||||||
|
|
||||||
// CloseChannel closes updates channel of a given peer
|
// CloseChannel closes updates channel of a given peer
|
||||||
func (p *PeersUpdateManager) CloseChannel(peerID string) {
|
func (p *PeersUpdateManager) CloseChannel(peerID string) {
|
||||||
p.channelsMux.Lock()
|
|
||||||
defer p.channelsMux.Unlock()
|
|
||||||
p.closeChannel(peerID)
|
p.closeChannel(peerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllConnectedPeers returns a copy of the connected peers map
|
// GetAllConnectedPeers returns a copy of the connected peers map
|
||||||
func (p *PeersUpdateManager) GetAllConnectedPeers() map[string]struct{} {
|
func (p *PeersUpdateManager) GetAllConnectedPeers() map[string]struct{} {
|
||||||
p.channelsMux.Lock()
|
|
||||||
defer p.channelsMux.Unlock()
|
|
||||||
m := make(map[string]struct{})
|
m := make(map[string]struct{})
|
||||||
for ID := range p.peerChannels {
|
p.peerChannels.Range(func(key any, value any) bool {
|
||||||
m[ID] = struct{}{}
|
if ID, ok := key.(string); ok {
|
||||||
}
|
m[ID] = struct{}{}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Len returns the length of the peer channels
|
||||||
|
func (p *PeersUpdateManager) Len() (len int64) {
|
||||||
|
return p.len.Load()
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/netbirdio/netbird/management/proto"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
//var peersUpdater *PeersUpdateManager
|
//var peersUpdater *PeersUpdateManager
|
||||||
@@ -13,10 +14,22 @@ func TestCreateChannel(t *testing.T) {
|
|||||||
peersUpdater := NewPeersUpdateManager()
|
peersUpdater := NewPeersUpdateManager()
|
||||||
defer peersUpdater.CloseChannel(peer)
|
defer peersUpdater.CloseChannel(peer)
|
||||||
|
|
||||||
|
if peersUpdater.Len() != 0 {
|
||||||
|
t.Error("peersUpdated should not have any channels yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ch := peersUpdater.GetChannel(peer); ch != nil {
|
||||||
|
t.Errorf("We should not have channel for %s yet", peer)
|
||||||
|
}
|
||||||
|
|
||||||
_ = peersUpdater.CreateChannel(peer)
|
_ = peersUpdater.CreateChannel(peer)
|
||||||
if _, ok := peersUpdater.peerChannels[peer]; !ok {
|
if ch := peersUpdater.GetChannel(peer); ch == nil {
|
||||||
t.Error("Error creating the channel")
|
t.Error("Error creating the channel")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if peersUpdater.Len() != 1 {
|
||||||
|
t.Error("peersUpdated should have 1 channel")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSendUpdate(t *testing.T) {
|
func TestSendUpdate(t *testing.T) {
|
||||||
@@ -28,12 +41,12 @@ func TestSendUpdate(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}}
|
}}
|
||||||
_ = peersUpdater.CreateChannel(peer)
|
_ = peersUpdater.CreateChannel(peer)
|
||||||
if _, ok := peersUpdater.peerChannels[peer]; !ok {
|
if ch := peersUpdater.GetChannel(peer); ch == nil {
|
||||||
t.Error("Error creating the channel")
|
t.Error("Error creating the channel")
|
||||||
}
|
}
|
||||||
peersUpdater.SendUpdate(peer, update1)
|
peersUpdater.SendUpdate(peer, update1)
|
||||||
select {
|
select {
|
||||||
case <-peersUpdater.peerChannels[peer]:
|
case <-peersUpdater.GetChannel(peer):
|
||||||
default:
|
default:
|
||||||
t.Error("Update wasn't send")
|
t.Error("Update wasn't send")
|
||||||
}
|
}
|
||||||
@@ -54,7 +67,7 @@ func TestSendUpdate(t *testing.T) {
|
|||||||
select {
|
select {
|
||||||
case <-timeout:
|
case <-timeout:
|
||||||
t.Error("timed out reading previously sent updates")
|
t.Error("timed out reading previously sent updates")
|
||||||
case updateReader := <-peersUpdater.peerChannels[peer]:
|
case updateReader := <-peersUpdater.GetChannel(peer):
|
||||||
if updateReader.Update.NetworkMap.Serial == update2.Update.NetworkMap.Serial {
|
if updateReader.Update.NetworkMap.Serial == update2.Update.NetworkMap.Serial {
|
||||||
t.Error("got the update that shouldn't have been sent")
|
t.Error("got the update that shouldn't have been sent")
|
||||||
}
|
}
|
||||||
@@ -67,11 +80,11 @@ func TestCloseChannel(t *testing.T) {
|
|||||||
peer := "test-close"
|
peer := "test-close"
|
||||||
peersUpdater := NewPeersUpdateManager()
|
peersUpdater := NewPeersUpdateManager()
|
||||||
_ = peersUpdater.CreateChannel(peer)
|
_ = peersUpdater.CreateChannel(peer)
|
||||||
if _, ok := peersUpdater.peerChannels[peer]; !ok {
|
if ch := peersUpdater.GetChannel(peer); ch == nil {
|
||||||
t.Error("Error creating the channel")
|
t.Error("Error creating the channel")
|
||||||
}
|
}
|
||||||
peersUpdater.CloseChannel(peer)
|
peersUpdater.CloseChannel(peer)
|
||||||
if _, ok := peersUpdater.peerChannels[peer]; ok {
|
if ch := peersUpdater.GetChannel(peer); ch != nil {
|
||||||
t.Error("Error closing the channel")
|
t.Error("Error closing the channel")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -377,7 +377,9 @@ func (am *DefaultAccountManager) deleteRegularUser(account *Account, initiatorUs
|
|||||||
meta := map[string]any{"name": tuName, "email": tuEmail}
|
meta := map[string]any{"name": tuName, "email": tuEmail}
|
||||||
am.storeEvent(initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
|
am.storeEvent(initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
|
||||||
|
|
||||||
return am.updateAccountPeers(account)
|
am.updateAccountPeers(account)
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) deleteUserPeers(initiatorUserID string, targetUserID string, account *Account) error {
|
func (am *DefaultAccountManager) deleteUserPeers(initiatorUserID string, targetUserID string, account *Account) error {
|
||||||
@@ -674,9 +676,7 @@ func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, upd
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := am.updateAccountPeers(account); err != nil {
|
am.updateAccountPeers(account)
|
||||||
log.Errorf("failed updating account peers while updating user %s", accountID)
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
if err = am.Store.SaveAccount(account); err != nil {
|
if err = am.Store.SaveAccount(account); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -870,9 +870,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(account *Account, peers []
|
|||||||
if len(peerIDs) != 0 {
|
if len(peerIDs) != 0 {
|
||||||
// this will trigger peer disconnect from the management service
|
// this will trigger peer disconnect from the management service
|
||||||
am.peersUpdateManager.CloseChannels(peerIDs)
|
am.peersUpdateManager.CloseChannels(peerIDs)
|
||||||
if err := am.updateAccountPeers(account); err != nil {
|
am.updateAccountPeers(account)
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user