diff --git a/client/cmd/testutil.go b/client/cmd/testutil.go index 2cfc93415..2f92e1c03 100644 --- a/client/cmd/testutil.go +++ b/client/cmd/testutil.go @@ -13,6 +13,7 @@ import ( "google.golang.org/grpc" + "github.com/netbirdio/management-integrations/integrations" clientProto "github.com/netbirdio/netbird/client/proto" client "github.com/netbirdio/netbird/client/server" mgmtProto "github.com/netbirdio/netbird/management/proto" @@ -78,7 +79,8 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste if err != nil { return nil, nil } - accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false) + iv, _ := integrations.NewIntegratedValidator(eventStore) + accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv) if err != nil { t.Fatal(err) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 952b3c90c..309b2e7c6 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -21,6 +21,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/keepalive" + "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager" @@ -1050,7 +1051,8 @@ func startManagement(dataDir string) (*grpc.Server, string, error) { if err != nil { return nil, "", err } - accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false) + ia, _ := integrations.NewIntegratedValidator(eventStore) + accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia) if err != nil { return nil, "", err } diff --git a/client/server/server_test.go b/client/server/server_test.go index 7f8310c90..4e4a09145 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -2,6 +2,7 @@ package server import ( "context" + "github.com/netbirdio/management-integrations/integrations" "net" "testing" "time" @@ -114,7 +115,8 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve if err != nil { return nil, "", err } - accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false) + ia, _ := integrations.NewIntegratedValidator(eventStore) + accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia) if err != nil { return nil, "", err } diff --git a/go.mod b/go.mod index 67ec9c42e..5566f8559 100644 --- a/go.mod +++ b/go.mod @@ -46,6 +46,7 @@ require ( github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.5.9 github.com/google/gopacket v1.1.19 + github.com/google/martian/v3 v3.0.0 github.com/google/nftables v0.0.0-20220808154552-2eca00135732 github.com/gopacket/gopacket v1.1.1 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 @@ -59,8 +60,7 @@ require ( github.com/miekg/dns v1.1.43 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552 - github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552 + github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 github.com/patrickmn/go-cache v2.1.0+incompatible diff --git a/go.sum b/go.sum index c36b8aff3..6da405341 100644 --- a/go.sum +++ b/go.sum @@ -255,6 +255,7 @@ github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/ github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= +github.com/google/martian/v3 v3.0.0 h1:pMen7vLs8nvgEYhywH3KDWJIJTeEr2ULsVWHWYHQyBs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A= github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc= @@ -382,10 +383,8 @@ github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc= github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= -github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552 h1:yzcQKizAK9YufCHMMCIsr467Dw/OU/4xyHbWizGb1E4= -github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552/go.mod h1:31FhBNvQ+riHEIu6LSTmqr8IeuSIsGfQffqV4LFmbwA= -github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552 h1:OFlzVZtkXCoJsfDKrMigFpuad8ZXTm8epq6x27K0irA= -github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552/go.mod h1:B0nMS3es77gOvPYhc0K91fAzTkQLi/jRq5TffUN3klM= +github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98 h1:i6AtenTLu/CqhTmj0g1K/GWkkpMJMhQM6Vjs46x25nA= +github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM= diff --git a/management/client/client_test.go b/management/client/client_test.go index f30ae0cfd..30f91c73b 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -3,6 +3,7 @@ package client import ( "context" "net" + "os" "path/filepath" "sync" "testing" @@ -15,6 +16,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/encryption" mgmtProto "github.com/netbirdio/netbird/management/proto" mgmt "github.com/netbirdio/netbird/management/server" @@ -30,6 +32,12 @@ import ( const ValidKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" +func TestMain(m *testing.M) { + _ = util.InitLog("debug", "console") + code := m.Run() + os.Exit(code) +} + func startManagement(t *testing.T) (*grpc.Server, net.Listener) { t.Helper() level, _ := log.ParseLevel("debug") @@ -60,7 +68,8 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { peersUpdateManager := mgmt.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} - accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false) + ia, _ := integrations.NewIntegratedValidator(eventStore) + accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia) if err != nil { t.Fatal(err) } diff --git a/management/cmd/management.go b/management/cmd/management.go index e8bcdc97d..23d9c195c 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -31,6 +31,7 @@ import ( "google.golang.org/grpc/keepalive" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" + "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/encryption" @@ -172,8 +173,12 @@ var ( log.Infof("geo location service has been initialized from %s", config.Datadir) } + integratedPeerValidator, err := integrations.NewIntegratedValidator(eventStore) + if err != nil { + return fmt.Errorf("failed to initialize integrated peer validator: %v", err) + } accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain, - dnsDomain, eventStore, geo, userDeleteFromIDPEnabled) + dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator) if err != nil { return fmt.Errorf("failed to build default manager: %v", err) } @@ -323,6 +328,7 @@ var ( SetupCloseHandler() <-stopCh + integratedPeerValidator.Stop() if geo != nil { _ = geo.Stop() } diff --git a/management/server/account.go b/management/server/account.go index d9030007d..c145c1bd7 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -21,14 +21,15 @@ import ( "github.com/rs/xid" log "github.com/sirupsen/logrus" - "github.com/netbirdio/management-integrations/additions" - "github.com/netbirdio/netbird/base62" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/geolocation" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/integrated_validator" + "github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" @@ -85,12 +86,12 @@ type AccountManager interface { GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) UpdatePeerSSHKey(peerID string, sshKey string) error GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) - GetGroup(accountId, groupID, userID string) (*Group, error) - GetAllGroups(accountID, userID string) ([]*Group, error) - GetGroupByName(groupName, accountID string) (*Group, error) - SaveGroup(accountID, userID string, group *Group) error + GetGroup(accountId, groupID, userID string) (*nbgroup.Group, error) + GetAllGroups(accountID, userID string) ([]*nbgroup.Group, error) + GetGroupByName(groupName, accountID string) (*nbgroup.Group, error) + SaveGroup(accountID, userID string, group *nbgroup.Group) error DeleteGroup(accountId, userId, groupID string) error - ListGroups(accountId string) ([]*Group, error) + ListGroups(accountId string) ([]*nbgroup.Group, error) GroupAddPeer(accountId, groupID, peerID string) error GroupDeletePeer(accountId, groupID, peerID string) error GetPolicy(accountID, policyID, userID string) (*Policy, error) @@ -124,6 +125,9 @@ type AccountManager interface { DeletePostureChecks(accountID, postureChecksID, userID string) error ListPostureChecks(accountID, userID string) ([]*posture.Checks, error) GetIdpManager() idp.Manager + UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error + GroupValidation(accountId string, groups []string) (bool, error) + GetValidatedPeers(account *Account) (map[string]struct{}, error) } type DefaultAccountManager struct { @@ -152,6 +156,8 @@ type DefaultAccountManager struct { // userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account userDeleteFromIDPEnabled bool + + integratedPeerValidator integrated_validator.IntegratedValidator } // Settings represents Account settings structure that can be modified via API and Dashboard @@ -218,8 +224,8 @@ type Account struct { PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"` Users map[string]*User `gorm:"-"` UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"` - Groups map[string]*Group `gorm:"-"` - GroupsG []Group `json:"-" gorm:"foreignKey:AccountID;references:id"` + Groups map[string]*nbgroup.Group `gorm:"-"` + GroupsG []nbgroup.Group `json:"-" gorm:"foreignKey:AccountID;references:id"` Policies []*Policy `gorm:"foreignKey:AccountID;references:id"` Routes map[string]*route.Route `gorm:"-"` RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"` @@ -247,7 +253,7 @@ type UserInfo struct { NonDeletable bool `json:"non_deletable"` LastLogin time.Time `json:"last_login"` Issued string `json:"issued"` - IntegrationReference IntegrationReference `json:"-"` + IntegrationReference integration_reference.IntegrationReference `json:"-"` Permissions UserPermissions `json:"permissions"` } @@ -372,25 +378,26 @@ func (a *Account) GetRoutesByPrefix(prefix netip.Prefix) []*route.Route { } // GetGroup returns a group by ID if exists, nil otherwise -func (a *Account) GetGroup(groupID string) *Group { +func (a *Account) GetGroup(groupID string) *nbgroup.Group { return a.Groups[groupID] } // GetPeerNetworkMap returns a group by ID if exists, nil otherwise -func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap { +func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string, validatedPeersMap map[string]struct{}) *NetworkMap { peer := a.Peers[peerID] if peer == nil { return &NetworkMap{ Network: a.Network.Copy(), } } - validatedPeers := additions.ValidatePeers([]*nbpeer.Peer{peer}) - if len(validatedPeers) == 0 { + + if _, ok := validatedPeersMap[peerID]; !ok { return &NetworkMap{ Network: a.Network.Copy(), } } - aclPeers, firewallRules := a.getPeerConnectionResources(peerID) + + aclPeers, firewallRules := a.getPeerConnectionResources(peerID, validatedPeersMap) // exclude expired peers var peersToConnect []*nbpeer.Peer var expiredPeers []*nbpeer.Peer @@ -564,7 +571,7 @@ func (a *Account) FindUser(userID string) (*User, error) { } // FindGroupByName looks for a given group in the Account by name or returns error if the group wasn't found. -func (a *Account) FindGroupByName(groupName string) (*Group, error) { +func (a *Account) FindGroupByName(groupName string) (*nbgroup.Group, error) { for _, group := range a.Groups { if group.Name == groupName { return group, nil @@ -583,6 +590,20 @@ func (a *Account) FindSetupKey(setupKey string) (*SetupKey, error) { return key, nil } +// GetPeerGroupsList return with the list of groups ID. +func (a *Account) GetPeerGroupsList(peerID string) []string { + var grps []string + for groupID, group := range a.Groups { + for _, id := range group.Peers { + if id == peerID { + grps = append(grps, groupID) + break + } + } + } + return grps +} + func (a *Account) getUserGroups(userID string) ([]string, error) { user, err := a.FindUser(userID) if err != nil { @@ -660,7 +681,7 @@ func (a *Account) Copy() *Account { setupKeys[id] = key.Copy() } - groups := map[string]*Group{} + groups := map[string]*nbgroup.Group{} for id, group := range a.Groups { groups[id] = group.Copy() } @@ -713,7 +734,7 @@ func (a *Account) Copy() *Account { } } -func (a *Account) GetGroupAll() (*Group, error) { +func (a *Account) GetGroupAll() (*nbgroup.Group, error) { for _, g := range a.Groups { if g.Name == "All" { return g, nil @@ -734,7 +755,7 @@ func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { return false } - existedGroupsByName := make(map[string]*Group) + existedGroupsByName := make(map[string]*nbgroup.Group) for _, group := range a.Groups { existedGroupsByName[group.Name] = group } @@ -743,7 +764,7 @@ func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { removed := 0 jwtAutoGroups := make(map[string]struct{}) for i, id := range user.AutoGroups { - if group, ok := a.Groups[id]; ok && group.Issued == GroupIssuedJWT { + if group, ok := a.Groups[id]; ok && group.Issued == nbgroup.GroupIssuedJWT { jwtAutoGroups[group.Name] = struct{}{} user.AutoGroups = append(user.AutoGroups[:i-removed], user.AutoGroups[i-removed+1:]...) removed++ @@ -756,15 +777,15 @@ func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { for _, name := range groupsNames { group, ok := existedGroupsByName[name] if !ok { - group = &Group{ + group = &nbgroup.Group{ ID: xid.New().String(), Name: name, - Issued: GroupIssuedJWT, + Issued: nbgroup.GroupIssuedJWT, } a.Groups[group.ID] = group } // only JWT groups will be synced - if group.Issued == GroupIssuedJWT { + if group.Issued == nbgroup.GroupIssuedJWT { user.AutoGroups = append(user.AutoGroups, group.ID) if _, ok := jwtAutoGroups[name]; !ok { modified = true @@ -837,6 +858,7 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) { func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager, singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, geo *geolocation.Geolocation, userDeleteFromIDPEnabled bool, + integratedPeerValidator integrated_validator.IntegratedValidator, ) (*DefaultAccountManager, error) { am := &DefaultAccountManager{ Store: store, @@ -850,6 +872,7 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage eventStore: eventStore, peerLoginExpiry: NewDefaultScheduler(), userDeleteFromIDPEnabled: userDeleteFromIDPEnabled, + integratedPeerValidator: integratedPeerValidator, } allAccounts := store.GetAllAccounts() // enable single account mode only if configured by user and number of existing accounts is not grater than 1 @@ -906,6 +929,8 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage }() } + am.integratedPeerValidator.SetPeerInvalidationListener(am.onPeersInvalidated) + return am, nil } @@ -948,7 +973,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string, return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account") } - err = additions.ValidateExtraSettings(newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID, am.eventStore) + err = am.integratedPeerValidator.ValidateExtraSettings(newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID) if err != nil { return nil, err } @@ -1823,18 +1848,27 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.Aut return nil } +func (am *DefaultAccountManager) onPeersInvalidated(accountID string) { + updatedAccount, err := am.Store.GetAccount(accountID) + if err != nil { + log.Errorf("failed to get account %s: %v", accountID, err) + return + } + am.updateAccountPeers(updatedAccount) +} + // addAllGroup to account object if it doesn't exist func addAllGroup(account *Account) error { if len(account.Groups) == 0 { - allGroup := &Group{ + allGroup := &nbgroup.Group{ ID: xid.New().String(), Name: "All", - Issued: GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, } for _, peer := range account.Peers { allGroup.Peers = append(allGroup.Peers, peer.ID) } - account.Groups = map[string]*Group{allGroup.ID: allGroup} + account.Groups = map[string]*nbgroup.Group{allGroup.ID: allGroup} id := xid.New().String() diff --git a/management/server/account/account.go b/management/server/account/account.go index b8b71a6de..40f032fbe 100644 --- a/management/server/account/account.go +++ b/management/server/account/account.go @@ -3,11 +3,17 @@ package account type ExtraSettings struct { // PeerApprovalEnabled enables or disables the need for peers bo be approved by an administrator PeerApprovalEnabled bool + + // IntegratedValidatorGroups list of group IDs to be used with integrated approval configurations + IntegratedValidatorGroups []string `gorm:"serializer:json"` } // Copy copies the ExtraSettings struct func (e *ExtraSettings) Copy() *ExtraSettings { + var cpGroup []string + return &ExtraSettings{ - PeerApprovalEnabled: e.PeerApprovalEnabled, + PeerApprovalEnabled: e.PeerApprovalEnabled, + IntegratedValidatorGroups: append(cpGroup, e.IntegratedValidatorGroups...), } } diff --git a/management/server/account_test.go b/management/server/account_test.go index 2b0c44196..a0eff239b 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -12,20 +12,57 @@ import ( "time" "github.com/golang-jwt/jwt" - - nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/server/activity" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/route" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/jwtclaims" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/route" ) +type MocIntegratedValidator struct { +} + +func (a MocIntegratedValidator) ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { + return nil +} + +func (a MocIntegratedValidator) ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) { + return update, nil +} +func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { + validatedPeers := make(map[string]struct{}) + for _, peer := range peers { + validatedPeers[peer.ID] = struct{}{} + } + return validatedPeers, nil +} + +func (MocIntegratedValidator) PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { + return peer +} + +func (MocIntegratedValidator) IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool) { + return false, false +} + +func (MocIntegratedValidator) PeerDeleted(_, _ string) error { + return nil +} + +func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) { + +} + +func (MocIntegratedValidator) Stop() { +} + func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Account, userID string) { t.Helper() peer := &nbpeer.Peer{ @@ -367,7 +404,12 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { account.Groups[all.ID].Peers = append(account.Groups[all.ID].Peers, peer.ID) } - networkMap := account.GetPeerNetworkMap(testCase.peerID, "netbird.io") + validatedPeers := map[string]struct{}{} + for p := range account.Peers { + validatedPeers[p] = struct{}{} + } + + networkMap := account.GetPeerNetworkMap(testCase.peerID, "netbird.io", validatedPeers) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) } @@ -667,7 +709,7 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { require.NoError(t, err, "get account by token failed") require.Len(t, account.Groups, 3, "groups should be added to the account") - groupsByNames := map[string]*Group{} + groupsByNames := map[string]*group.Group{} for _, g := range account.Groups { groupsByNames[g.Name] = g } @@ -675,12 +717,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { g1, ok := groupsByNames["group1"] require.True(t, ok, "group1 should be added to the account") require.Equal(t, g1.Name, "group1", "group1 name should match") - require.Equal(t, g1.Issued, GroupIssuedJWT, "group1 issued should match") + require.Equal(t, g1.Issued, group.GroupIssuedJWT, "group1 issued should match") g2, ok := groupsByNames["group2"] require.True(t, ok, "group2 should be added to the account") require.Equal(t, g2.Name, "group2", "group2 name should match") - require.Equal(t, g2.Issued, GroupIssuedJWT, "group2 issued should match") + require.Equal(t, g2.Issued, group.GroupIssuedJWT, "group2 issued should match") }) } @@ -800,7 +842,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) { t.Fatalf("expected to create an account for a user %s", userId) } - if account.Domain != domain { + if account != nil && account.Domain != domain { t.Errorf("setting account domain failed, expected %s, got %s", domain, account.Domain) } @@ -815,7 +857,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) { t.Fatalf("expected to get an account for a user %s", userId) } - if account.Domain != domain { + if account != nil && account.Domain != domain { t.Errorf("updating domain. expected %s got %s", domain, account.Domain) } } @@ -835,13 +877,12 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { } if account == nil { t.Fatalf("expected to create an account for a user %s", userId) + return } - accountId := account.Id - - _, err = manager.GetAccountByUserOrAccountID("", accountId, "") + _, err = manager.GetAccountByUserOrAccountID("", account.Id, "") if err != nil { - t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountId) + t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", account.Id) } _, err = manager.GetAccountByUserOrAccountID("", "", "") @@ -1124,7 +1165,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(peer1.ID) defer manager.peersUpdateManager.CloseChannel(peer1.ID) - group := Group{ + group := group.Group{ ID: "group-id", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, @@ -1417,7 +1458,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) { Peers: map[string]*nbpeer.Peer{ "peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, }, - Groups: map[string]*Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}}, + Groups: map[string]*group.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}}, Routes: map[string]*route.Route{ "route-1": { ID: "route-1", @@ -1518,7 +1559,7 @@ func TestAccount_Copy(t *testing.T) { }, }, }, - Groups: map[string]*Group{ + Groups: map[string]*group.Group{ "group1": { ID: "group1", Peers: []string{"peer1"}, @@ -2112,8 +2153,8 @@ func TestAccount_SetJWTGroups(t *testing.T) { "peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, }, - Groups: map[string]*Group{ - "group1": {ID: "group1", Name: "group1", Issued: GroupIssuedAPI, Peers: []string{}}, + Groups: map[string]*group.Group{ + "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}}, }, Settings: &Settings{GroupsPropagationEnabled: true}, Users: map[string]*User{ @@ -2160,10 +2201,10 @@ func TestAccount_UserGroupsAddToPeers(t *testing.T) { "peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, }, - Groups: map[string]*Group{ - "group1": {ID: "group1", Name: "group1", Issued: GroupIssuedAPI, Peers: []string{}}, - "group2": {ID: "group2", Name: "group2", Issued: GroupIssuedAPI, Peers: []string{}}, - "group3": {ID: "group3", Name: "group3", Issued: GroupIssuedAPI, Peers: []string{}}, + Groups: map[string]*group.Group{ + "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}}, + "group2": {ID: "group2", Name: "group2", Issued: group.GroupIssuedAPI, Peers: []string{}}, + "group3": {ID: "group3", Name: "group3", Issued: group.GroupIssuedAPI, Peers: []string{}}, }, Users: map[string]*User{"user1": {Id: "user1"}, "user2": {Id: "user2"}}, } @@ -2196,10 +2237,10 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) { "peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, }, - Groups: map[string]*Group{ - "group1": {ID: "group1", Name: "group1", Issued: GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3"}}, - "group2": {ID: "group2", Name: "group2", Issued: GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3", "peer4", "peer5"}}, - "group3": {ID: "group3", Name: "group3", Issued: GroupIssuedAPI, Peers: []string{"peer4", "peer5"}}, + Groups: map[string]*group.Group{ + "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3"}}, + "group2": {ID: "group2", Name: "group2", Issued: group.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3", "peer4", "peer5"}}, + "group3": {ID: "group3", Name: "group3", Issued: group.GroupIssuedAPI, Peers: []string{"peer4", "peer5"}}, }, Users: map[string]*User{"user1": {Id: "user1"}, "user2": {Id: "user2"}}, } @@ -2223,7 +2264,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false) + return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}) } func createStore(t *testing.T) (Store, error) { diff --git a/management/server/dns_test.go b/management/server/dns_test.go index aac35308c..18f942e68 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -8,6 +8,7 @@ import ( "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" ) @@ -193,7 +194,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false) + return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}) } func createDNSStore(t *testing.T) (Store, error) { @@ -278,13 +279,13 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro return nil, err } - newGroup1 := &Group{ + newGroup1 := &group.Group{ ID: dnsGroup1ID, Peers: []string{peer1.ID}, Name: dnsGroup1ID, } - newGroup2 := &Group{ + newGroup2 := &group.Group{ ID: dnsGroup2ID, Name: dnsGroup2ID, } diff --git a/management/server/ephemeral.go b/management/server/ephemeral.go index 9d70a05d1..4fffa024d 100644 --- a/management/server/ephemeral.go +++ b/management/server/ephemeral.go @@ -165,7 +165,7 @@ func (e *EphemeralManager) cleanup() { log.Debugf("delete ephemeral peer: %s", id) err := e.accountManager.DeletePeer(p.account.Id, id, activity.SystemInitiator) if err != nil { - log.Tracef("failed to delete ephemeral peer: %s", err) + log.Errorf("failed to delete ephemeral peer: %s", err) } } } diff --git a/management/server/file_store.go b/management/server/file_store.go index 0228285cb..2de852bee 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -10,6 +10,7 @@ import ( "github.com/rs/xid" log "github.com/sirupsen/logrus" + nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" @@ -170,7 +171,7 @@ func restore(file string) (*FileStore, error) { // Set API as issuer for groups which has not this field for _, group := range account.Groups { if group.Issued == "" { - group.Issued = GroupIssuedAPI + group.Issued = nbgroup.GroupIssuedAPI } } diff --git a/management/server/file_store_test.go b/management/server/file_store_test.go index d8575a3bf..d53298d8f 100644 --- a/management/server/file_store_test.go +++ b/management/server/file_store_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/util" ) @@ -188,7 +189,7 @@ func TestStore(t *testing.T) { Name: "peer name", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } - account.Groups["all"] = &Group{ + account.Groups["all"] = &group.Group{ ID: "all", Name: "all", Peers: []string{"testpeer"}, @@ -320,7 +321,7 @@ func TestRestoreGroups_Migration(t *testing.T) { // create default group account := store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] - account.Groups = map[string]*Group{ + account.Groups = map[string]*group.Group{ "cfefqs706sqkneg59g3g": { ID: "cfefqs706sqkneg59g3g", Name: "All", @@ -336,7 +337,7 @@ func TestRestoreGroups_Migration(t *testing.T) { account = store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] require.Contains(t, account.Groups, "cfefqs706sqkneg59g3g", "failed to restore a FileStore file - missing Account Groups") - require.Equal(t, GroupIssuedAPI, account.Groups["cfefqs706sqkneg59g3g"].Issued, "default group should has API issued mark") + require.Equal(t, group.GroupIssuedAPI, account.Groups["cfefqs706sqkneg59g3g"].Issued, "default group should has API issued mark") } func TestGetAccountByPrivateDomain(t *testing.T) { @@ -384,6 +385,7 @@ func TestFileStore_GetAccount(t *testing.T) { expected := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] if expected == nil { t.Fatalf("expected account doesn't exist") + return } account, err := store.GetAccount(expected.Id) diff --git a/management/server/group.go b/management/server/group.go index 59f05a354..0fc952cdb 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -7,6 +7,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" ) @@ -19,51 +20,8 @@ func (e *GroupLinkError) Error() string { return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name) } -const ( - GroupIssuedAPI = "api" - GroupIssuedJWT = "jwt" - GroupIssuedIntegration = "integration" -) - -// Group of the peers for ACL -type Group struct { - // ID of the group - ID string - - // AccountID is a reference to Account that this object belongs - AccountID string `json:"-" gorm:"index"` - - // Name visible in the UI - Name string - - // Issued defines how this group was created (enum of "api", "integration" or "jwt") - Issued string - - // Peers list of the group - Peers []string `gorm:"serializer:json"` - - IntegrationReference IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` -} - -// EventMeta returns activity event meta related to the group -func (g *Group) EventMeta() map[string]any { - return map[string]any{"name": g.Name} -} - -func (g *Group) Copy() *Group { - group := &Group{ - ID: g.ID, - Name: g.Name, - Issued: g.Issued, - Peers: make([]string, len(g.Peers)), - IntegrationReference: g.IntegrationReference, - } - copy(group.Peers, g.Peers) - return group -} - // GetGroup object of the peers -func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*Group, error) { +func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*nbgroup.Group, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -90,7 +48,7 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*G } // GetAllGroups returns all groups in an account -func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ([]*Group, error) { +func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ([]*nbgroup.Group, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -108,7 +66,7 @@ func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ( return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users") } - groups := make([]*Group, 0, len(account.Groups)) + groups := make([]*nbgroup.Group, 0, len(account.Groups)) for _, item := range account.Groups { groups = append(groups, item) } @@ -117,7 +75,7 @@ func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ( } // GetGroupByName filters all groups in an account by name and returns the one with the most peers -func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*Group, error) { +func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*nbgroup.Group, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -126,7 +84,7 @@ func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*G return nil, err } - matchingGroups := make([]*Group, 0) + matchingGroups := make([]*nbgroup.Group, 0) for _, group := range account.Groups { if group.Name == groupName { matchingGroups = append(matchingGroups, group) @@ -138,7 +96,7 @@ func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*G } maxPeers := -1 - var groupWithMostPeers *Group + var groupWithMostPeers *nbgroup.Group for i, group := range matchingGroups { if len(group.Peers) > maxPeers { maxPeers = len(group.Peers) @@ -150,7 +108,7 @@ func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*G } // SaveGroup object of the peers -func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *Group) error { +func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *nbgroup.Group) error { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -159,11 +117,11 @@ func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *G return err } - if newGroup.ID == "" && newGroup.Issued != GroupIssuedAPI { + if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued) } - if newGroup.ID == "" && newGroup.Issued == GroupIssuedAPI { + if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { existingGroup, err := account.FindGroupByName(newGroup.Name) if err != nil { @@ -270,7 +228,7 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) } // disable a deleting integration group if the initiator is not an admin service user - if g.Issued == GroupIssuedIntegration { + if g.Issued == nbgroup.GroupIssuedIntegration { executingUser := account.Users[userId] if executingUser == nil { return status.Errorf(status.NotFound, "user not found") @@ -340,6 +298,15 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) } } + // check integrated peer validator groups + if account.Settings.Extra != nil { + for _, integratedPeerValidatorGroups := range account.Settings.Extra.IntegratedValidatorGroups { + if groupID == integratedPeerValidatorGroups { + return &GroupLinkError{"integrated validator", g.Name} + } + } + } + delete(account.Groups, groupID) account.Network.IncSerial() @@ -355,7 +322,7 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) } // ListGroups objects of the peers -func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error) { +func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -364,7 +331,7 @@ func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error) return nil, err } - groups := make([]*Group, 0, len(account.Groups)) + groups := make([]*nbgroup.Group, 0, len(account.Groups)) for _, item := range account.Groups { groups = append(groups, item) } diff --git a/management/server/group/group.go b/management/server/group/group.go new file mode 100644 index 000000000..79dfd995c --- /dev/null +++ b/management/server/group/group.go @@ -0,0 +1,46 @@ +package group + +import "github.com/netbirdio/netbird/management/server/integration_reference" + +const ( + GroupIssuedAPI = "api" + GroupIssuedJWT = "jwt" + GroupIssuedIntegration = "integration" +) + +// Group of the peers for ACL +type Group struct { + // ID of the group + ID string + + // AccountID is a reference to Account that this object belongs + AccountID string `json:"-" gorm:"index"` + + // Name visible in the UI + Name string + + // Issued defines how this group was created (enum of "api", "integration" or "jwt") + Issued string + + // Peers list of the group + Peers []string `gorm:"serializer:json"` + + IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` +} + +// EventMeta returns activity event meta related to the group +func (g *Group) EventMeta() map[string]any { + return map[string]any{"name": g.Name} +} + +func (g *Group) Copy() *Group { + group := &Group{ + ID: g.ID, + Name: g.Name, + Issued: g.Issued, + Peers: make([]string, len(g.Peers)), + IntegrationReference: g.IntegrationReference, + } + copy(group.Peers, g.Peers) + return group +} diff --git a/management/server/group_test.go b/management/server/group_test.go index 3a2195c88..35e9b2170 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -5,6 +5,7 @@ import ( "testing" nbdns "github.com/netbirdio/netbird/dns" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/route" ) @@ -24,22 +25,22 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { t.Error("failed to init testing account") } for _, group := range account.Groups { - group.Issued = GroupIssuedIntegration + group.Issued = nbgroup.GroupIssuedIntegration err = am.SaveGroup(account.Id, groupAdminUserID, group) if err != nil { - t.Errorf("should allow to create %s groups", GroupIssuedIntegration) + t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedIntegration) } } for _, group := range account.Groups { - group.Issued = GroupIssuedJWT + group.Issued = nbgroup.GroupIssuedJWT err = am.SaveGroup(account.Id, groupAdminUserID, group) if err != nil { - t.Errorf("should allow to create %s groups", GroupIssuedJWT) + t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedJWT) } } for _, group := range account.Groups { - group.Issued = GroupIssuedAPI + group.Issued = nbgroup.GroupIssuedAPI group.ID = "" err = am.SaveGroup(account.Id, groupAdminUserID, group) if err == nil { @@ -129,51 +130,51 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) { accountID := "testingAcc" domain := "example.com" - groupForRoute := &Group{ + groupForRoute := &nbgroup.Group{ ID: "grp-for-route", AccountID: "account-id", Name: "Group for route", - Issued: GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, Peers: make([]string, 0), } - groupForNameServerGroups := &Group{ + groupForNameServerGroups := &nbgroup.Group{ ID: "grp-for-name-server-grp", AccountID: "account-id", Name: "Group for name server groups", - Issued: GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, Peers: make([]string, 0), } - groupForPolicies := &Group{ + groupForPolicies := &nbgroup.Group{ ID: "grp-for-policies", AccountID: "account-id", Name: "Group for policies", - Issued: GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, Peers: make([]string, 0), } - groupForSetupKeys := &Group{ + groupForSetupKeys := &nbgroup.Group{ ID: "grp-for-keys", AccountID: "account-id", Name: "Group for setup keys", - Issued: GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, Peers: make([]string, 0), } - groupForUsers := &Group{ + groupForUsers := &nbgroup.Group{ ID: "grp-for-users", AccountID: "account-id", Name: "Group for users", - Issued: GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, Peers: make([]string, 0), } - groupForIntegration := &Group{ + groupForIntegration := &nbgroup.Group{ ID: "grp-for-integration", AccountID: "account-id", Name: "Group for users integration", - Issued: GroupIssuedIntegration, + Issued: nbgroup.GroupIssuedIntegration, Peers: make([]string, 0), } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 341d202b6..340adcfc6 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -361,6 +361,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p Meta: extractPeerMeta(loginReq), UserID: userID, SetupKey: loginReq.GetSetupKey(), + ConnectionIP: realIP, }) if err != nil { diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 281089396..f8f581bd4 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -355,6 +355,7 @@ components: - user_id - version - ui_version + - approval_required AccessiblePeer: allOf: - $ref: '#/components/schemas/PeerMinimum' diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 78cd83a27..0bed93b3c 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -470,7 +470,7 @@ type Peer struct { AccessiblePeers []AccessiblePeer `json:"accessible_peers"` // ApprovalRequired (Cloud only) Indicates whether peer needs approval - ApprovalRequired *bool `json:"approval_required,omitempty"` + ApprovalRequired bool `json:"approval_required"` // CityName Commonly used English name of the city CityName CityName `json:"city_name"` @@ -539,7 +539,7 @@ type Peer struct { // PeerBase defines model for PeerBase. type PeerBase struct { // ApprovalRequired (Cloud only) Indicates whether peer needs approval - ApprovalRequired *bool `json:"approval_required,omitempty"` + ApprovalRequired bool `json:"approval_required"` // CityName Commonly used English name of the city CityName CityName `json:"city_name"` @@ -611,7 +611,7 @@ type PeerBatch struct { AccessiblePeersCount int `json:"accessible_peers_count"` // ApprovalRequired (Cloud only) Indicates whether peer needs approval - ApprovalRequired *bool `json:"approval_required,omitempty"` + ApprovalRequired bool `json:"approval_required"` // CityName Commonly used English name of the city CityName CityName `json:"city_name"` diff --git a/management/server/http/groups_handler.go b/management/server/http/groups_handler.go index 56d06595f..47bcf2f32 100644 --- a/management/server/http/groups_handler.go +++ b/management/server/http/groups_handler.go @@ -4,15 +4,15 @@ import ( "encoding/json" "net/http" - "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/status" - - "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/management/server/jwtclaims" - "github.com/gorilla/mux" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server" + nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/util" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/status" ) // GroupsHandler is a handler that returns groups of the account @@ -110,7 +110,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { } else { peers = *req.Peers } - group := server.Group{ + group := nbgroup.Group{ ID: groupID, Name: req.Name, Peers: peers, @@ -154,10 +154,10 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { } else { peers = *req.Peers } - group := server.Group{ + group := nbgroup.Group{ Name: req.Name, Peers: peers, - Issued: server.GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, } err = h.accountManager.SaveGroup(account.Id, user.Id, &group) @@ -240,7 +240,7 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { } } -func toGroupResponse(account *server.Account, group *server.Group) *api.Group { +func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group { cache := make(map[string]api.PeerMinimum) gr := api.Group{ Id: group.ID, diff --git a/management/server/http/groups_handler_test.go b/management/server/http/groups_handler_test.go index 303efc9d7..3d74b848c 100644 --- a/management/server/http/groups_handler_test.go +++ b/management/server/http/groups_handler_test.go @@ -15,6 +15,7 @@ import ( "github.com/magiconair/properties/assert" "github.com/netbirdio/netbird/management/server" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -28,30 +29,30 @@ var TestPeers = map[string]*nbpeer.Peer{ "B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")}, } -func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandler { +func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler { return &GroupsHandler{ accountManager: &mock_server.MockAccountManager{ - SaveGroupFunc: func(accountID, userID string, group *server.Group) error { + SaveGroupFunc: func(accountID, userID string, group *nbgroup.Group) error { if !strings.HasPrefix(group.ID, "id-") { group.ID = "id-was-set" } return nil }, - GetGroupFunc: func(_, groupID, _ string) (*server.Group, error) { + GetGroupFunc: func(_, groupID, _ string) (*nbgroup.Group, error) { if groupID != "idofthegroup" { return nil, status.Errorf(status.NotFound, "not found") } if groupID == "id-jwt-group" { - return &server.Group{ + return &nbgroup.Group{ ID: "id-jwt-group", Name: "Default Group", - Issued: server.GroupIssuedJWT, + Issued: nbgroup.GroupIssuedJWT, }, nil } - return &server.Group{ + return &nbgroup.Group{ ID: "idofthegroup", Name: "Group", - Issued: server.GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, }, nil }, GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { @@ -62,10 +63,10 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandle Users: map[string]*server.User{ user.Id: user, }, - Groups: map[string]*server.Group{ - "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: server.GroupIssuedJWT}, - "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: server.GroupIssuedAPI}, - "id-all": {ID: "id-all", Name: "All", Issued: server.GroupIssuedAPI}, + Groups: map[string]*nbgroup.Group{ + "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT}, + "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI}, + "id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, }, }, user, nil }, @@ -118,7 +119,7 @@ func TestGetGroup(t *testing.T) { }, } - group := &server.Group{ + group := &nbgroup.Group{ ID: "idofthegroup", Name: "Group", } @@ -153,7 +154,7 @@ func TestGetGroup(t *testing.T) { t.Fatalf("I don't know what I expected; %v", err) } - got := &server.Group{} + got := &nbgroup.Group{} if err = json.Unmarshal(content, &got); err != nil { t.Fatalf("Sent content is not in correct json format; %v", err) } diff --git a/management/server/http/handler.go b/management/server/http/handler.go index d035ae0b7..bdbeba346 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -9,7 +9,6 @@ import ( "github.com/rs/cors" "github.com/netbirdio/management-integrations/integrations" - s "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/middleware" diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index d4d2558e8..77b4578f8 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -6,8 +6,10 @@ import ( "net/http" "github.com/gorilla/mux" + log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -61,10 +63,18 @@ func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w groupsInfo := toGroupsInfo(account.Groups, peer.ID) - netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain()) + validPeers, err := h.accountManager.GetValidatedPeers(account) + if err != nil { + log.Errorf("failed to list appreoved peers: %v", err) + util.WriteError(fmt.Errorf("internal error"), w) + return + } + + netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validPeers) accessiblePeers := toAccessiblePeers(netMap, dnsDomain) - util.WriteJSONObject(w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers)) + _, valid := validPeers[peer.ID] + util.WriteJSONObject(w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid)) } func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) { @@ -75,11 +85,18 @@ func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, pe return } - update := &nbpeer.Peer{ID: peerID, SSHEnabled: req.SshEnabled, Name: req.Name, - LoginExpirationEnabled: req.LoginExpirationEnabled} + update := &nbpeer.Peer{ + ID: peerID, + SSHEnabled: req.SshEnabled, + Name: req.Name, + LoginExpirationEnabled: req.LoginExpirationEnabled, + } if req.ApprovalRequired != nil { - update.Status = &nbpeer.PeerStatus{RequiresApproval: *req.ApprovalRequired} + // todo: looks like that we reset all status property, is it right? + update.Status = &nbpeer.PeerStatus{ + RequiresApproval: *req.ApprovalRequired, + } } peer, err := h.accountManager.UpdatePeer(account.Id, user.Id, update) @@ -91,15 +108,24 @@ func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, pe groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) - netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain()) + validPeers, err := h.accountManager.GetValidatedPeers(account) + if err != nil { + log.Errorf("failed to list appreoved peers: %v", err) + util.WriteError(fmt.Errorf("internal error"), w) + return + } + netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validPeers) accessiblePeers := toAccessiblePeers(netMap, dnsDomain) - util.WriteJSONObject(w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers)) + _, valid := validPeers[peer.ID] + + util.WriteJSONObject(w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid)) } func (h *PeersHandler) deletePeer(accountID, userID string, peerID string, w http.ResponseWriter) { err := h.accountManager.DeletePeer(accountID, peerID, userID) if err != nil { + log.Errorf("failed to delete peer: %v", err) util.WriteError(err, w) return } @@ -138,46 +164,68 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { // GetAllPeers returns a list of all peers associated with a provided account func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodGet: - claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) - if err != nil { - util.WriteError(err, w) - return - } - - peers, err := h.accountManager.GetPeers(account.Id, user.Id) - if err != nil { - util.WriteError(err, w) - return - } - - dnsDomain := h.accountManager.GetDNSDomain() - - respBody := make([]*api.PeerBatch, 0, len(peers)) - for _, peer := range peers { - peerToReturn, err := h.checkPeerStatus(peer) - if err != nil { - util.WriteError(err, w) - return - } - groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) - - accessiblePeerNumbers := h.accessiblePeersNumber(account, peer.ID) - - respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, accessiblePeerNumbers)) - } - util.WriteJSONObject(w, respBody) - return - default: + if r.Method != http.MethodGet { util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w) + return } + + claims := h.claimsExtractor.FromRequestContext(r) + account, user, err := h.accountManager.GetAccountFromToken(claims) + if err != nil { + util.WriteError(err, w) + return + } + + peers, err := h.accountManager.GetPeers(account.Id, user.Id) + if err != nil { + util.WriteError(err, w) + return + } + + dnsDomain := h.accountManager.GetDNSDomain() + + respBody := make([]*api.PeerBatch, 0, len(peers)) + for _, peer := range peers { + peerToReturn, err := h.checkPeerStatus(peer) + if err != nil { + util.WriteError(err, w) + return + } + groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) + + accessiblePeerNumbers, _ := h.accessiblePeersNumber(account, peer.ID) + + respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, accessiblePeerNumbers)) + } + + validPeersMap, err := h.accountManager.GetValidatedPeers(account) + if err != nil { + log.Errorf("failed to list appreoved peers: %v", err) + util.WriteError(fmt.Errorf("internal error"), w) + return + } + h.setApprovalRequiredFlag(respBody, validPeersMap) + + util.WriteJSONObject(w, respBody) } -func (h *PeersHandler) accessiblePeersNumber(account *server.Account, peerID string) int { - netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain()) - return len(netMap.Peers) + len(netMap.OfflinePeers) +func (h *PeersHandler) accessiblePeersNumber(account *server.Account, peerID string) (int, error) { + validatedPeersMap, err := h.accountManager.GetValidatedPeers(account) + if err != nil { + return 0, err + } + + netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validatedPeersMap) + return len(netMap.Peers) + len(netMap.OfflinePeers), nil +} + +func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) { + for _, peer := range respBody { + _, ok := approvedPeersMap[peer.Id] + if !ok { + peer.ApprovalRequired = true + } + } } func toAccessiblePeers(netMap *server.NetworkMap, dnsDomain string) []api.AccessiblePeer { @@ -206,7 +254,7 @@ func toAccessiblePeers(netMap *server.NetworkMap, dnsDomain string) []api.Access return accessiblePeers } -func toGroupsInfo(groups map[string]*server.Group, peerID string) []api.GroupMinimum { +func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum { var groupsInfo []api.GroupMinimum groupsChecked := make(map[string]struct{}) for _, group := range groups { @@ -230,7 +278,7 @@ func toGroupsInfo(groups map[string]*server.Group, peerID string) []api.GroupMin return groupsInfo } -func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeer []api.AccessiblePeer) *api.Peer { +func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeer []api.AccessiblePeer, approved bool) *api.Peer { osVersion := peer.Meta.OSVersion if osVersion == "" { osVersion = peer.Meta.Core @@ -257,7 +305,7 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD LastLogin: peer.LastLogin, LoginExpired: peer.Status.LoginExpired, AccessiblePeers: accessiblePeer, - ApprovalRequired: &peer.Status.RequiresApproval, + ApprovalRequired: !approved, CountryCode: peer.Location.CountryCode, CityName: peer.Location.CityName, } @@ -290,7 +338,6 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn LastLogin: peer.LastLogin, LoginExpired: peer.Status.LoginExpired, AccessiblePeersCount: accessiblePeersCount, - ApprovalRequired: &peer.Status.RequiresApproval, CountryCode: peer.Location.CountryCode, CityName: peer.Location.CityName, } diff --git a/management/server/http/policies_handler_test.go b/management/server/http/policies_handler_test.go index e6b858036..74e682854 100644 --- a/management/server/http/policies_handler_test.go +++ b/management/server/http/policies_handler_test.go @@ -9,6 +9,7 @@ import ( "strings" "testing" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/status" @@ -51,7 +52,7 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies { Policies: []*server.Policy{ {ID: "id-existed"}, }, - Groups: map[string]*server.Group{ + Groups: map[string]*nbgroup.Group{ "F": {ID: "F"}, "G": {ID: "G"}, }, diff --git a/management/server/http/setupkeys_handler_test.go b/management/server/http/setupkeys_handler_test.go index 7b68479ed..ebbd5954f 100644 --- a/management/server/http/setupkeys_handler_test.go +++ b/management/server/http/setupkeys_handler_test.go @@ -13,13 +13,12 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/status" - - "github.com/netbirdio/netbird/management/server/jwtclaims" - "github.com/netbirdio/netbird/management/server" + nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/status" ) const ( @@ -44,7 +43,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup SetupKeys: map[string]*server.SetupKey{ defaultKey.Key: defaultKey, }, - Groups: map[string]*server.Group{ + Groups: map[string]*nbgroup.Group{ "group-1": {ID: "group-1", Peers: []string{"A", "B"}}, "id-all": {ID: "id-all", Name: "All"}, }, diff --git a/management/server/http/util/util.go b/management/server/http/util/util.go index 4e2c3d0b3..2bb279c76 100644 --- a/management/server/http/util/util.go +++ b/management/server/http/util/util.go @@ -99,6 +99,8 @@ func WriteError(err error, w http.ResponseWriter) { httpStatus = http.StatusUnprocessableEntity case status.Unauthorized: httpStatus = http.StatusUnauthorized + case status.BadRequest: + httpStatus = http.StatusBadRequest default: } msg = strings.ToLower(err.Error()) diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go new file mode 100644 index 000000000..cd770a801 --- /dev/null +++ b/management/server/integrated_validator.go @@ -0,0 +1,80 @@ +package server + +import ( + "errors" + + "github.com/google/martian/v3/log" + + "github.com/netbirdio/netbird/management/server/account" +) + +// UpdateIntegratedValidatorGroups updates the integrated validator groups for a specified account. +// It retrieves the account associated with the provided userID, then updates the integrated validator groups +// with the provided list of group ids. The updated account is then saved. +// +// Parameters: +// - accountID: The ID of the account for which integrated validator groups are to be updated. +// - userID: The ID of the user whose account is being updated. +// - groups: A slice of strings representing the ids of integrated validator groups to be updated. +// +// Returns: +// - error: An error if any occurred during the process, otherwise returns nil +func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error { + ok, err := am.GroupValidation(accountID, groups) + if err != nil { + log.Debugf("error validating groups: %s", err.Error()) + return err + } + + if !ok { + log.Debugf("invalid groups") + return errors.New("invalid groups") + } + + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + a, err := am.Store.GetAccountByUser(userID) + if err != nil { + return err + } + + var extra *account.ExtraSettings + + if a.Settings.Extra != nil { + extra = a.Settings.Extra + } else { + extra = &account.ExtraSettings{} + a.Settings.Extra = extra + } + extra.IntegratedValidatorGroups = groups + return am.Store.SaveAccount(a) +} + +func (am *DefaultAccountManager) GroupValidation(accountId string, groups []string) (bool, error) { + if len(groups) == 0 { + return true, nil + } + accountsGroups, err := am.ListGroups(accountId) + if err != nil { + return false, err + } + for _, group := range groups { + var found bool + for _, accountGroup := range accountsGroups { + if accountGroup.ID == group { + found = true + break + } + } + if !found { + return false, nil + } + } + + return true, nil +} + +func (am *DefaultAccountManager) GetValidatedPeers(account *Account) (map[string]struct{}, error) { + return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) +} diff --git a/management/server/integrated_validator/interface.go b/management/server/integrated_validator/interface.go new file mode 100644 index 000000000..e87755b87 --- /dev/null +++ b/management/server/integrated_validator/interface.go @@ -0,0 +1,19 @@ +package integrated_validator + +import ( + "github.com/netbirdio/netbird/management/server/account" + nbgroup "github.com/netbirdio/netbird/management/server/group" + nbpeer "github.com/netbirdio/netbird/management/server/peer" +) + +// IntegratedValidator interface exists to avoid the circle dependencies +type IntegratedValidator interface { + ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error + ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) + PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer + IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool) + GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) + PeerDeleted(accountID, peerID string) error + SetPeerInvalidationListener(fn func(accountID string)) + Stop() +} diff --git a/management/server/integration_reference/integration_reference.go b/management/server/integration_reference/integration_reference.go new file mode 100644 index 000000000..254b4e62f --- /dev/null +++ b/management/server/integration_reference/integration_reference.go @@ -0,0 +1,23 @@ +package integration_reference + +import ( + "fmt" + "strings" +) + +// IntegrationReference holds the reference to a particular integration +type IntegrationReference struct { + ID int + IntegrationType string +} + +func (ir IntegrationReference) String() string { + return fmt.Sprintf("%s:%d", ir.IntegrationType, ir.ID) +} + +func (ir IntegrationReference) CacheKey(path ...string) string { + if len(path) == 0 { + return ir.String() + } + return fmt.Sprintf("%s:%s", ir.String(), strings.Join(path, ":")) +} diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 6ea902003..98ad0de0c 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -9,8 +9,6 @@ import ( "testing" "time" - "github.com/netbirdio/netbird/management/server/activity" - "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" @@ -19,6 +17,7 @@ import ( "github.com/netbirdio/netbird/encryption" mgmtProto "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/util" ) @@ -413,7 +412,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error) peersUpdateManager := NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} accountManager, err := BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", - eventStore, nil, false) + eventStore, nil, false, MocIntegratedValidator{}) if err != nil { return nil, "", err } diff --git a/management/server/management_test.go b/management/server/management_test.go index fb3f74cb9..13db5ae95 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -10,24 +10,22 @@ import ( sync2 "sync" "time" - "github.com/netbirdio/netbird/management/server/activity" - - "google.golang.org/grpc/credentials/insecure" - - "github.com/netbirdio/netbird/management/server" - pb "github.com/golang/protobuf/proto" //nolint - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/encryption" - . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" + log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" + "github.com/netbirdio/netbird/encryption" mgmtProto "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/group" + nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/util" ) @@ -448,6 +446,43 @@ var _ = Describe("Management service", func() { }) }) +type MocIntegratedValidator struct { +} + +func (a MocIntegratedValidator) ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { + return nil +} + +func (a MocIntegratedValidator) ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) { + return update, nil +} + +func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { + validatedPeers := make(map[string]struct{}) + for p := range peers { + validatedPeers[p] = struct{}{} + } + return validatedPeers, nil +} + +func (MocIntegratedValidator) PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { + return peer +} + +func (MocIntegratedValidator) IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool) { + return false, false +} + +func (MocIntegratedValidator) PeerDeleted(_, _ string) error { + return nil +} + +func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) { + +} + +func (MocIntegratedValidator) Stop() {} + func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, client mgmtProto.ManagementServiceClient) *mgmtProto.LoginResponse { defer GinkgoRecover() @@ -504,7 +539,7 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) { peersUpdateManager := server.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", - eventStore, nil, false) + eventStore, nil, false, MocIntegratedValidator{}) if err != nil { log.Fatalf("failed creating a manager: %v", err) } diff --git a/management/server/metrics/selfhosted_test.go b/management/server/metrics/selfhosted_test.go index 7be3c818d..c479867d2 100644 --- a/management/server/metrics/selfhosted_test.go +++ b/management/server/metrics/selfhosted_test.go @@ -5,6 +5,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/route" @@ -32,7 +33,7 @@ func (mockDatasource) GetAllAccounts() []*server.Account { UsedTimes: 1, }, }, - Groups: map[string]*server.Group{ + Groups: map[string]*group.Group{ "1": {}, "2": {}, }, @@ -117,7 +118,7 @@ func (mockDatasource) GetAllAccounts() []*server.Account { UsedTimes: 1, }, }, - Groups: map[string]*server.Group{ + Groups: map[string]*group.Group{ "1": {}, "2": {}, }, diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 9463498cf..8e7c47a28 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -10,6 +10,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -31,12 +32,12 @@ type MockAccountManager struct { GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error) GetPeerNetworkFunc func(peerKey string) (*server.Network, error) AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error) - GetGroupFunc func(accountID, groupID, userID string) (*server.Group, error) - GetAllGroupsFunc func(accountID, userID string) ([]*server.Group, error) - GetGroupByNameFunc func(accountID, groupName string) (*server.Group, error) - SaveGroupFunc func(accountID, userID string, group *server.Group) error + GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error) + GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error) + GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error) + SaveGroupFunc func(accountID, userID string, group *group.Group) error DeleteGroupFunc func(accountID, userId, groupID string) error - ListGroupsFunc func(accountID string) ([]*server.Group, error) + ListGroupsFunc func(accountID string) ([]*group.Group, error) GroupAddPeerFunc func(accountID, groupID, peerID string) error GroupDeletePeerFunc func(accountID, groupID, peerID string) error DeleteRuleFunc func(accountID, ruleID, userID string) error @@ -91,10 +92,20 @@ type MockAccountManager struct { DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error) GetIdpManagerFunc func() idp.Manager + UpdateIntegratedValidatorGroupsFunc func(accountID string, userID string, groups []string) error + GroupValidationFunc func(accountId string, groups []string) (bool, error) +} + +func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) { + approvedPeers := make(map[string]struct{}) + for id := range account.Peers { + approvedPeers[id] = struct{}{} + } + return approvedPeers, nil } // GetGroup mock implementation of GetGroup from server.AccountManager interface -func (am *MockAccountManager) GetGroup(accountId, groupID, userID string) (*server.Group, error) { +func (am *MockAccountManager) GetGroup(accountId, groupID, userID string) (*group.Group, error) { if am.GetGroupFunc != nil { return am.GetGroupFunc(accountId, groupID, userID) } @@ -102,7 +113,7 @@ func (am *MockAccountManager) GetGroup(accountId, groupID, userID string) (*serv } // GetAllGroups mock implementation of GetAllGroups from server.AccountManager interface -func (am *MockAccountManager) GetAllGroups(accountID, userID string) ([]*server.Group, error) { +func (am *MockAccountManager) GetAllGroups(accountID, userID string) ([]*group.Group, error) { if am.GetAllGroupsFunc != nil { return am.GetAllGroupsFunc(accountID, userID) } @@ -261,7 +272,7 @@ func (am *MockAccountManager) AddPeer( } // GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface -func (am *MockAccountManager) GetGroupByName(accountID, groupName string) (*server.Group, error) { +func (am *MockAccountManager) GetGroupByName(accountID, groupName string) (*group.Group, error) { if am.GetGroupFunc != nil { return am.GetGroupByNameFunc(accountID, groupName) } @@ -269,7 +280,7 @@ func (am *MockAccountManager) GetGroupByName(accountID, groupName string) (*serv } // SaveGroup mock implementation of SaveGroup from server.AccountManager interface -func (am *MockAccountManager) SaveGroup(accountID, userID string, group *server.Group) error { +func (am *MockAccountManager) SaveGroup(accountID, userID string, group *group.Group) error { if am.SaveGroupFunc != nil { return am.SaveGroupFunc(accountID, userID, group) } @@ -285,7 +296,7 @@ func (am *MockAccountManager) DeleteGroup(accountId, userId, groupID string) err } // ListGroups mock implementation of ListGroups from server.AccountManager interface -func (am *MockAccountManager) ListGroups(accountID string) ([]*server.Group, error) { +func (am *MockAccountManager) ListGroups(accountID string) ([]*group.Group, error) { if am.ListGroupsFunc != nil { return am.ListGroupsFunc(accountID) } @@ -694,3 +705,19 @@ func (am *MockAccountManager) GetIdpManager() idp.Manager { } return nil } + +// UpdateIntegratedValidatedGroups mocks UpdateIntegratedApprovalGroups of the AccountManager interface +func (am *MockAccountManager) UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error { + if am.UpdateIntegratedValidatorGroupsFunc != nil { + return am.UpdateIntegratedValidatorGroupsFunc(accountID, userID, groups) + } + return status.Errorf(codes.Unimplemented, "method UpdateIntegratedValidatorGroups is not implemented") +} + +// GroupValidation mocks GroupValidation of the AccountManager interface +func (am *MockAccountManager) GroupValidation(accountId string, groups []string) (bool, error) { + if am.GroupValidationFunc != nil { + return am.GroupValidationFunc(accountId, groups) + } + return false, status.Errorf(codes.Unimplemented, "method GroupValidation is not implemented") +} diff --git a/management/server/nameserver.go b/management/server/nameserver.go index e521805c8..fa7793602 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -10,6 +10,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" ) @@ -261,7 +262,7 @@ func validateNSList(list []nbdns.NameServer) error { return nil } -func validateGroups(list []string, groups map[string]*Group) error { +func validateGroups(list []string, groups map[string]*nbgroup.Group) error { if len(list) == 0 { return status.Errorf(status.InvalidArgument, "the list of group IDs should not be empty") } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index d04ac1a20..b10f9387a 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -8,6 +8,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" ) @@ -759,7 +760,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false) + return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}) } func createNSStore(t *testing.T) (Store, error) { @@ -831,12 +832,12 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup - newGroup1 := &Group{ + newGroup1 := &nbgroup.Group{ ID: group1ID, Name: group1ID, } - newGroup2 := &Group{ + newGroup2 := &nbgroup.Group{ ID: group2ID, Name: group2ID, } diff --git a/management/server/peer.go b/management/server/peer.go index 7de1b6542..fda8e49e9 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -7,16 +7,12 @@ import ( "time" "github.com/rs/xid" - - "github.com/netbirdio/management-integrations/additions" - - "github.com/netbirdio/netbird/management/server/activity" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/status" - log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/management/server/activity" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/status" ) // PeerSync used as a data object between the gRPC API and AccountManager on Sync request. @@ -37,6 +33,8 @@ type PeerLogin struct { UserID string // SetupKey references to a server.SetupKey to log in. Can be empty when UserID is used or auth is not required. SetupKey string + // ConnectionIP is the real IP of the peer + ConnectionIP net.IP } // GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if @@ -52,6 +50,10 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.P return nil, err } + approvedPeersMap, err := am.GetValidatedPeers(account) + if err != nil { + return nil, err + } peers := make([]*nbpeer.Peer, 0) peersMap := make(map[string]*nbpeer.Peer) @@ -71,7 +73,7 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.P // fetch all the peers that have access to the user's peers for _, peer := range peers { - aclPeers, _ := account.getPeerConnectionResources(peer.ID) + aclPeers, _ := account.getPeerConnectionResources(peer.ID, approvedPeersMap) for _, p := range aclPeers { peersMap[p.ID] = p } @@ -167,7 +169,7 @@ func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nb return nil, status.Errorf(status.NotFound, "peer %s not found", update.ID) } - update, err = additions.ValidatePeersUpdateRequest(update, peer, userID, accountID, am.eventStore, am.GetDNSDomain()) + update, err = am.integratedPeerValidator.ValidatePeer(update, peer, userID, accountID, am.GetDNSDomain(), account.GetPeerGroupsList(peer.ID), account.Settings.Extra) if err != nil { return nil, err } @@ -244,6 +246,12 @@ func (am *DefaultAccountManager) deletePeers(account *Account, peerIDs []string, // the 2nd loop performs the actual modification for _, peer := range peers { + + err := am.integratedPeerValidator.PeerDeleted(account.Id, peer.ID) + if err != nil { + return err + } + account.DeletePeer(peer.ID) am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{ @@ -304,7 +312,17 @@ func (am *DefaultAccountManager) GetNetworkMap(peerID string) (*NetworkMap, erro if peer == nil { return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID) } - return account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil + + groups := make(map[string][]string) + for groupID, group := range account.Groups { + groups[groupID] = group.Peers + } + + validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) + if err != nil { + return nil, err + } + return account.GetPeerNetworkMap(peer.ID, am.dnsDomain, validatedPeers), nil } // GetPeerNetwork returns the Network for a given peer @@ -433,10 +451,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P CreatedAt: registrationTime, LoginExpirationEnabled: addedByUser, Ephemeral: ephemeral, - } - - if account.Settings.Extra != nil { - newPeer = additions.PreparePeer(newPeer, account.Settings.Extra) + Location: peer.Location, } // add peer to 'All' group @@ -467,6 +482,8 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P } } + newPeer = am.integratedPeerValidator.PreparePeer(account.Id, newPeer, account.GetPeerGroupsList(newPeer.ID), account.Settings.Extra) + if addedByUser { user, err := account.FindUser(userID) if err != nil { @@ -492,7 +509,11 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P am.updateAccountPeers(account) - networkMap := account.GetPeerNetworkMap(newPeer.ID, am.dnsDomain) + approvedPeersMap, err := am.GetValidatedPeers(account) + if err != nil { + return nil, nil, err + } + networkMap := account.GetPeerNetworkMap(newPeer.ID, am.dnsDomain, approvedPeersMap) return newPeer, networkMap, nil } @@ -529,23 +550,53 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *Network if peerLoginExpired(peer, account) { return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more") } - return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil + + requiresApproval, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) + if requiresApproval { + emptyMap := &NetworkMap{ + Network: account.Network.Copy(), + } + return peer, emptyMap, nil + } + + if isStatusChanged { + am.updateAccountPeers(account) + } + + approvedPeersMap, err := am.GetValidatedPeers(account) + if err != nil { + return nil, nil, err + } + return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap), nil } // LoginPeer logs in or registers a peer. // If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so. func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) { account, err := am.Store.GetAccountByPeerPubKey(login.WireGuardPubKey) - if err != nil { if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { // we couldn't find this peer by its public key which can mean that peer hasn't been registered yet. // Try registering it. - return am.AddPeer(login.SetupKey, login.UserID, &nbpeer.Peer{ + newPeer := &nbpeer.Peer{ Key: login.WireGuardPubKey, Meta: login.Meta, SSHKey: login.SSHKey, - }) + } + if am.geo != nil && login.ConnectionIP != nil { + location, err := am.geo.Lookup(login.ConnectionIP) + if err != nil { + log.Warnf("failed to get location for new peer realip: [%s]: %v", login.ConnectionIP.String(), err) + } else { + newPeer.Location.ConnectionIP = login.ConnectionIP + newPeer.Location.CountryCode = location.Country.ISOCode + newPeer.Location.CityName = location.City.Names.En + newPeer.Location.GeoNameID = location.City.GeonameID + + } + } + + return am.AddPeer(login.SetupKey, login.UserID, newPeer) } log.Errorf("failed while logging in peer %s: %v", login.WireGuardPubKey, err) return nil, nil, status.Errorf(status.Internal, "failed while logging in peer") @@ -595,6 +646,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw am.StoreEvent(login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain())) } + isRequiresApproval, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) peer, updated := updatePeerMeta(peer, login.Meta, account) if updated { shouldStoreAccount = true @@ -612,10 +664,23 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw } } - if updateRemotePeers { + if updateRemotePeers || isStatusChanged { am.updateAccountPeers(account) } - return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil + + if isRequiresApproval { + emptyMap := &NetworkMap{ + Network: account.Network.Copy(), + } + return peer, emptyMap, nil + } + + approvedPeersMap, err := am.GetValidatedPeers(account) + if err != nil { + return nil, nil, err + } + + return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap), nil } func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error { @@ -764,8 +829,13 @@ func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*nbp return nil, err } + approvedPeersMap, err := am.GetValidatedPeers(account) + if err != nil { + return nil, err + } + for _, p := range userPeers { - aclPeers, _ := account.getPeerConnectionResources(p.ID) + aclPeers, _ := account.getPeerConnectionResources(p.ID, approvedPeersMap) for _, aclPeer := range aclPeers { if aclPeer.ID == peerID { return peer, nil @@ -789,8 +859,13 @@ func updatePeerMeta(peer *nbpeer.Peer, meta nbpeer.PeerSystemMeta, account *Acco func (am *DefaultAccountManager) updateAccountPeers(account *Account) { peers := account.GetPeers() + approvedPeersMap, err := am.GetValidatedPeers(account) + if err != nil { + log.Errorf("failed send out updates to peers, failed to validate peer: %v", err) + return + } for _, peer := range peers { - remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain) + remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap) update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain()) am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update}) } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 7f6d440bb..6063cc2a7 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/assert" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" ) @@ -199,8 +200,8 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { return } var ( - group1 Group - group2 Group + group1 nbgroup.Group + group2 nbgroup.Group policy Policy ) diff --git a/management/server/policy.go b/management/server/policy.go index 8265dabb5..e162d2b3b 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -5,11 +5,11 @@ import ( "strconv" "strings" - "github.com/netbirdio/management-integrations/additions" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" @@ -211,7 +211,8 @@ type FirewallRule struct { // getPeerConnectionResources for a given peer // // This function returns the list of peers and firewall rules that are applicable to a given peer. -func (a *Account) getPeerConnectionResources(peerID string) ([]*nbpeer.Peer, []*FirewallRule) { +func (a *Account) getPeerConnectionResources(peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { + generateResources, getAccumulatedResources := a.connResourcesGenerator() for _, policy := range a.Policies { if !policy.Enabled { @@ -223,10 +224,8 @@ func (a *Account) getPeerConnectionResources(peerID string) ([]*nbpeer.Peer, []* continue } - sourcePeers, peerInSources := getAllPeersFromGroups(a, rule.Sources, peerID, policy.SourcePostureChecks) - destinationPeers, peerInDestinations := getAllPeersFromGroups(a, rule.Destinations, peerID, nil) - sourcePeers = additions.ValidatePeers(sourcePeers) - destinationPeers = additions.ValidatePeers(destinationPeers) + sourcePeers, peerInSources := getAllPeersFromGroups(a, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) + destinationPeers, peerInDestinations := getAllPeersFromGroups(a, rule.Destinations, peerID, nil, validatedPeersMap) if rule.Bidirectional { if peerInSources { @@ -264,7 +263,7 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, in all, err := a.GetGroupAll() if err != nil { log.Errorf("failed to get group all: %v", err) - all = &Group{} + all = &nbgroup.Group{} } return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) { @@ -491,7 +490,7 @@ func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule { // // Important: Posture checks are applicable only to source group peers, // for destination group peers, call this method with an empty list of sourcePostureChecksIDs -func getAllPeersFromGroups(account *Account, groups []string, peerID string, sourcePostureChecksIDs []string) ([]*nbpeer.Peer, bool) { +func getAllPeersFromGroups(account *Account, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { peerInGroups := false filteredPeers := make([]*nbpeer.Peer, 0, len(groups)) for _, g := range groups { @@ -512,6 +511,10 @@ func getAllPeersFromGroups(account *Account, groups []string, peerID string, sou continue } + if _, ok := validatedPeersMap[peer.ID]; !ok { + continue + } + if peer.ID == peerID { peerInGroups = true continue diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 681bab1da..1ea3bb379 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/assert" "golang.org/x/exp/slices" + nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" ) @@ -56,7 +57,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) { Status: &nbpeer.PeerStatus{}, }, }, - Groups: map[string]*Group{ + Groups: map[string]*nbgroup.Group{ "GroupAll": { ID: "GroupAll", Name: "All", @@ -135,16 +136,21 @@ func TestAccount_getPeersByPolicy(t *testing.T) { }, } + validatedPeers := make(map[string]struct{}) + for p := range account.Peers { + validatedPeers[p] = struct{}{} + } + t.Run("check that all peers get map", func(t *testing.T) { for _, p := range account.Peers { - peers, firewallRules := account.getPeerConnectionResources(p.ID) + peers, firewallRules := account.getPeerConnectionResources(p.ID, validatedPeers) assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present") assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present") } }) t.Run("check first peer map details", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources("peerB") + peers, firewallRules := account.getPeerConnectionResources("peerB", validatedPeers) assert.Len(t, peers, 7) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) @@ -299,7 +305,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { Status: &nbpeer.PeerStatus{}, }, }, - Groups: map[string]*Group{ + Groups: map[string]*nbgroup.Group{ "GroupAll": { ID: "GroupAll", Name: "All", @@ -374,8 +380,13 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }, } + approvedPeers := make(map[string]struct{}) + for p := range account.Peers { + approvedPeers[p] = struct{}{} + } + t.Run("check first peer map", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources("peerB") + peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers) assert.Contains(t, peers, account.Peers["peerC"]) epectedFirewallRules := []*FirewallRule{ @@ -403,7 +414,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }) t.Run("check second peer map", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources("peerC") + peers, firewallRules := account.getPeerConnectionResources("peerC", approvedPeers) assert.Contains(t, peers, account.Peers["peerB"]) epectedFirewallRules := []*FirewallRule{ @@ -433,7 +444,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { account.Policies[1].Rules[0].Bidirectional = false t.Run("check first peer map directional only", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources("peerB") + peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers) assert.Contains(t, peers, account.Peers["peerC"]) epectedFirewallRules := []*FirewallRule{ @@ -454,7 +465,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }) t.Run("check second peer map directional only", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources("peerC") + peers, firewallRules := account.getPeerConnectionResources("peerC", approvedPeers) assert.Contains(t, peers, account.Peers["peerB"]) epectedFirewallRules := []*FirewallRule{ @@ -569,7 +580,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, }, }, - Groups: map[string]*Group{ + Groups: map[string]*nbgroup.Group{ "GroupAll": { ID: "GroupAll", Name: "All", @@ -644,10 +655,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, }) + approvedPeers := make(map[string]struct{}) + for p := range account.Peers { + approvedPeers[p] = struct{}{} + } t.Run("verify peer's network map with default group peer list", func(t *testing.T) { // peerB doesn't fulfill the NB posture check but is included in the destination group Swarm, // will establish a connection with all source peers satisfying the NB posture check. - peers, firewallRules := account.getPeerConnectionResources("peerB") + peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -657,7 +672,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.getPeerConnectionResources("peerC") + peers, firewallRules = account.getPeerConnectionResources("peerC", approvedPeers) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, 1) expectedFirewallRules := []*FirewallRule{ @@ -673,7 +688,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources("peerE") + peers, firewallRules = account.getPeerConnectionResources("peerE", approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -683,7 +698,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources("peerI") + peers, firewallRules = account.getPeerConnectionResources("peerI", approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -698,19 +713,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules := account.getPeerConnectionResources("peerB") + peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules = account.getPeerConnectionResources("peerI") + peers, firewallRules = account.getPeerConnectionResources("peerI", approvedPeers) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.getPeerConnectionResources("peerC") + peers, firewallRules = account.getPeerConnectionResources("peerC", approvedPeers) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers)) @@ -725,14 +740,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources("peerE") + peers, firewallRules = account.getPeerConnectionResources("peerE", approvedPeers) assert.Len(t, peers, 3) assert.Len(t, firewallRules, 3) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerD"]) - peers, firewallRules = account.getPeerConnectionResources("peerA") + peers, firewallRules = account.getPeerConnectionResources("peerA", approvedPeers) assert.Len(t, peers, 5) // assert peers from Group Swarm assert.Contains(t, peers, account.Peers["peerD"]) diff --git a/management/server/route_test.go b/management/server/route_test.go index 5a56eaa8b..9f8ea08c9 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/route" ) @@ -858,7 +859,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { groups, err := am.ListGroups(account.Id) require.NoError(t, err) - var groupHA1, groupHA2 *Group + var groupHA1, groupHA2 *nbgroup.Group for _, group := range groups { switch group.Name { case routeGroupHA1: @@ -967,7 +968,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { require.Len(t, peer2Routes.Routes, 1, "we should receive one route") require.True(t, peer1Routes.Routes[0].IsEqual(peer2Routes.Routes[0]), "routes should be the same for peers in the same group") - newGroup := &Group{ + newGroup := &nbgroup.Group{ ID: xid.New().String(), Name: "peer1 group", Peers: []string{peer1ID}, @@ -1014,7 +1015,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false) + return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}) } func createRouterStore(t *testing.T) (Store, error) { @@ -1195,7 +1196,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er return nil, err } - newGroup := []*Group{ + newGroup := []*nbgroup.Group{ { ID: routeGroup1, Name: routeGroup1, diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index b714652f1..43edabbd6 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" ) func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { @@ -24,7 +25,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(account.Id, userID, &Group{ + err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{ ID: "group_1", Name: "group_name_1", Peers: []string{}, @@ -82,7 +83,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(account.Id, userID, &Group{ + err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{ ID: "group_1", Name: "group_name_1", Peers: []string{}, @@ -91,7 +92,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(account.Id, userID, &Group{ + err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{ ID: "group_2", Name: "group_name_2", Peers: []string{}, @@ -178,7 +179,7 @@ func TestGetSetupKeys(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(account.Id, userID, &Group{ + err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{ ID: "group_1", Name: "group_name_1", Peers: []string{}, @@ -187,7 +188,7 @@ func TestGetSetupKeys(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(account.Id, userID, &Group{ + err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{ ID: "group_2", Name: "group_name_2", Peers: []string{}, diff --git a/management/server/sqlite_store.go b/management/server/sqlite_store.go index f6a6f92a7..e6a9c8467 100644 --- a/management/server/sqlite_store.go +++ b/management/server/sqlite_store.go @@ -17,6 +17,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/account" + nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" @@ -64,7 +65,7 @@ func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore, sql.SetMaxOpenConns(conns) // TODO: make it configurable err = db.AutoMigrate( - &SetupKey{}, &nbpeer.Peer{}, &User{}, &PersonalAccessToken{}, &Group{}, + &SetupKey{}, &nbpeer.Peer{}, &User{}, &PersonalAccessToken{}, &nbgroup.Group{}, &Account{}, &Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &installation{}, &account.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, ) @@ -99,17 +100,17 @@ func NewSqliteStoreFromFileStore(filestore *FileStore, dataDir string, metrics t // AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock func (s *SqliteStore) AcquireGlobalLock() (unlock func()) { - log.Debugf("acquiring global lock") + log.Tracef("acquiring global lock") start := time.Now() s.globalAccountLock.Lock() unlock = func() { s.globalAccountLock.Unlock() - log.Debugf("released global lock in %v", time.Since(start)) + log.Tracef("released global lock in %v", time.Since(start)) } took := time.Since(start) - log.Debugf("took %v to acquire global lock", took) + log.Tracef("took %v to acquire global lock", took) if s.metrics != nil { s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took) } @@ -118,7 +119,7 @@ func (s *SqliteStore) AcquireGlobalLock() (unlock func()) { } func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) { - log.Debugf("acquiring lock for account %s", accountID) + log.Tracef("acquiring lock for account %s", accountID) start := time.Now() value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{}) @@ -127,7 +128,7 @@ func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) { unlock = func() { mtx.Unlock() - log.Debugf("released lock for account %s in %v", accountID, time.Since(start)) + log.Tracef("released lock for account %s in %v", accountID, time.Since(start)) } return unlock @@ -434,7 +435,7 @@ func (s *SqliteStore) GetAccount(accountID string) (*Account, error) { } account.UsersG = nil - account.Groups = make(map[string]*Group, len(account.GroupsG)) + account.Groups = make(map[string]*nbgroup.Group, len(account.GroupsG)) for _, group := range account.GroupsG { account.Groups[group.ID] = group.Copy() } diff --git a/management/server/user.go b/management/server/user.go index 15517db41..b955c4058 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -10,6 +10,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" @@ -49,23 +50,6 @@ type UserStatus string // UserRole is the role of a User type UserRole string -// IntegrationReference holds the reference to a particular integration -type IntegrationReference struct { - ID int - IntegrationType string -} - -func (ir IntegrationReference) String() string { - return fmt.Sprintf("%s:%d", ir.IntegrationType, ir.ID) -} - -func (ir IntegrationReference) CacheKey(path ...string) string { - if len(path) == 0 { - return ir.String() - } - return fmt.Sprintf("%s:%s", ir.String(), strings.Join(path, ":")) -} - // User represents a user of the system type User struct { Id string `gorm:"primaryKey"` @@ -91,7 +75,7 @@ type User struct { // Issued of the user Issued string `gorm:"default:api"` - IntegrationReference IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` + IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` } // IsBlocked returns true if the user is blocked, false otherwise diff --git a/management/server/user_test.go b/management/server/user_test.go index e34aa406d..c92f87e6c 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/jwtclaims" ) @@ -276,7 +277,7 @@ func TestUser_Copy(t *testing.T) { LastLogin: time.Now().UTC(), CreatedAt: time.Now().UTC(), Issued: "test", - IntegrationReference: IntegrationReference{ + IntegrationReference: integration_reference.IntegrationReference{ ID: 0, IntegrationType: "test", }, @@ -603,8 +604,9 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { } am := DefaultAccountManager{ - Store: store, - eventStore: &activity.InMemoryEventStore{}, + Store: store, + eventStore: &activity.InMemoryEventStore{}, + integratedPeerValidator: MocIntegratedValidator{}, } testCases := []struct { @@ -793,7 +795,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) { Id: "externalUser", Role: UserRoleUser, Issued: UserIssuedIntegration, - IntegrationReference: IntegrationReference{ + IntegrationReference: integration_reference.IntegrationReference{ ID: 1, IntegrationType: "external", },