Compare commits

...

4 Commits

Author SHA1 Message Date
Carlos Hernandez
f603cd9202 [client] Check wginterface instead of engine ctx (#2676)
Moving code to ensure wgInterface is gone right after context is
cancelled/stop in the off chance that on next retry the backoff
operation is permanently cancelled and interface is abandoned without
destroying.
2024-10-04 19:15:16 +02:00
Bethuel Mmbaga
5897a48e29 fix wrong reference (#2695)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-04 18:55:25 +03:00
Bethuel Mmbaga
8bf729c7b4 [management] Add AccountExists to AccountManager (#2694)
* Add AccountExists method to account manager interface

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* remove unused code

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-04 18:09:40 +03:00
Bethuel Mmbaga
7f09b39769 [management] Refactor User JWT group sync (#2690)
* Refactor GetAccountIDByUserOrAccountID

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* sync user jwt group changes

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* propagate jwt group changes to peers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix no jwt groups synced

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests and lint

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Move the account peer update outside the transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* move updateUserPeersInGroups to account manager

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* move event store outside of transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* get user with update lock

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Run jwt sync in transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-04 17:17:01 +03:00
10 changed files with 481 additions and 197 deletions

View File

@@ -269,12 +269,6 @@ func (c *ConnectClient) run(
checks := loginResp.GetChecks() checks := loginResp.GetChecks()
c.engineMutex.Lock() c.engineMutex.Lock()
if c.engine != nil && c.engine.ctx.Err() != nil {
log.Info("Stopping Netbird Engine")
if err := c.engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
}
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks) c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
c.engineMutex.Unlock() c.engineMutex.Unlock()
@@ -294,6 +288,15 @@ func (c *ConnectClient) run(
} }
<-engineCtx.Done() <-engineCtx.Done()
c.engineMutex.Lock()
if c.engine != nil && c.engine.wgInterface != nil {
log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name())
if err := c.engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
c.engine = nil
}
c.engineMutex.Unlock()
c.statusRecorder.ClientTeardown() c.statusRecorder.ClientTeardown()
backOff.Reset() backOff.Reset()

View File

@@ -251,6 +251,13 @@ func (e *Engine) Stop() error {
} }
log.Info("Network monitor: stopped") log.Info("Network monitor: stopped")
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
e.stopDNSServer()
if e.routeManager != nil {
e.routeManager.Stop()
}
err := e.removeAllPeers() err := e.removeAllPeers()
if err != nil { if err != nil {
return fmt.Errorf("failed to remove all peers: %s", err) return fmt.Errorf("failed to remove all peers: %s", err)
@@ -1116,18 +1123,12 @@ func (e *Engine) close() {
} }
} }
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
e.stopDNSServer()
if e.routeManager != nil {
e.routeManager.Stop()
}
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName) log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
if e.wgInterface != nil { if e.wgInterface != nil {
if err := e.wgInterface.Close(); err != nil { if err := e.wgInterface.Close(); err != nil {
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err) log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
} }
e.wgInterface = nil
} }
if !isNil(e.sshServer) { if !isNil(e.sshServer) {
@@ -1395,7 +1396,7 @@ func (e *Engine) startNetworkMonitor() {
} }
// Set a new timer to debounce rapid network changes // Set a new timer to debounce rapid network changes
debounceTimer = time.AfterFunc(1*time.Second, func() { debounceTimer = time.AfterFunc(2*time.Second, func() {
// This function is called after the debounce period // This function is called after the debounce period
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
@@ -1426,6 +1427,11 @@ func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
} }
func (e *Engine) stopDNSServer() { func (e *Engine) stopDNSServer() {
if e.dnsServer == nil {
return
}
e.dnsServer.Stop()
e.dnsServer = nil
err := fmt.Errorf("DNS server stopped") err := fmt.Errorf("DNS server stopped")
nsGroupStates := e.statusRecorder.GetDNSStates() nsGroupStates := e.statusRecorder.GetDNSStates()
for i := range nsGroupStates { for i := range nsGroupStates {
@@ -1433,10 +1439,6 @@ func (e *Engine) stopDNSServer() {
nsGroupStates[i].Error = err nsGroupStates[i].Error = err
} }
e.statusRecorder.UpdateDNSStates(nsGroupStates) e.statusRecorder.UpdateDNSStates(nsGroupStates)
if e.dnsServer != nil {
e.dnsServer.Stop()
e.dnsServer = nil
}
} }
// isChecksEqual checks if two slices of checks are equal. // isChecksEqual checks if two slices of checks are equal.

View File

@@ -76,7 +76,8 @@ type AccountManager interface {
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error)
GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) AccountExists(ctx context.Context, accountID string) (bool, error)
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error) GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
@@ -843,55 +844,54 @@ func (a *Account) GetPeer(peerID string) *nbpeer.Peer {
return a.Peers[peerID] return a.Peers[peerID]
} }
// SetJWTGroups updates the user's auto groups by synchronizing JWT groups. // getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups.
// Returns true if there are changes in the JWT group membership. // Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups,
func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { // newly groups to create and an error if any occurred.
user, ok := a.Users[userID] func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgroup.Group, groupNames []string) (bool, []string, []*nbgroup.Group, error) {
if !ok {
return false
}
existedGroupsByName := make(map[string]*nbgroup.Group) existedGroupsByName := make(map[string]*nbgroup.Group)
for _, group := range a.Groups { for _, group := range groups {
existedGroupsByName[group.Name] = group existedGroupsByName[group.Name] = group
} }
newAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, a.Groups) newUserAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, groups)
groupsToAdd := difference(groupsNames, maps.Keys(jwtGroupsMap))
groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupsNames) groupsToAdd := difference(groupNames, maps.Keys(jwtGroupsMap))
groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupNames)
// If no groups are added or removed, we should not sync account // If no groups are added or removed, we should not sync account
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
return false return false, nil, nil, nil
} }
newGroupsToCreate := make([]*nbgroup.Group, 0)
var modified bool var modified bool
for _, name := range groupsToAdd { for _, name := range groupsToAdd {
group, exists := existedGroupsByName[name] group, exists := existedGroupsByName[name]
if !exists { if !exists {
group = &nbgroup.Group{ group = &nbgroup.Group{
ID: xid.New().String(), ID: xid.New().String(),
AccountID: user.AccountID,
Name: name, Name: name,
Issued: nbgroup.GroupIssuedJWT, Issued: nbgroup.GroupIssuedJWT,
} }
a.Groups[group.ID] = group newGroupsToCreate = append(newGroupsToCreate, group)
} }
if group.Issued == nbgroup.GroupIssuedJWT { if group.Issued == nbgroup.GroupIssuedJWT {
newAutoGroups = append(newAutoGroups, group.ID) newUserAutoGroups = append(newUserAutoGroups, group.ID)
modified = true modified = true
} }
} }
for name, id := range jwtGroupsMap { for name, id := range jwtGroupsMap {
if !slices.Contains(groupsToRemove, name) { if !slices.Contains(groupsToRemove, name) {
newAutoGroups = append(newAutoGroups, id) newUserAutoGroups = append(newUserAutoGroups, id)
continue continue
} }
modified = true modified = true
} }
user.AutoGroups = newAutoGroups
return modified return modified, newUserAutoGroups, newGroupsToCreate, nil
} }
// UserGroupsAddToPeers adds groups to all peers of user // UserGroupsAddToPeers adds groups to all peers of user
@@ -1262,24 +1262,23 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
return nil return nil
} }
// GetAccountIDByUserOrAccountID retrieves the account ID based on either the userID or accountID provided. // AccountExists checks if an account exists.
// If an accountID is provided, it checks if the account exists and returns it. func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) {
// If no accountID is provided, but a userID is given, it tries to retrieve the account by userID. return am.Store.AccountExists(ctx, LockingStrengthShare, accountID)
}
// GetAccountIDByUserID retrieves the account ID based on the userID provided.
// If user does have an account, it returns the user's account ID.
// If the user doesn't have an account, it creates one using the provided domain. // If the user doesn't have an account, it creates one using the provided domain.
// Returns the account ID or an error if none is found or created. // Returns the account ID or an error if none is found or created.
func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) { func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) {
if accountID != "" { if userID == "" {
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID) return "", status.Errorf(status.NotFound, "no valid userID provided")
if err != nil {
return "", err
}
if !exists {
return "", status.Errorf(status.NotFound, "account %s does not exist", accountID)
}
return accountID, nil
} }
if userID != "" { accountID, err := am.Store.GetAccountIDByUserID(userID)
if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
if err != nil { if err != nil {
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID) return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
@@ -1288,11 +1287,11 @@ func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Conte
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil { if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
return "", err return "", err
} }
return account.Id, nil return account.Id, nil
} }
return "", err
return "", status.Errorf(status.NotFound, "no valid userID or accountID provided") }
return accountID, nil
} }
func isNil(i idp.Manager) bool { func isNil(i idp.Manager) bool {
@@ -1796,6 +1795,10 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai
return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId) return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId)
} }
if user.AccountID != accountID {
return "", "", status.Errorf(status.PermissionDenied, "user %s is not part of the account %s", claims.UserId, accountID)
}
if !user.IsServiceUser && claims.Invited { if !user.IsServiceUser && claims.Invited {
err = am.redeemInvite(ctx, accountID, user.Id) err = am.redeemInvite(ctx, accountID, user.Id)
if err != nil { if err != nil {
@@ -1803,7 +1806,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai
} }
} }
if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil { if err = am.syncJWTGroups(ctx, accountID, claims); err != nil {
return "", "", err return "", "", err
} }
@@ -1812,7 +1815,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
// and propagates changes to peers if group propagation is enabled. // and propagates changes to peers if group propagation is enabled.
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, user *User, claims jwtclaims.AuthorizationClaims) error { func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error {
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return err return err
@@ -1823,67 +1826,134 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
} }
if settings.JWTGroupsClaimName == "" { if settings.JWTGroupsClaimName == "" {
log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set") log.WithContext(ctx).Debugf("JWT groups are enabled but no claim name is set")
return nil return nil
} }
// TODO: Remove GetAccount after refactoring account peer's update
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
oldGroups := make([]string, len(user.AutoGroups)) unlockPeer := am.Store.AcquireWriteLockByUID(ctx, accountID)
copy(oldGroups, user.AutoGroups) defer func() {
if unlockPeer != nil {
unlockPeer()
}
}()
// Update the account if group membership changes var addNewGroups []string
if account.SetJWTGroups(claims.UserId, jwtGroupsNames) { var removeOldGroups []string
addNewGroups := difference(user.AutoGroups, oldGroups) var hasChanges bool
removeOldGroups := difference(oldGroups, user.AutoGroups) var user *User
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if settings.GroupsPropagationEnabled { user, err = am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...) if err != nil {
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...) return fmt.Errorf("error getting user: %w", err)
account.Network.IncSerial()
} }
if err := am.Store.SaveAccount(ctx, account); err != nil { groups, err := am.Store.GetAccountGroups(ctx, accountID)
log.WithContext(ctx).Errorf("failed to save account: %v", err) if err != nil {
return fmt.Errorf("error getting account groups: %w", err)
}
changed, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(user, groups, jwtGroupsNames)
if err != nil {
return fmt.Errorf("error getting JWT groups changes: %w", err)
}
hasChanges = changed
// skip update if no changes
if !changed {
return nil return nil
} }
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil {
return fmt.Errorf("error saving groups: %w", err)
}
addNewGroups = difference(updatedAutoGroups, user.AutoGroups)
removeOldGroups = difference(user.AutoGroups, updatedAutoGroups)
user.AutoGroups = updatedAutoGroups
if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil {
return fmt.Errorf("error saving user: %w", err)
}
// Propagate changes to peers if group propagation is enabled // Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled { if settings.GroupsPropagationEnabled {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) groups, err = transaction.GetAccountGroups(ctx, accountID)
am.updateAccountPeers(ctx, account) if err != nil {
return fmt.Errorf("error getting account groups: %w", err)
}
groupsMap := make(map[string]*nbgroup.Group, len(groups))
for _, group := range groups {
groupsMap[group.ID] = group
}
peers, err := transaction.GetUserPeers(ctx, LockingStrengthShare, accountID, claims.UserId)
if err != nil {
return fmt.Errorf("error getting user peers: %w", err)
}
updatedGroups, err := am.updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups)
if err != nil {
return fmt.Errorf("error modifying user peers in groups: %w", err)
}
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, updatedGroups); err != nil {
return fmt.Errorf("error saving groups: %w", err)
}
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return fmt.Errorf("error incrementing network serial: %w", err)
}
}
unlockPeer()
unlockPeer = nil
return nil
})
if err != nil {
return err
}
if !hasChanges {
return nil
} }
for _, g := range addNewGroups { for _, g := range addNewGroups {
if group := account.GetGroup(g); group != nil { group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser, if err != nil {
map[string]any{ log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
"group": group.Name, } else {
"group_id": group.ID, meta := map[string]any{
"is_service_user": user.IsServiceUser, "group": group.Name, "group_id": group.ID,
"user_name": user.ServiceUserName}) "is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
}
am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupAddedToUser, meta)
} }
} }
for _, g := range removeOldGroups { for _, g := range removeOldGroups {
if group := account.GetGroup(g); group != nil { group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser, if err != nil {
map[string]any{ log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
"group": group.Name, } else {
"group_id": group.ID, meta := map[string]any{
"is_service_user": user.IsServiceUser, "group": group.Name, "group_id": group.ID,
"user_name": user.ServiceUserName}) "is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
}
am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupRemovedFromUser, meta)
} }
} }
if settings.GroupsPropagationEnabled {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
}
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
am.updateAccountPeers(ctx, account)
} }
return nil return nil
@@ -1916,7 +1986,17 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
// if Account ID is part of the claims // if Account ID is part of the claims
// it means that we've already classified the domain and user has an account // it means that we've already classified the domain and user has an account
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
return am.GetAccountIDByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain) if claims.AccountId != "" {
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, claims.AccountId)
if err != nil {
return "", err
}
if !exists {
return "", status.Errorf(status.NotFound, "account %s does not exist", claims.AccountId)
}
return claims.AccountId, nil
}
return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain)
} else if claims.AccountId != "" { } else if claims.AccountId != "" {
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
if err != nil { if err != nil {
@@ -2229,7 +2309,11 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac
routes := make(map[route.ID]*route.Route) routes := make(map[route.ID]*route.Route)
setupKeys := map[string]*SetupKey{} setupKeys := map[string]*SetupKey{}
nameServersGroups := make(map[string]*nbdns.NameServerGroup) nameServersGroups := make(map[string]*nbdns.NameServerGroup)
users[userID] = NewOwnerUser(userID)
owner := NewOwnerUser(userID)
owner.AccountID = accountID
users[userID] = owner
dnsSettings := DNSSettings{ dnsSettings := DNSSettings{
DisabledManagementGroups: make([]string, 0), DisabledManagementGroups: make([]string, 0),
} }
@@ -2297,12 +2381,17 @@ func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
// separateGroups separates user's auto groups into non-JWT and JWT groups. // separateGroups separates user's auto groups into non-JWT and JWT groups.
// Returns the list of standard auto groups and a map of JWT auto groups, // Returns the list of standard auto groups and a map of JWT auto groups,
// where the keys are the group names and the values are the group IDs. // where the keys are the group names and the values are the group IDs.
func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([]string, map[string]string) { func separateGroups(autoGroups []string, allGroups []*nbgroup.Group) ([]string, map[string]string) {
newAutoGroups := make([]string, 0) newAutoGroups := make([]string, 0)
jwtAutoGroups := make(map[string]string) // map of group name to group ID jwtAutoGroups := make(map[string]string) // map of group name to group ID
allGroupsMap := make(map[string]*nbgroup.Group, len(allGroups))
for _, group := range allGroups {
allGroupsMap[group.ID] = group
}
for _, id := range autoGroups { for _, id := range autoGroups {
if group, ok := allGroups[id]; ok { if group, ok := allGroupsMap[id]; ok {
if group.Issued == nbgroup.GroupIssuedJWT { if group.Issued == nbgroup.GroupIssuedJWT {
jwtAutoGroups[group.Name] = id jwtAutoGroups[group.Name] = id
} else { } else {
@@ -2310,5 +2399,6 @@ func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([
} }
} }
} }
return newAutoGroups, jwtAutoGroups return newAutoGroups, jwtAutoGroups
} }

View File

@@ -633,7 +633,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain) accountID, err := manager.GetAccountIDByUserID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.Domain)
require.NoError(t, err, "create init user failed") require.NoError(t, err, "create init user failed")
initAccount, err := manager.Store.GetAccount(context.Background(), accountID) initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
@@ -671,17 +671,16 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
userId := "user-id" userId := "user-id"
domain := "test.domain" domain := "test.domain"
initAccount := newAccountWithId(context.Background(), "", userId, domain) _ = newAccountWithId(context.Background(), "", userId, domain)
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
accountID := initAccount.Id accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userId, accountID, domain)
require.NoError(t, err, "create init user failed") require.NoError(t, err, "create init user failed")
// as initAccount was created without account id we have to take the id after account initialization // as initAccount was created without account id we have to take the id after account initialization
// that happens inside the GetAccountIDByUserOrAccountID where the id is getting generated // that happens inside the GetAccountIDByUserID where the id is getting generated
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it // it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
initAccount, err = manager.Store.GetAccount(context.Background(), accountID) initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get init account failed") require.NoError(t, err, "get init account failed")
claims := jwtclaims.AuthorizationClaims{ claims := jwtclaims.AuthorizationClaims{
@@ -885,7 +884,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
} }
} }
func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { func TestAccountManager_GetAccountByUserID(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -894,7 +893,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
userId := "test_user" userId := "test_user"
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userId, "", "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -903,14 +902,13 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
return return
} }
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "") exists, err := manager.Store.AccountExists(context.Background(), LockingStrengthShare, accountID)
if err != nil { assert.NoError(t, err)
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountID) assert.True(t, exists, "expected to get existing account after creation using userid")
}
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", "", "") _, err = manager.GetAccountIDByUserID(context.Background(), "", "")
if err == nil { if err == nil {
t.Errorf("expected an error when user and account IDs are empty") t.Errorf("expected an error when user ID is empty")
} }
} }
@@ -1669,7 +1667,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
@@ -1684,7 +1682,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") _, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey() key, err := wgtypes.GenerateKey()
@@ -1696,7 +1694,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
}) })
require.NoError(t, err, "unable to add peer") require.NoError(t, err, "unable to add peer")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to get the account") require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID) account, err := manager.Store.GetAccount(context.Background(), accountID)
@@ -1742,7 +1740,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey() key, err := wgtypes.GenerateKey()
@@ -1770,7 +1768,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
}, },
} }
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") accountID, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to get the account") require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID) account, err := manager.Store.GetAccount(context.Background(), accountID)
@@ -1790,7 +1788,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") _, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey() key, err := wgtypes.GenerateKey()
@@ -1802,7 +1800,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
}) })
require.NoError(t, err, "unable to add peer") require.NoError(t, err, "unable to add peer")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to get the account") require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID) account, err := manager.Store.GetAccount(context.Background(), accountID)
@@ -1850,7 +1848,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
@@ -1861,9 +1859,6 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
assert.False(t, updated.Settings.PeerLoginExpirationEnabled) assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour) assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
require.NoError(t, err, "unable to get account by ID")
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "unable to get account settings") require.NoError(t, err, "unable to get account settings")
@@ -2199,8 +2194,12 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
} }
func TestAccount_SetJWTGroups(t *testing.T) { func TestAccount_SetJWTGroups(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
// create a new account // create a new account
account := &Account{ account := &Account{
Id: "accountID",
Peers: map[string]*nbpeer.Peer{ Peers: map[string]*nbpeer.Peer{
"peer1": {ID: "peer1", Key: "key1", UserID: "user1"}, "peer1": {ID: "peer1", Key: "key1", UserID: "user1"},
"peer2": {ID: "peer2", Key: "key2", UserID: "user1"}, "peer2": {ID: "peer2", Key: "key2", UserID: "user1"},
@@ -2211,62 +2210,120 @@ func TestAccount_SetJWTGroups(t *testing.T) {
Groups: map[string]*group.Group{ Groups: map[string]*group.Group{
"group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}}, "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}},
}, },
Settings: &Settings{GroupsPropagationEnabled: true}, Settings: &Settings{GroupsPropagationEnabled: true, JWTGroupsEnabled: true, JWTGroupsClaimName: "groups"},
Users: map[string]*User{ Users: map[string]*User{
"user1": {Id: "user1"}, "user1": {Id: "user1", AccountID: "accountID"},
"user2": {Id: "user2"}, "user2": {Id: "user2", AccountID: "accountID"},
}, },
} }
assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account")
t.Run("empty jwt groups", func(t *testing.T) { t.Run("empty jwt groups", func(t *testing.T) {
updated := account.SetJWTGroups("user1", []string{}) claims := jwtclaims.AuthorizationClaims{
assert.False(t, updated, "account should not be updated") UserId: "user1",
assert.Empty(t, account.Users["user1"].AutoGroups, "auto groups must be empty") Raw: jwt.MapClaims{"groups": []interface{}{}},
}
err := manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Empty(t, user.AutoGroups, "auto groups must be empty")
}) })
t.Run("jwt match existing api group", func(t *testing.T) { t.Run("jwt match existing api group", func(t *testing.T) {
updated := account.SetJWTGroups("user1", []string{"group1"}) claims := jwtclaims.AuthorizationClaims{
assert.False(t, updated, "account should not be updated") UserId: "user1",
assert.Equal(t, 0, len(account.Users["user1"].AutoGroups)) Raw: jwt.MapClaims{"groups": []interface{}{"group1"}},
assert.Equal(t, account.Groups["group1"].Issued, group.GroupIssuedAPI, "group should be api issued") }
err := manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
}) })
t.Run("jwt match existing api group in user auto groups", func(t *testing.T) { t.Run("jwt match existing api group in user auto groups", func(t *testing.T) {
account.Users["user1"].AutoGroups = []string{"group1"} account.Users["user1"].AutoGroups = []string{"group1"}
assert.NoError(t, manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, account.Users["user1"]))
updated := account.SetJWTGroups("user1", []string{"group1"}) claims := jwtclaims.AuthorizationClaims{
assert.False(t, updated, "account should not be updated") UserId: "user1",
assert.Equal(t, 1, len(account.Users["user1"].AutoGroups)) Raw: jwt.MapClaims{"groups": []interface{}{"group1"}},
assert.Equal(t, account.Groups["group1"].Issued, group.GroupIssuedAPI, "group should be api issued") }
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
}) })
t.Run("add jwt group", func(t *testing.T) { t.Run("add jwt group", func(t *testing.T) {
updated := account.SetJWTGroups("user1", []string{"group1", "group2"}) claims := jwtclaims.AuthorizationClaims{
assert.True(t, updated, "account should be updated") UserId: "user1",
assert.Len(t, account.Groups, 2, "new group should be added") Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group2"}},
assert.Len(t, account.Users["user1"].AutoGroups, 2, "new group should be added") }
assert.Contains(t, account.Groups, account.Users["user1"].AutoGroups[0], "groups must contain group2 from user groups") err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
}) })
t.Run("existed group not update", func(t *testing.T) { t.Run("existed group not update", func(t *testing.T) {
updated := account.SetJWTGroups("user1", []string{"group2"}) claims := jwtclaims.AuthorizationClaims{
assert.False(t, updated, "account should not be updated") UserId: "user1",
assert.Len(t, account.Groups, 2, "groups count should not be changed") Raw: jwt.MapClaims{"groups": []interface{}{"group2"}},
}
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
}) })
t.Run("add new group", func(t *testing.T) { t.Run("add new group", func(t *testing.T) {
updated := account.SetJWTGroups("user2", []string{"group1", "group3"}) claims := jwtclaims.AuthorizationClaims{
assert.True(t, updated, "account should be updated") UserId: "user2",
assert.Len(t, account.Groups, 3, "new group should be added") Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group3"}},
assert.Len(t, account.Users["user2"].AutoGroups, 1, "new group should be added") }
assert.Contains(t, account.Groups, account.Users["user2"].AutoGroups[0], "groups must contain group3 from user groups") err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID")
assert.NoError(t, err)
assert.Len(t, groups, 3, "new group3 should be added")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user2")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "new group should be added")
}) })
t.Run("remove all JWT groups", func(t *testing.T) { t.Run("remove all JWT groups", func(t *testing.T) {
updated := account.SetJWTGroups("user1", []string{}) claims := jwtclaims.AuthorizationClaims{
assert.True(t, updated, "account should be updated") UserId: "user1",
assert.Len(t, account.Users["user1"].AutoGroups, 1, "only non-JWT groups should remain") Raw: jwt.MapClaims{"groups": []interface{}{}},
assert.Contains(t, account.Users["user1"].AutoGroups, "group1", " group1 should still be present") }
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain")
assert.Contains(t, user.AutoGroups, "group1", " group1 should still be present")
}) })
} }

