Compare commits

..

6 Commits

Author SHA1 Message Date
Yury Gargay
026958c22a Lock by PeerID when CreateChannel 2023-10-09 14:16:08 +02:00
Yury Gargay
9bfab103a0 Merge remote-tracking branch 'upstream/main' into yury/use-sync-map-in-updatechannel 2023-10-09 13:54:21 +02:00
Yury Gargay
dece311076 Merge remote-tracking branch 'upstream/main' into yury/use-sync-map-in-updatechannel 2023-10-04 14:03:34 +02:00
Yury Gargay
206d903de5 Update error message 2023-09-25 10:34:10 +02:00
Yury Gargay
3e20f23646 Rework Len() method via atomic counter 2023-09-25 10:34:10 +02:00
Yury Gargay
025fefc6bd Use sync.Map in PeersUpdateManager 2023-09-25 10:34:10 +02:00
21 changed files with 349 additions and 485 deletions

View File

@@ -1,5 +1,4 @@
#!/bin/bash
set -e
if ! which curl >/dev/null 2>&1; then
echo "This script uses curl fetch OpenID configuration from IDP."
@@ -155,8 +154,6 @@ if [ -n "$NETBIRD_MGMT_IDP" ]; then
export NETBIRD_IDP_MGMT_CLIENT_ID
export NETBIRD_IDP_MGMT_CLIENT_SECRET
export NETBIRD_IDP_MGMT_EXTRA_CONFIG=$EXTRA_CONFIG
else
export NETBIRD_IDP_MGMT_EXTRA_CONFIG={}
fi
IFS=',' read -r -a REDIRECT_URL_PORTS <<< "$NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS"
@@ -173,29 +170,8 @@ if [ "$NETBIRD_DASH_AUTH_USE_AUDIENCE" = "false" ]; then
export NETBIRD_AUTH_PKCE_AUDIENCE=
fi
# Read the encryption key
if test -f 'management.json'; then
encKey=$(jq -r ".DataStoreEncryptionKey" management.json)
if [[ "$encKey" != "null" ]]; then
export NETBIRD_DATASTORE_ENC_KEY=$encKey
fi
fi
env | grep NETBIRD
bkp_postfix="$(date +%s)"
if test -f 'docker-compose.yml'; then
cp docker-compose.yml "docker-compose.yml.bkp.${bkp_postfix}"
fi
if test -f 'management.json'; then
cp management.json "management.json.bkp.${bkp_postfix}"
fi
if test -f 'turnserver.conf'; then
cp turnserver.conf "turnserver.conf.bpk.${bkp_postfix}"
fi
envsubst <docker-compose.yml.tmpl >docker-compose.yml
envsubst <management.json.tmpl | jq . >management.json
envsubst <turnserver.conf.tmpl >turnserver.conf
envsubst <management.json.tmpl >management.json
envsubst <turnserver.conf.tmpl >turnserver.conf

View File