View File

@@ -27,7 +27,8 @@ type MockAccountManager struct {
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountIDByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (string, error) AccountExistsFunc func(ctx context.Context, accountID string) (bool, error)
GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
@@ -58,7 +59,7 @@ type MockAccountManager struct {
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerSSHKeyFunc func(ctx context.Context, peerID string, sshKey string) error UpdatePeerSSHKeyFunc func(ctx context.Context, peerID string, sshKey string) error
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups,accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error
DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error
@@ -194,14 +195,22 @@ func (am *MockAccountManager) CreateSetupKey(
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented") return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
} }
// GetAccountIDByUserOrAccountID mock implementation of GetAccountIDByUserOrAccountID from server.AccountManager interface // AccountExists mock implementation of AccountExists from server.AccountManager interface
func (am *MockAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userId, accountId, domain string) (string, error) { func (am *MockAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) {
if am.GetAccountIDByUserOrAccountIdFunc != nil { if am.AccountExistsFunc != nil {
return am.GetAccountIDByUserOrAccountIdFunc(ctx, userId, accountId, domain) return am.AccountExistsFunc(ctx, accountID)
}
return false, status.Errorf(codes.Unimplemented, "method AccountExists is not implemented")
}
// GetAccountIDByUserID mock implementation of GetAccountIDByUserID from server.AccountManager interface
func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId, domain string) (string, error) {
if am.GetAccountIDByUserIdFunc != nil {
return am.GetAccountIDByUserIdFunc(ctx, userId, domain)
} }
return "", status.Errorf( return "", status.Errorf(
codes.Unimplemented, codes.Unimplemented,
"method GetAccountIDByUserOrAccountID is not implemented", "method GetAccountIDByUserID is not implemented",
) )
} }
@@ -444,7 +453,7 @@ func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface // CreateRoute mock implementation of CreateRoute from server.AccountManager interface
func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
if am.CreateRouteFunc != nil { if am.CreateRouteFunc != nil {
return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups,accessControlGroupID, enabled, userID, keepRoute) return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, accessControlGroupID, enabled, userID, keepRoute)
} }
return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented") return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented")
} }