@@ -27,7 +27,6 @@
"Password": null
},
"Datadir": "",
"DataStoreEncryptionKey": "$NETBIRD_DATASTORE_ENC_KEY",
"HttpConfig": {
"Address": "0.0.0.0:$NETBIRD_MGMT_API_PORT",
"AuthIssuer": "$NETBIRD_AUTH_AUTHORITY",

View File

@@ -103,7 +103,6 @@ type AccountManager interface {
UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error)
LoginPeer(login PeerLogin) (*Peer, *NetworkMap, error) // used by peer gRPC API
SyncPeer(sync PeerSync) (*Peer, *NetworkMap, error) // used by peer gRPC API
GetAllConnectedPeers() (map[string]struct{}, error)
}
type DefaultAccountManager struct {
@@ -180,7 +179,7 @@ type Account struct {
Policies []*Policy
Routes map[string]*route.Route
NameServerGroups map[string]*nbdns.NameServerGroup
DNSSettings DNSSettings
DNSSettings *DNSSettings
// Settings is a dictionary of Account settings
Settings *Settings
}
@@ -201,7 +200,7 @@ type UserInfo struct {
// from the ACL peers that have distribution groups associated with the peer ID.
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
func (a *Account) getRoutesToSync(peerID string, aclPeers []*Peer) []*route.Route {
routes, peerDisabledRoutes := a.getRoutingPeerRoutes(peerID)
routes, peerDisabledRoutes := a.getEnabledAndDisabledRoutesByPeer(peerID)
peerRoutesMembership := make(lookupMap)
for _, r := range append(routes, peerDisabledRoutes...) {
peerRoutesMembership[route.GetHAUniqueID(r)] = struct{}{}
@@ -209,7 +208,7 @@ func (a *Account) getRoutesToSync(peerID string, aclPeers []*Peer) []*route.Rout
groupListMap := a.getPeerGroups(peerID)
for _, peer := range aclPeers {
activeRoutes, _ := a.getRoutingPeerRoutes(peer.ID)
activeRoutes, _ := a.getEnabledAndDisabledRoutesByPeer(peer.ID)
groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, groupListMap)
filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership)
routes = append(routes, filteredRoutes...)
@@ -245,32 +244,20 @@ func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap looku
return filteredRoutes
}
// getRoutingPeerRoutes returns the enabled and disabled lists of routes that the given routing peer serves
// getEnabledAndDisabledRoutesByPeer returns the enabled and disabled lists of routes that belong to a peer.
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
// If the given is not a routing peer, then the lists are empty.
func (a *Account) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) {
peer := a.GetPeer(peerID)
if peer == nil {
log.Errorf("peer %s that doesn't exist under account %s", peerID, a.Id)
return enabledRoutes, disabledRoutes
}
// currently we support only linux routing peers
if peer.Meta.GoOS != "linux" {
return enabledRoutes, disabledRoutes
}
seenRoute := make(map[string]struct{})
func (a *Account) getEnabledAndDisabledRoutesByPeer(peerID string) ([]*route.Route, []*route.Route) {
var enabledRoutes []*route.Route
var disabledRoutes []*route.Route
takeRoute := func(r *route.Route, id string) {
if _, ok := seenRoute[r.ID]; ok {
peer := a.GetPeer(peerID)
if peer == nil {
log.Errorf("route %s has peer %s that doesn't exist under account %s", r.ID, peerID, a.Id)
return
}
seenRoute[r.ID] = struct{}{}
if r.Enabled {
r.Peer = peer.Key
enabledRoutes = append(enabledRoutes, r)
return
}
@@ -278,30 +265,25 @@ func (a *Account) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Ro
}
for _, r := range a.Routes {
for _, groupID := range r.PeerGroups {
group := a.GetGroup(groupID)
if group == nil {
log.Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id)
continue
}
for _, id := range group.Peers {
if id != peerID {
if len(r.PeerGroups) != 0 {
for _, groupID := range r.PeerGroups {
group := a.GetGroup(groupID)
if group == nil {
log.Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id)
continue
}
newPeerRoute := r.Copy()
newPeerRoute.Peer = id
newPeerRoute.PeerGroups = nil
newPeerRoute.ID = r.ID + ":" + id // we have to provide unique route id when distribute network map
takeRoute(newPeerRoute, id)
break
for _, id := range group.Peers {
if id == peerID {
takeRoute(r, id)
break
}
}
}
}
if r.Peer == peerID {
takeRoute(r.Copy(), peerID)
takeRoute(r, peerID)
}
}
return enabledRoutes, disabledRoutes
}
@@ -337,7 +319,50 @@ func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap {
peersToConnect = append(peersToConnect, p)
}
routesUpdate := a.getRoutesToSync(peerID, peersToConnect)
routes := a.getRoutesToSync(peerID, peersToConnect)
takePeer := func(id string) (*Peer, bool) {
peer := a.GetPeer(id)
if peer == nil || peer.Meta.GoOS != "linux" {
return nil, false
}
return peer, true
}
// We need to set Peer.Key instead of Peer.ID because this object will be sent to agents as part of a network map.
// Ideally we should have a separate field for that, but fine for now.
var routesUpdate []*route.Route
seenPeers := make(map[string]bool)
for _, r := range routes {
if r.Peer != "" {
peer, valid := takePeer(r.Peer)
if !valid {
continue
}
rCopy := r.Copy()
rCopy.Peer = peer.Key // client expects the key
routesUpdate = append(routesUpdate, rCopy)
continue
}
for _, groupID := range r.PeerGroups {
if group := a.GetGroup(groupID); group != nil {
for _, peerId := range group.Peers {
peer, valid := takePeer(peerId)
if !valid {
continue
}
if _, ok := seenPeers[peer.ID]; !ok {
rCopy := r.Copy()
rCopy.ID = r.ID + ":" + peer.ID // we have to provide unit route id when distribute network map
rCopy.Peer = peer.Key // client expects the key
routesUpdate = append(routesUpdate, rCopy)
}
seenPeers[peer.ID] = true
}
}
}
}
dnsManagementStatus := a.getPeerDNSManagementStatus(peerID)
dnsUpdate := nbdns.Config{
@@ -513,11 +538,13 @@ func (a *Account) getUserGroups(userID string) ([]string, error) {
func (a *Account) getPeerDNSManagementStatus(peerID string) bool {
peerGroups := a.getPeerGroups(peerID)
enabled := true
for _, groupID := range a.DNSSettings.DisabledManagementGroups {
_, found := peerGroups[groupID]
if found {
enabled = false
break
if a.DNSSettings != nil {
for _, groupID := range a.DNSSettings.DisabledManagementGroups {
_, found := peerGroups[groupID]
if found {
enabled = false
break
}
}
}
return enabled
@@ -604,7 +631,10 @@ func (a *Account) Copy() *Account {
nsGroups[id] = nsGroup.Copy()
}
dnsSettings := a.DNSSettings.Copy()
var dnsSettings *DNSSettings
if a.DNSSettings != nil {
dnsSettings = a.DNSSettings.Copy()
}
var settings *Settings
if a.Settings != nil {
@@ -942,7 +972,6 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
if err != nil {
return err
}
log.Infof("%d entries received from IdP management", len(userData))
// If the Identity Provider does not support writing AppMetadata,
// in cases like this, we expect it to return all users in an "unset" field.
@@ -1042,7 +1071,6 @@ func (am *DefaultAccountManager) loadAccount(_ context.Context, accountID interf
if err != nil {
return nil, err
}
log.Debugf("%d entries received from IdP management", len(userData))
dataMap := make(map[string]*idp.UserData, len(userData))
for _, datum := range userData {
@@ -1554,11 +1582,6 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
}
}
// GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers()
func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) {
return am.peersUpdateManager.GetAllConnectedPeers(), nil
}
func isDomainValid(domain string) bool {
re := regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`)
return re.Match([]byte(domain))
@@ -1613,7 +1636,7 @@ func newAccountWithId(accountID, userID, domain string) *Account {
setupKeys := map[string]*SetupKey{}
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
users[userID] = NewAdminUser(userID)
dnsSettings := DNSSettings{
dnsSettings := &DNSSettings{
DisabledManagementGroups: make([]string, 0),
}
log.Debugf("created new account %s", accountID)

View File

@@ -1237,7 +1237,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
}
account := &Account{
Peers: map[string]*Peer{
"peer-1": {Key: "peer-1", Meta: PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: PeerSystemMeta{GoOS: "linux"}},
"peer-1": {Key: "peer-1"}, "peer-2": {Key: "peer-2"}, "peer-3": {Key: "peer-1"},
},
Groups: map[string]*Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}},
Routes: map[string]*route.Route{
@@ -1374,7 +1374,7 @@ func TestAccount_Copy(t *testing.T) {
NameServers: []nbdns.NameServer{},
},
},
DNSSettings: DNSSettings{DisabledManagementGroups: []string{}},
DNSSettings: &DNSSettings{DisabledManagementGroups: []string{}},
Settings: &Settings{},
}
err := hasNilField(account)

View File

@@ -45,9 +45,6 @@ const (
"VALUES(?, ?, ?, ?, ?, ?)"
insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?)`
fallbackName = "unknown"
fallbackEmail = "unknown@unknown.com"
)
// Store is the implementation of the activity.Store interface backed by SQLite
@@ -131,7 +128,6 @@ func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) {
func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
events := make([]*activity.Event, 0)
var cryptErr error
for result.Next() {
var id int64
var operation activity.Activity
@@ -160,8 +156,8 @@ func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
if targetUserName != nil {
name, err := store.fieldEncrypt.Decrypt(*targetUserName)
if err != nil {
cryptErr = fmt.Errorf("failed to decrypt username for target id: %s", target)
meta["username"] = fallbackName
log.Errorf("failed to decrypt username for target id: %s", target)
meta["username"] = ""
} else {
meta["username"] = name
}
@@ -170,8 +166,8 @@ func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
if targetEmail != nil {
email, err := store.fieldEncrypt.Decrypt(*targetEmail)
if err != nil {
cryptErr = fmt.Errorf("failed to decrypt email address for target id: %s", target)
meta["email"] = fallbackEmail
log.Errorf("failed to decrypt email address for target id: %s", target)
meta["email"] = ""
} else {
meta["email"] = email
}
@@ -190,8 +186,7 @@ func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
if initiatorName != nil {
name, err := store.fieldEncrypt.Decrypt(*initiatorName)
if err != nil {
cryptErr = fmt.Errorf("failed to decrypt username of initiator: %s", initiator)
event.InitiatorName = fallbackName
log.Errorf("failed to decrypt username of initiator: %s", initiator)
} else {
event.InitiatorName = name
}
@@ -200,8 +195,7 @@ func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
if initiatorEmail != nil {
email, err := store.fieldEncrypt.Decrypt(*initiatorEmail)
if err != nil {
cryptErr = fmt.Errorf("failed to decrypt email address of initiator: %s", initiator)
event.InitiatorEmail = fallbackEmail
log.Errorf("failed to decrypt email address of initiator: %s", initiator)
} else {
event.InitiatorEmail = email
}
@@ -210,10 +204,6 @@ func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
events = append(events, event)
}
if cryptErr != nil {
log.Warnf("%s", cryptErr)
}
return events, nil
}

View File

@@ -24,11 +24,19 @@ type DNSSettings struct {
}
// Copy returns a copy of the DNS settings
func (d DNSSettings) Copy() DNSSettings {
settings := DNSSettings{
DisabledManagementGroups: make([]string, len(d.DisabledManagementGroups)),
func (d *DNSSettings) Copy() *DNSSettings {
settings := &DNSSettings{
DisabledManagementGroups: make([]string, 0),
}
copy(settings.DisabledManagementGroups, d.DisabledManagementGroups)
if d == nil {
return settings
}
if d.DisabledManagementGroups != nil && len(d.DisabledManagementGroups) > 0 {
settings.DisabledManagementGroups = d.DisabledManagementGroups[:]
}
return settings
}
@@ -50,8 +58,12 @@ func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string)
if !user.IsAdmin() {
return nil, status.Errorf(status.PermissionDenied, "only admins are allowed to view DNS settings")
}
dnsSettings := account.DNSSettings.Copy()
return &dnsSettings, nil
if account.DNSSettings == nil {
return &DNSSettings{}, nil
}
return account.DNSSettings.Copy(), nil
}
// SaveDNSSettings validates a user role and updates the account's DNS settings
@@ -84,7 +96,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string
}
}
oldSettings := account.DNSSettings.Copy()
oldSettings := &DNSSettings{}
if account.DNSSettings != nil {
oldSettings = account.DNSSettings.Copy()
}
account.DNSSettings = dnsSettingsToSave.Copy()
account.Network.IncSerial()

View File

@@ -42,7 +42,7 @@ func TestGetDNSSettings(t *testing.T) {
t.Fatal("DNS settings for new accounts shouldn't return nil")
}
account.DNSSettings = DNSSettings{
account.DNSSettings = &DNSSettings{
DisabledManagementGroups: []string{group1ID},
}

View File

@@ -111,6 +111,10 @@ func restore(file string) (*FileStore, error) {
for _, peer := range account.Peers {
store.PeerKeyID2AccountID[peer.Key] = accountID
store.PeerID2AccountID[peer.ID] = accountID
// reset all peers to status = Disconnected
if peer.Status != nil && peer.Status.Connected {
peer.Status.Connected = false
}
}
for _, user := range account.Users {
store.UserID2AccountID[user.Id] = accountID

View File

@@ -61,7 +61,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
if appMetrics != nil {
// update gauge based on number of connected peers which is equal to open gRPC streams
err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
return int64(len(peersUpdateManager.peerChannels))
return peersUpdateManager.Len()
})
if err != nil {
return nil, err

View File

@@ -26,7 +26,7 @@ const (
testDNSSettingsUserID = "test_user"
)
var baseExistingDNSSettings = server.DNSSettings{
var baseExistingDNSSettings = &server.DNSSettings{
DisabledManagementGroups: []string{testDNSSettingsExistingGroup},
}
@@ -43,7 +43,7 @@ func initDNSSettingsTestData() *DNSSettingsHandler {
return &DNSSettingsHandler{
accountManager: &mock_server.MockAccountManager{
GetDNSSettingsFunc: func(accountID string, userID string) (*server.DNSSettings, error) {
return &testingDNSSettingsAccount.DNSSettings, nil
return testingDNSSettingsAccount.DNSSettings, nil
},
SaveDNSSettingsFunc: func(accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error {
if dnsSettingsToSave != nil {

View File

@@ -31,24 +31,6 @@ func NewPeersHandler(accountManager server.AccountManager, authCfg AuthCfg) *Pee
}
}
func (h *PeersHandler) checkPeerStatus(peer *server.Peer) (*server.Peer, error) {
peerToReturn := peer.Copy()
if peer.Status.Connected {
statuses, err := h.accountManager.GetAllConnectedPeers()
if err != nil {
return peerToReturn, err
}
// Although we have online status in store we do not yet have an updated channel so have to show it as disconnected
// This may happen after server restart when not all peers are yet connected
if _, connected := statuses[peerToReturn.ID]; !connected {
peerToReturn.Status.Connected = false
}
}
return peerToReturn, nil
}
func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w http.ResponseWriter) {
peer, err := h.accountManager.GetPeer(account.Id, peerID, userID)
if err != nil {
@@ -56,13 +38,7 @@ func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w
return
}
peerToReturn, err := h.checkPeerStatus(peer)
if err != nil {
util.WriteError(err, w)
return
}
util.WriteJSONObject(w, toPeerResponse(peerToReturn, account, h.accountManager.GetDNSDomain()))
util.WriteJSONObject(w, toPeerResponse(peer, account, h.accountManager.GetDNSDomain()))
}
func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
@@ -144,12 +120,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
respBody := []*api.Peer{}
for _, peer := range peers {
peerToReturn, err := h.checkPeerStatus(peer)
if err != nil {
util.WriteError(err, w)
return
}
respBody = append(respBody, toPeerResponse(peerToReturn, account, dnsDomain))
respBody = append(respBody, toPeerResponse(peer, account, dnsDomain))
}
util.WriteJSONObject(w, respBody)
return

View File

@@ -3,7 +3,6 @@ package http
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
@@ -24,33 +23,19 @@ import (
)
const testPeerID = "test_peer"
const noUpdateChannelTestPeerID = "no-update-channel"
func initTestMetaData(peers ...*server.Peer) *PeersHandler {
return &PeersHandler{
accountManager: &mock_server.MockAccountManager{
UpdatePeerFunc: func(accountID, userID string, update *server.Peer) (*server.Peer, error) {
var p *server.Peer
for _, peer := range peers {
if update.ID == peer.ID {
p = peer.Copy()
break
}
}
p := peers[0].Copy()
p.SSHEnabled = update.SSHEnabled
p.LoginExpirationEnabled = update.LoginExpirationEnabled
p.Name = update.Name
return p, nil
},
GetPeerFunc: func(accountID, peerID, userID string) (*server.Peer, error) {
var p *server.Peer
for _, peer := range peers {
if peerID == peer.ID {
p = peer.Copy()
break
}
}
return p, nil
return peers[0], nil
},
GetPeersFunc: func(accountID, userID string) ([]*server.Peer, error) {
return peers, nil
@@ -72,16 +57,6 @@ func initTestMetaData(peers ...*server.Peer) *PeersHandler {
},
}, user, nil
},
GetAllConnectedPeersFunc: func() (map[string]struct{}, error) {
statuses := make(map[string]struct{})
for _, peer := range peers {
if peer.ID == noUpdateChannelTestPeerID {
break
}
statuses[peer.ID] = struct{}{}
}
return statuses, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
@@ -104,7 +79,7 @@ func TestGetPeers(t *testing.T) {
Key: "key",
SetupKey: "setupkey",
IP: net.ParseIP("100.64.0.1"),
Status: &server.PeerStatus{Connected: true},
Status: &server.PeerStatus{},
Name: "PeerName",
LoginExpirationEnabled: false,
Meta: server.PeerSystemMeta{
@@ -118,17 +93,11 @@ func TestGetPeers(t *testing.T) {
},
}
peer1 := peer.Copy()
peer1.ID = noUpdateChannelTestPeerID
expectedUpdatedPeer := peer.Copy()
expectedUpdatedPeer.LoginExpirationEnabled = true
expectedUpdatedPeer.SSHEnabled = true
expectedUpdatedPeer.Name = "New Name"
expectedPeer1 := peer1.Copy()
expectedPeer1.Status.Connected = false
tt := []struct {
name string
expectedStatus int
@@ -147,21 +116,13 @@ func TestGetPeers(t *testing.T) {
expectedPeer: peer,
},
{
name: "GetPeer with update channel",
name: "GetPeer",
requestType: http.MethodGet,
requestPath: "/api/peers/" + testPeerID,
expectedStatus: http.StatusOK,
expectedArray: false,
expectedPeer: peer,
},
{
name: "GetPeer with no update channel",
requestType: http.MethodGet,
requestPath: "/api/peers/" + peer1.ID,
expectedStatus: http.StatusOK,
expectedArray: false,
expectedPeer: expectedPeer1,
},
{
name: "PutPeer",
requestType: http.MethodPut,
@@ -175,7 +136,7 @@ func TestGetPeers(t *testing.T) {
rr := httptest.NewRecorder()
p := initTestMetaData(peer, peer1)
p := initTestMetaData(peer)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
@@ -210,10 +171,6 @@ func TestGetPeers(t *testing.T) {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
// hardcode this check for now as we only have two peers in this suite
assert.Equal(t, len(respBody), 2)
assert.Equal(t, respBody[1].Connected, false)
got = respBody[0]
} else {
got = &api.Peer{}
@@ -223,15 +180,12 @@ func TestGetPeers(t *testing.T) {
}
}
fmt.Println(got)
assert.Equal(t, got.Name, tc.expectedPeer.Name)
assert.Equal(t, got.Version, tc.expectedPeer.Meta.WtVersion)
assert.Equal(t, got.Ip, tc.expectedPeer.IP.String())
assert.Equal(t, got.Os, "OS core")
assert.Equal(t, got.LoginExpirationEnabled, tc.expectedPeer.LoginExpirationEnabled)
assert.Equal(t, got.SshEnabled, tc.expectedPeer.SSHEnabled)
assert.Equal(t, got.Connected, tc.expectedPeer.Status.Connected)
})
}
}

View File

@@ -251,18 +251,34 @@ func (am *AuthentikManager) GetUserDataByID(userID string, appMetadata AppMetada
// GetAccount returns all the users for a given profile.
func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) {
users, err := am.getAllUsers()
ctx, err := am.authenticationContext()
if err != nil {
return nil, err
}
userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Execute()
if err != nil {
return nil, err
}
defer resp.Body.Close()
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetAccount()
}
for index, user := range users {
user.AppMetadata.WTAccountID = accountID
users[index] = user
if resp.StatusCode != http.StatusOK {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get account %s users, statusCode %d", accountID, resp.StatusCode)
}
users := make([]*UserData, 0)
for _, user := range userList.Results {
userData := parseAuthentikUser(user)
userData.AppMetadata.WTAccountID = accountID
users = append(users, userData)
}
return users, nil
@@ -271,57 +287,35 @@ func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) {
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (am *AuthentikManager) GetAllAccounts() (map[string][]*UserData, error) {
users, err := am.getAllUsers()
ctx, err := am.authenticationContext()
if err != nil {
return nil, err
}
indexedUsers := make(map[string][]*UserData)
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)
userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Execute()
if err != nil {
return nil, err
}
defer resp.Body.Close()
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetAllAccounts()
}
return indexedUsers, nil
}
// getAllUsers returns all users in a Authentik account.
func (am *AuthentikManager) getAllUsers() ([]*UserData, error) {
users := make([]*UserData, 0)
page := int32(1)
for {
ctx, err := am.authenticationContext()
if err != nil {
return nil, err
if resp.StatusCode != http.StatusOK {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Page(page).Execute()
if err != nil {
return nil, err
}
_ = resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
}
for _, user := range userList.Results {
users = append(users, parseAuthentikUser(user))
}
page = int32(userList.GetPagination().Next)
if userList.GetPagination().Next == 0 {
break
}
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
}
return users, nil
indexedUsers := make(map[string][]*UserData)
for _, user := range userList.Results {
userData := parseAuthentikUser(user)
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
}
return indexedUsers, nil
}
// CreateUser creates a new user in authentik Idp and sends an invitation.

View File

@@ -266,7 +266,10 @@ func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) {
// GetAccount returns all the users for a given profile.
func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
users, err := am.getAllUsers()
q := url.Values{}
q.Add("$select", profileFields)
body, err := am.get("users", q)
if err != nil {
return nil, err
}
@@ -275,9 +278,18 @@ func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
am.appMetrics.IDPMetrics().CountGetAccount()
}
for index, user := range users {
user.AppMetadata.WTAccountID = accountID
users[index] = user
var profiles struct{ Value []azureProfile }
err = am.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
users := make([]*UserData, 0)
for _, profile := range profiles.Value {
userData := profile.userData()
userData.AppMetadata.WTAccountID = accountID
users = append(users, userData)
}
return users, nil
@@ -286,16 +298,28 @@ func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) {
users, err := am.getAllUsers()
q := url.Values{}
q.Add("$select", profileFields)
body, err := am.get("users", q)
if err != nil {
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetAllAccounts()
}
var profiles struct{ Value []azureProfile }
err = am.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
indexedUsers := make(map[string][]*UserData)
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetAllAccounts()
for _, profile := range profiles.Value {
userData := profile.userData()
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
}
return indexedUsers, nil
@@ -349,39 +373,6 @@ func (am *AzureManager) DeleteUser(userID string) error {
return nil
}
// getAllUsers returns all users in an Azure AD account.
func (am *AzureManager) getAllUsers() ([]*UserData, error) {
users := make([]*UserData, 0)
q := url.Values{}
q.Add("$select", profileFields)
q.Add("$top", "500")
for nextLink := "users"; nextLink != ""; {
body, err := am.get(nextLink, q)
if err != nil {
return nil, err
}
var profiles struct {
Value []azureProfile
NextLink string `json:"@odata.nextLink"`
}
err = am.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
for _, profile := range profiles.Value {
users = append(users, profile.userData())
}
nextLink = profiles.NextLink
}
return users, nil
}
// get perform Get requests.
func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) {
jwtToken, err := am.credentials.Authenticate()
@@ -389,14 +380,7 @@ func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) {
return nil, err
}
var reqURL string
if strings.HasPrefix(resource, "https") {
// Already an absolute URL for paging
reqURL = resource
} else {
reqURL = fmt.Sprintf("%s/%s?%s", am.GraphAPIEndpoint, resource, q.Encode())
}
reqURL := fmt.Sprintf("%s/%s?%s", am.GraphAPIEndpoint, resource, q.Encode())
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
if err != nil {
return nil, err

View File

@@ -96,7 +96,7 @@ func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(_ string, _ AppMetadata)
// GetUserDataByID requests user data from Google Workspace via ID.
func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
user, err := gm.usersService.Get(userID).Do()
user, err := gm.usersService.Get(userID).Projection("full").Do()
if err != nil {
return nil, err
}
@@ -113,67 +113,41 @@ func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata App
// GetAccount returns all the users for a given profile.
func (gm *GoogleWorkspaceManager) GetAccount(accountID string) ([]*UserData, error) {
users, err := gm.getAllUsers()
usersList, err := gm.usersService.List().Customer(gm.CustomerID).Projection("full").Do()
if err != nil {
return nil, err
}
if gm.appMetrics != nil {
gm.appMetrics.IDPMetrics().CountGetAccount()
usersData := make([]*UserData, 0)
for _, user := range usersList.Users {
userData := parseGoogleWorkspaceUser(user)
userData.AppMetadata.WTAccountID = accountID
usersData = append(usersData, userData)
}
for index, user := range users {
user.AppMetadata.WTAccountID = accountID
users[index] = user
}
return users, nil
return usersData, nil
}
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (gm *GoogleWorkspaceManager) GetAllAccounts() (map[string][]*UserData, error) {
users, err := gm.getAllUsers()
usersList, err := gm.usersService.List().Customer(gm.CustomerID).Projection("full").Do()
if err != nil {
return nil, err
}
indexedUsers := make(map[string][]*UserData)
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)
if gm.appMetrics != nil {
gm.appMetrics.IDPMetrics().CountGetAllAccounts()
}
return indexedUsers, nil
}
// getAllUsers returns all users in a Google Workspace account filtered by customer ID.
func (gm *GoogleWorkspaceManager) getAllUsers() ([]*UserData, error) {
users := make([]*UserData, 0)
pageToken := ""
for {
call := gm.usersService.List().Customer(gm.CustomerID).MaxResults(500)
if pageToken != "" {
call.PageToken(pageToken)
}
resp, err := call.Do()
if err != nil {
return nil, err
}
for _, user := range resp.Users {
users = append(users, parseGoogleWorkspaceUser(user))
}
pageToken = resp.NextPageToken
if pageToken == "" {
break
}
indexedUsers := make(map[string][]*UserData)
for _, user := range usersList.Users {
userData := parseGoogleWorkspaceUser(user)
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
}
return users, nil
return indexedUsers, nil
}
// CreateUser creates a new user in Google Workspace and sends an invitation.
@@ -184,7 +158,7 @@ func (gm *GoogleWorkspaceManager) CreateUser(_, _, _, _ string) (*UserData, erro
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (gm *GoogleWorkspaceManager) GetUserByEmail(email string) ([]*UserData, error) {
user, err := gm.usersService.Get(email).Do()
user, err := gm.usersService.Get(email).Projection("full").Do()
if err != nil {
return nil, err
}

View File

@@ -9,7 +9,6 @@ import (
"time"
"github.com/okta/okta-sdk-golang/v2/okta"
"github.com/okta/okta-sdk-golang/v2/okta/query"
"github.com/netbirdio/netbird/management/server/telemetry"
)
@@ -161,7 +160,7 @@ func (om *OktaManager) GetUserByEmail(email string) ([]*UserData, error) {
// GetAccount returns all the users for a given profile.
func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) {
users, err := om.getAllUsers()
users, resp, err := om.client.User.ListUsers(context.Background(), nil)
if err != nil {
return nil, err
}
@@ -170,40 +169,39 @@ func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) {
om.appMetrics.IDPMetrics().CountGetAccount()
}
for index, user := range users {
user.AppMetadata.WTAccountID = accountID
users[index] = user
if resp.StatusCode != http.StatusOK {
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get account, statusCode %d", resp.StatusCode)
}
return users, nil
list := make([]*UserData, 0)
for _, user := range users {
userData, err := parseOktaUser(user)
if err != nil {
return nil, err
}
userData.AppMetadata.WTAccountID = accountID
list = append(list, userData)
}
return list, nil
}
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (om *OktaManager) GetAllAccounts() (map[string][]*UserData, error) {
users, err := om.getAllUsers()
users, resp, err := om.client.User.ListUsers(context.Background(), nil)
if err != nil {
return nil, err
}
indexedUsers := make(map[string][]*UserData)
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountGetAllAccounts()
}
return indexedUsers, nil
}
// getAllUsers returns all users in an Okta account.
func (om *OktaManager) getAllUsers() ([]*UserData, error) {
qp := query.NewQueryParams(query.WithLimit(200))
userList, resp, err := om.client.User.ListUsers(context.Background(), qp)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountRequestStatusError()
@@ -211,34 +209,17 @@ func (om *OktaManager) getAllUsers() ([]*UserData, error) {
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
}
for resp.HasNextPage() {
paginatedUsers := make([]*okta.User, 0)
resp, err = resp.Next(context.Background(), &paginatedUsers)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
}
userList = append(userList, paginatedUsers...)
}
users := make([]*UserData, 0, len(userList))
for _, user := range userList {
indexedUsers := make(map[string][]*UserData)
for _, user := range users {
userData, err := parseOktaUser(user)
if err != nil {
return nil, err
}
users = append(users, userData)
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
}
return users, nil
return indexedUsers, nil
}
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.

View File

@@ -75,7 +75,6 @@ type MockAccountManager struct {
LoginPeerFunc func(login server.PeerLogin) (*server.Peer, *server.NetworkMap, error)
SyncPeerFunc func(sync server.PeerSync) (*server.Peer, *server.NetworkMap, error)
InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
}
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
@@ -584,11 +583,3 @@ func (am *MockAccountManager) SyncPeer(sync server.PeerSync) (*server.Peer, *ser
}
return nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented")
}
// GetAllConnectedPeers mocks GetAllConnectedPeers of the AccountManager interface
func (am *MockAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) {
if am.GetAllConnectedPeersFunc != nil {
return am.GetAllConnectedPeersFunc()
}
return nil, status.Errorf(codes.Unimplemented, "method GetAllConnectedPeers is not implemented")
}

View File

@@ -98,7 +98,7 @@ func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
for _, id := range group.Peers {
if _, ok := seenPeers[id]; ok {
peer := account.GetPeer(id)
peer := account.GetPeer(peerID)
if peer == nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
}

View File

@@ -5,7 +5,6 @@ import (
"testing"
"github.com/rs/xid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/activity"
@@ -49,12 +48,11 @@ func TestCreateRoute(t *testing.T) {
}
testCases := []struct {
name string
inputArgs input
createInitRoute bool
shouldCreate bool
errFunc require.ErrorAssertionFunc
expectedRoute *route.Route
name string
inputArgs input
shouldCreate bool
errFunc require.ErrorAssertionFunc
expectedRoute *route.Route
}{
{
name: "Happy Path",
@@ -166,9 +164,8 @@ func TestCreateRoute(t *testing.T) {
enabled: true,
groups: []string{routeGroup1},
},
createInitRoute: true,
errFunc: require.Error,
shouldCreate: false,
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Bad Peers Group already has this route",
@@ -182,9 +179,8 @@ func TestCreateRoute(t *testing.T) {
enabled: true,
groups: []string{routeGroup1},
},
createInitRoute: true,
errFunc: require.Error,
shouldCreate: false,
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Empty Peer Should Create",
@@ -330,18 +326,6 @@ func TestCreateRoute(t *testing.T) {
t.Errorf("failed to init testing account: %s", err)
}
if testCase.createInitRoute {
groupAll, errInit := account.GetGroupAll()
if errInit != nil {
t.Errorf("failed to get group all: %s", errInit)
}
_, errInit = am.CreateRoute(account.Id, existingNetwork, "", []string{routeGroup3, routeGroup4},
"", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID)
if errInit != nil {
t.Errorf("failed to create init route: %s", errInit)
}
}
outRoute, err := am.CreateRoute(
account.Id,
testCase.inputArgs.network,
@@ -386,18 +370,17 @@ func TestSaveRoute(t *testing.T) {
validGroupHA2 := routeGroupHA2
testCases := []struct {
name string
existingRoute *route.Route
createInitRoute bool
newPeer *string
newPeerGroups []string
newMetric *int
newPrefix *netip.Prefix
newGroups []string
skipCopying bool
shouldCreate bool
errFunc require.ErrorAssertionFunc
expectedRoute *route.Route
name string
existingRoute *route.Route
newPeer *string
newPeerGroups []string
newMetric *int
newPrefix *netip.Prefix
newGroups []string
skipCopying bool
shouldCreate bool
errFunc require.ErrorAssertionFunc
expectedRoute *route.Route
}{
{
name: "Happy Path",
@@ -662,9 +645,8 @@ func TestSaveRoute(t *testing.T) {
Enabled: true,
Groups: []string{routeGroup1},
},
createInitRoute: true,
newPeer: &validUsedPeer,
errFunc: require.Error,
newPeer: &validUsedPeer,
errFunc: require.Error,
},
{
name: "Do not allow to modify existing route with a peers group from another route",
@@ -680,9 +662,8 @@ func TestSaveRoute(t *testing.T) {
Enabled: true,
Groups: []string{routeGroup1},
},
createInitRoute: true,
newPeerGroups: []string{routeGroup4},
errFunc: require.Error,
newPeerGroups: []string{routeGroup4},
errFunc: require.Error,
},
}
for _, testCase := range testCases {
@@ -697,21 +678,6 @@ func TestSaveRoute(t *testing.T) {
t.Error("failed to init testing account")
}
if testCase.createInitRoute {
account.Routes["initRoute"] = &route.Route{
ID: "initRoute",
Network: netip.MustParsePrefix(existingNetwork),
NetID: existingRouteID,
NetworkType: route.IPv4Network,
PeerGroups: []string{routeGroup4},
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
Groups: []string{routeGroup1},
}
}
account.Routes[testCase.existingRoute.ID] = testCase.existingRoute
err = am.Store.SaveAccount(account)
@@ -845,15 +811,15 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
peer1Routes, err := am.GetNetworkMap(peer1ID)
require.NoError(t, err)
assert.Len(t, peer1Routes.Routes, 1, "HA route should have 1 server route")
require.Len(t, peer1Routes.Routes, 3, "HA route should have more than 1 routes")
peer2Routes, err := am.GetNetworkMap(peer2ID)
require.NoError(t, err)
assert.Len(t, peer2Routes.Routes, 1, "HA route should have 1 server route")
require.Len(t, peer2Routes.Routes, 3, "HA route should have more than 1 routes")
peer4Routes, err := am.GetNetworkMap(peer4ID)
require.NoError(t, err)
assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route")
require.Len(t, peer4Routes.Routes, 3, "HA route should have more than 1 routes")
groups, err := am.ListGroups(account.Id)
require.NoError(t, err)
@@ -872,32 +838,32 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
peer2RoutesAfterDelete, err := am.GetNetworkMap(peer2ID)
require.NoError(t, err)
assert.Len(t, peer2RoutesAfterDelete.Routes, 2, "after peer deletion group should have 2 client routes")
require.Len(t, peer2RoutesAfterDelete.Routes, 2, "after peer deletion group should have only 2 route")
err = am.GroupDeletePeer(account.Id, groupHA2.ID, peer4ID)
require.NoError(t, err)
peer2RoutesAfterDelete, err = am.GetNetworkMap(peer2ID)
require.NoError(t, err)
assert.Len(t, peer2RoutesAfterDelete.Routes, 1, "after peer deletion group should have only 1 route")
require.Len(t, peer2RoutesAfterDelete.Routes, 1, "after peer deletion group should have only 1 route")
err = am.GroupAddPeer(account.Id, groupHA2.ID, peer4ID)
require.NoError(t, err)
peer1RoutesAfterAdd, err := am.GetNetworkMap(peer1ID)
require.NoError(t, err)
assert.Len(t, peer1RoutesAfterAdd.Routes, 1, "HA route should have more than 1 route")
require.Len(t, peer1RoutesAfterAdd.Routes, 2, "HA route should have more than 1 route")
peer2RoutesAfterAdd, err := am.GetNetworkMap(peer2ID)
require.NoError(t, err)
assert.Len(t, peer2RoutesAfterAdd.Routes, 2, "HA route should have 2 client routes")
require.Len(t, peer2RoutesAfterAdd.Routes, 2, "HA route should have more than 1 route")
err = am.DeleteRoute(account.Id, newRoute.ID, userID)
require.NoError(t, err)
peer1DeletedRoute, err := am.GetNetworkMap(peer1ID)
require.NoError(t, err)
assert.Len(t, peer1DeletedRoute.Routes, 0, "we should receive one route for peer1")
require.Len(t, peer1DeletedRoute.Routes, 0, "we should receive one route for peer1")
}
func TestGetNetworkMap_RouteSync(t *testing.T) {
@@ -1227,5 +1193,11 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
}
}
_, err = am.CreateRoute(account.Id, existingNetwork, "", []string{routeGroup3, routeGroup4},
"", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID)
if err != nil {
return nil, err
}
return am.Store.GetAccount(account.Id)
}

View File

@@ -2,6 +2,7 @@ package server
import (
"sync"
"sync/atomic"
log "github.com/sirupsen/logrus"
@@ -14,25 +15,29 @@ type UpdateMessage struct {
Update *proto.SyncResponse
}
type UpdateChannel chan *UpdateMessage
type PeersUpdateManager struct {
// peerChannels is an update channel indexed by Peer.ID
peerChannels map[string]chan *UpdateMessage
channelsMux *sync.Mutex
peerChannels sync.Map
// peerChannelLocks keeps the peer locks to organize channel creations
peerChannelLocks sync.Map
// len is the length of peerChannels
len atomic.Int64
}
// NewPeersUpdateManager returns a new instance of PeersUpdateManager
func NewPeersUpdateManager() *PeersUpdateManager {
return &PeersUpdateManager{
peerChannels: make(map[string]chan *UpdateMessage),
channelsMux: &sync.Mutex{},
}
return &PeersUpdateManager{}
}
// SendUpdate sends update message to the peer's channel
func (p *PeersUpdateManager) SendUpdate(peerID string, update *UpdateMessage) {
p.channelsMux.Lock()
defer p.channelsMux.Unlock()
if channel, ok := p.peerChannels[peerID]; ok {
if ch, ok := p.peerChannels.Load(peerID); ok {
channel, ok := ch.(UpdateChannel)
if !ok {
log.Warnf("could not cast to UpdateChannel")
}
select {
case channel <- update:
log.Debugf("update was sent to channel for peer %s", peerID)
@@ -45,35 +50,48 @@ func (p *PeersUpdateManager) SendUpdate(peerID string, update *UpdateMessage) {
}
// CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer.
func (p *PeersUpdateManager) CreateChannel(peerID string) chan *UpdateMessage {
p.channelsMux.Lock()
defer p.channelsMux.Unlock()
func (p *PeersUpdateManager) CreateChannel(peerID string) UpdateChannel {
// we have to lock the whole operation by peerID as we do two non atomic operations:
// - closeChannel()
// - Store
value, _ := p.peerChannelLocks.LoadOrStore(peerID, &sync.Mutex{})
mtx := value.(*sync.Mutex)
mtx.Lock()
defer mtx.Unlock()
p.closeChannel(peerID)
if channel, ok := p.peerChannels[peerID]; ok {
delete(p.peerChannels, peerID)
close(channel)
}
// mbragin: todo shouldn't it be more? or configurable?
channel := make(chan *UpdateMessage, channelBufferSize)
p.peerChannels[peerID] = channel
channel := make(UpdateChannel, channelBufferSize)
p.peerChannels.Store(peerID, channel)
p.len.Add(1)
log.Debugf("opened updates channel for a peer %s", peerID)
return channel
}
func (p *PeersUpdateManager) closeChannel(peerID string) {
if channel, ok := p.peerChannels[peerID]; ok {
delete(p.peerChannels, peerID)
close(channel)
func (p *PeersUpdateManager) GetChannel(peerID string) UpdateChannel {
if ch, ok := p.peerChannels.Load(peerID); ok {
channel := ch.(UpdateChannel)
return channel
}
return nil
}
log.Debugf("closed updates channel of a peer %s", peerID)
func (p *PeersUpdateManager) closeChannel(peerID string) {
if ch, ok := p.peerChannels.LoadAndDelete(peerID); ok {
channel, ok := ch.(UpdateChannel)
if !ok {
log.Errorf("could not cast to UpdateChannel")
}
p.len.Add(-1)
close(channel)
log.Debugf("closed updates channel of a peer %s", peerID)
}
}
// CloseChannels closes updates channel for each given peer
func (p *PeersUpdateManager) CloseChannels(peerIDs []string) {
p.channelsMux.Lock()
defer p.channelsMux.Unlock()
for _, id := range peerIDs {
p.closeChannel(id)
}
@@ -81,18 +99,22 @@ func (p *PeersUpdateManager) CloseChannels(peerIDs []string) {
// CloseChannel closes updates channel of a given peer
func (p *PeersUpdateManager) CloseChannel(peerID string) {
p.channelsMux.Lock()
defer p.channelsMux.Unlock()
p.closeChannel(peerID)
}
// GetAllConnectedPeers returns a copy of the connected peers map
func (p *PeersUpdateManager) GetAllConnectedPeers() map[string]struct{} {
p.channelsMux.Lock()
defer p.channelsMux.Unlock()
m := make(map[string]struct{})
for ID := range p.peerChannels {
m[ID] = struct{}{}
}
p.peerChannels.Range(func(key any, value any) bool {
if ID, ok := key.(string); ok {
m[ID] = struct{}{}
}
return true
})
return m
}
// Len returns the length of the peer channels
func (p *PeersUpdateManager) Len() (len int64) {
return p.len.Load()
}

View File

@@ -1,9 +1,10 @@
package server
import (
"github.com/netbirdio/netbird/management/proto"
"testing"
"time"
"github.com/netbirdio/netbird/management/proto"
)
//var peersUpdater *PeersUpdateManager
@@ -13,10 +14,22 @@ func TestCreateChannel(t *testing.T) {
peersUpdater := NewPeersUpdateManager()
defer peersUpdater.CloseChannel(peer)
if peersUpdater.Len() != 0 {
t.Error("peersUpdated should not have any channels yet")
}
if ch := peersUpdater.GetChannel(peer); ch != nil {
t.Errorf("We should not have channel for %s yet", peer)
}
_ = peersUpdater.CreateChannel(peer)
if _, ok := peersUpdater.peerChannels[peer]; !ok {
if ch := peersUpdater.GetChannel(peer); ch == nil {
t.Error("Error creating the channel")
}
if peersUpdater.Len() != 1 {
t.Error("peersUpdated should have 1 channel")
}
}
func TestSendUpdate(t *testing.T) {
@@ -28,12 +41,12 @@ func TestSendUpdate(t *testing.T) {
},
}}
_ = peersUpdater.CreateChannel(peer)
if _, ok := peersUpdater.peerChannels[peer]; !ok {
if ch := peersUpdater.GetChannel(peer); ch == nil {
t.Error("Error creating the channel")
}
peersUpdater.SendUpdate(peer, update1)
select {
case <-peersUpdater.peerChannels[peer]:
case <-peersUpdater.GetChannel(peer):
default:
t.Error("Update wasn't send")
}
@@ -54,7 +67,7 @@ func TestSendUpdate(t *testing.T) {
select {
case <-timeout:
t.Error("timed out reading previously sent updates")
case updateReader := <-peersUpdater.peerChannels[peer]:
case updateReader := <-peersUpdater.GetChannel(peer):
if updateReader.Update.NetworkMap.Serial == update2.Update.NetworkMap.Serial {
t.Error("got the update that shouldn't have been sent")
}
@@ -67,11 +80,11 @@ func TestCloseChannel(t *testing.T) {
peer := "test-close"
peersUpdater := NewPeersUpdateManager()
_ = peersUpdater.CreateChannel(peer)
if _, ok := peersUpdater.peerChannels[peer]; !ok {
if ch := peersUpdater.GetChannel(peer); ch == nil {
t.Error("Error creating the channel")
}
peersUpdater.CloseChannel(peer)
if _, ok := peersUpdater.peerChannels[peer]; ok {
if ch := peersUpdater.GetChannel(peer); ch != nil {
t.Error("Error closing the channel")
}
}