View File

@@ -378,15 +378,26 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
Create(&usersToSave).Error Create(&usersToSave).Error
} }
// SaveGroups saves the given list of groups to the database. // SaveUser saves the given user to the database.
// It updates existing groups if a conflict occurs. func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error {
func (s *SqlStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error { result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user)
groupsToSave := make([]nbgroup.Group, 0, len(groups)) if result.Error != nil {
for _, group := range groups { return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error)
group.AccountID = accountID
groupsToSave = append(groupsToSave, *group)
} }
return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&groupsToSave).Error return nil
}
// SaveGroups saves the given list of groups to the database.
func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error {
if len(groups) == 0 {
return nil
}
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
}
return nil
} }
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore // DeleteHashedPAT2TokenIDIndex is noop in SqlStore
@@ -1021,6 +1032,11 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
return nil return nil
} }
// GetUserPeers retrieves peers for a user.
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
return getRecords[*nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID)
}
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil { if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account") return status.Errorf(status.Internal, "issue adding peer to account")
@@ -1127,6 +1143,15 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
return &group, nil return &group, nil
} }
// SaveGroup saves a group to the store.
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error)
}
return nil
}
// GetAccountPolicies retrieves policies for an account. // GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID) return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)

View File

@@ -1185,3 +1185,33 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 2, setupKey.UsedTimes) assert.Equal(t, 2, setupKey.UsedTimes)
} }
func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
t.Cleanup(cleanup)
if err != nil {
t.Fatal(err)
}
group := &nbgroup.Group{
ID: "group-id",
AccountID: "account-id",
Name: "group-name",
Issued: "api",
Peers: nil,
}
err = store.ExecuteInTransaction(context.Background(), func(transaction Store) error {
err := transaction.SaveGroup(context.Background(), LockingStrengthUpdate, group)
if err != nil {
t.Fatal("failed to save group")
return err
}
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID)
if err != nil {
t.Fatal("failed to get group")
return err
}
t.Logf("group: %v", group)
return nil
})
assert.NoError(t, err)
}

View File

@@ -60,6 +60,7 @@ type Store interface {
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
SaveUsers(accountID string, users map[string]*User) error SaveUsers(accountID string, users map[string]*User) error
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteHashedPAT2TokenIDIndex(hashedToken string) error
@@ -68,7 +69,8 @@ type Store interface {
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
@@ -82,6 +84,7 @@ type Store interface {
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error

View File

@@ -8,14 +8,14 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/integration_reference"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
) )
const ( const (
@@ -1254,6 +1254,74 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun
return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, nil return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, nil
} }
// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them.
func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[string]*nbgroup.Group, peers []*nbpeer.Peer, groupsToAdd,
groupsToRemove []string) (groupsToUpdate []*nbgroup.Group, err error) {
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
return
}
userPeerIDMap := make(map[string]struct{}, len(peers))
for _, peer := range peers {
userPeerIDMap[peer.ID] = struct{}{}
}
for _, gid := range groupsToAdd {
group, ok := accountGroups[gid]
if !ok {
return nil, errors.New("group not found")
}
addUserPeersToGroup(userPeerIDMap, group)
groupsToUpdate = append(groupsToUpdate, group)
}
for _, gid := range groupsToRemove {
group, ok := accountGroups[gid]
if !ok {
return nil, errors.New("group not found")
}
removeUserPeersFromGroup(userPeerIDMap, group)
groupsToUpdate = append(groupsToUpdate, group)
}
return groupsToUpdate, nil
}
// addUserPeersToGroup adds the user's peers to the group.
func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) {
groupPeers := make(map[string]struct{}, len(group.Peers))
for _, pid := range group.Peers {
groupPeers[pid] = struct{}{}
}
for pid := range userPeerIDs {
groupPeers[pid] = struct{}{}
}
group.Peers = make([]string, 0, len(groupPeers))
for pid := range groupPeers {
group.Peers = append(group.Peers, pid)
}
}
// removeUserPeersFromGroup removes user's peers from the group.
func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) {
// skip removing peers from group All
if group.Name == "All" {
return
}
updatedPeers := make([]string, 0, len(group.Peers))
for _, pid := range group.Peers {
if _, found := userPeerIDs[pid]; !found {
updatedPeers = append(updatedPeers, pid)
}
}
group.Peers = updatedPeers
}
func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) { func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {
for _, user := range userData { for _, user := range userData {
if user.ID == userID { if user.ID == userID {

View File

@@ -813,10 +813,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
} }
accID, err := am.GetAccountIDByUserOrAccountID(context.Background(), "", account.Id, "") acc, err := am.Store.GetAccount(context.Background(), account.Id)
assert.NoError(t, err)
acc, err := am.Store.GetAccount(context.Background(), accID)
assert.NoError(t, err) assert.NoError(t, err)
for _, id := range tc.expectedDeleted { for _, id := range tc.expectedDeleted {