diff --git a/management/server/group.go b/management/server/group.go index 758b28b76..ca9b042b6 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -598,15 +598,6 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountI return false, nil } -func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []string) bool { - for _, groupID := range groupIDs { - if group, exists := account.Groups[groupID]; exists && group.HasPeers() { - return true - } - } - return false -} - // anyGroupHasPeers checks if any of the given groups in the account have peers. func anyGroupHasPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) { groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, groupIDs) diff --git a/management/server/route.go b/management/server/route.go index ecb562645..0c02d991e 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -9,6 +9,7 @@ import ( "strings" "unicode/utf8" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/rs/xid" log "github.com/sirupsen/logrus" @@ -52,17 +53,46 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - return am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID) + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() + } + + return am.Store.GetRouteByID(ctx, LockingStrengthShare, accountID, string(routeID)) +} + +// GetRoutesByPrefixOrDomains return list of routes by account and route prefix +func getRoutesByPrefixOrDomains(ctx context.Context, transaction Store, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) { + accountRoutes, err := transaction.GetAccountRoutes(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + routes := make([]*route.Route, 0) + for _, r := range accountRoutes { + dynamic := r.IsDynamic() + if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() || + !dynamic && r.Network.String() == prefix.String() { + routes = append(routes, r) + } + } + + return routes, nil } // checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. -func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error { +func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction Store, accountID string, checkRoute *route.Route, groupsMap map[string]*nbgroup.Group) error { // routes can have both peer and peer_groups - routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains) + prefix := checkRoute.Network + domains := checkRoute.Domains + + routesWithPrefix, err := getRoutesByPrefixOrDomains(ctx, transaction, accountID, prefix, domains) + if err != nil { + return err + } // lets remember all the peers and the peer groups from routesWithPrefix seenPeers := make(map[string]bool) @@ -71,18 +101,24 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account for _, prefixRoute := range routesWithPrefix { // we skip route(s) with the same network ID as we want to allow updating of the existing route // when creating a new route routeID is newly generated so nothing will be skipped - if routeID == prefixRoute.ID { + if checkRoute.ID == prefixRoute.ID { continue } if prefixRoute.Peer != "" { seenPeers[string(prefixRoute.ID)] = true } + + peerGroupsMap, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, prefixRoute.PeerGroups) + if err != nil { + return err + } + for _, groupID := range prefixRoute.PeerGroups { seenPeerGroups[groupID] = true - group := account.GetGroup(groupID) - if group == nil { + group, ok := peerGroupsMap[groupID] + if !ok || group == nil { return status.Errorf( status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist", getRouteDescriptor(prefix, domains), groupID, @@ -95,12 +131,13 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account } } - if peerID != "" { + if peerID := checkRoute.Peer; peerID != "" { // check that peerID exists and is not in any route as single peer or part of the group - peer := account.GetPeer(peerID) - if peer == nil { + _, err = transaction.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID) + if err != nil { return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) } + if _, ok := seenPeers[peerID]; ok { return status.Errorf(status.AlreadyExists, "failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID) @@ -108,9 +145,8 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account } // check that peerGroupIDs are not in any route peerGroups list - for _, groupID := range peerGroupIDs { - group := account.GetGroup(groupID) // we validated the group existence before entering this function, no need to check again. - + for _, groupID := range checkRoute.PeerGroups { + group := groupsMap[groupID] // we validated the group existence before entering this function, no need to check again. if _, ok := seenPeerGroups[groupID]; ok { return status.Errorf( status.AlreadyExists, "failed to add route with %s - peer group %s already has this route", @@ -118,12 +154,18 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account } // check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix + peersMap, err := transaction.GetPeersByIDs(ctx, LockingStrengthShare, accountID, group.Peers) + if err != nil { + return err + } + for _, id := range group.Peers { if _, ok := seenPeers[id]; ok { - peer := account.GetPeer(id) - if peer == nil { - return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) + peer, ok := peersMap[id] + if !ok || peer == nil { + return status.Errorf(status.InvalidArgument, "peer with ID %s not found", id) } + return status.Errorf(status.AlreadyExists, "failed to add route with %s - peer %s from the group %s already has this route", getRouteDescriptor(prefix, domains), peer.Name, group.Name) @@ -146,104 +188,63 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - // Do not allow non-Linux peers - if peer := account.GetPeer(peerID); peer != nil { - if peer.Meta.GoOS != "linux" { - return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + var newRoute *route.Route + var updateAccountPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + newRoute = &route.Route{ + ID: route.ID(xid.New().String()), + AccountID: accountID, + Network: prefix, + Domains: domains, + KeepRoute: keepRoute, + NetID: netID, + Description: description, + Peer: peerID, + PeerGroups: peerGroupIDs, + NetworkType: networkType, + Masquerade: masquerade, + Metric: metric, + Enabled: enabled, + Groups: groups, + AccessControlGroups: accessControlGroupIDs, } - } - if len(domains) > 0 && prefix.IsValid() { - return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") - } + if err = validateRoute(ctx, transaction, accountID, newRoute); err != nil { + return err + } - if len(domains) == 0 && !prefix.IsValid() { - return nil, status.Errorf(status.InvalidArgument, "invalid Prefix") - } - - if len(domains) > 0 { - prefix = getPlaceholderIP() - } - - if peerID != "" && len(peerGroupIDs) != 0 { - return nil, status.Errorf( - status.InvalidArgument, - "peer with ID %s and peers group %s should not be provided at the same time", - peerID, peerGroupIDs) - } - - var newRoute route.Route - newRoute.ID = route.ID(xid.New().String()) - - if len(peerGroupIDs) > 0 { - err = validateGroups(peerGroupIDs, account.Groups) + updateAccountPeers, err = areRouteChangesAffectPeers(ctx, am.Store, newRoute) if err != nil { - return nil, err + return err } - } - if len(accessControlGroupIDs) > 0 { - err = validateGroups(accessControlGroupIDs, account.Groups) - if err != nil { - return nil, err + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err } - } - err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains) + return transaction.SaveRoute(ctx, LockingStrengthUpdate, newRoute) + }) if err != nil { return nil, err } - if metric < route.MinMetric || metric > route.MaxMetric { - return nil, status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric) - } - - if utf8.RuneCountInString(string(netID)) > route.MaxNetIDChar || netID == "" { - return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) - } - - err = validateGroups(groups, account.Groups) - if err != nil { - return nil, err - } - - newRoute.Peer = peerID - newRoute.PeerGroups = peerGroupIDs - newRoute.Network = prefix - newRoute.Domains = domains - newRoute.NetworkType = networkType - newRoute.Description = description - newRoute.NetID = netID - newRoute.Masquerade = masquerade - newRoute.Metric = metric - newRoute.Enabled = enabled - newRoute.Groups = groups - newRoute.KeepRoute = keepRoute - newRoute.AccessControlGroups = accessControlGroupIDs - - if account.Routes == nil { - account.Routes = make(map[route.ID]*route.Route) - } - - account.Routes[newRoute.ID] = &newRoute - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return nil, err - } - - if am.isRouteChangeAffectPeers(account, &newRoute) { - am.updateAccountPeers(ctx, accountID) - } - am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) - return &newRoute, nil + if updateAccountPeers { + am.updateAccountPeers(ctx, accountID) + } + + return newRoute, nil } // SaveRoute saves route @@ -251,10 +252,147 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - if routeToSave == nil { - return status.Errorf(status.InvalidArgument, "route provided is nil") + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return err } + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + + var oldRoute *route.Route + var oldRouteAffectsPeers bool + var newRouteAffectsPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = validateRoute(ctx, transaction, accountID, routeToSave); err != nil { + return err + } + + oldRoute, err = transaction.GetRouteByID(ctx, LockingStrengthUpdate, accountID, string(routeToSave.ID)) + if err != nil { + return err + } + + oldRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, oldRoute) + if err != nil { + return err + } + + newRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, routeToSave) + if err != nil { + return err + } + routeToSave.AccountID = accountID + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveRoute(ctx, LockingStrengthUpdate, routeToSave) + }) + if err != nil { + return err + } + + am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) + + if oldRouteAffectsPeers || newRouteAffectsPeers { + am.updateAccountPeers(ctx, accountID) + } + + return nil +} + +// DeleteRoute deletes route with routeID +func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return err + } + + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + + var route *route.Route + var updateAccountPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + route, err = transaction.GetRouteByID(ctx, LockingStrengthUpdate, accountID, string(routeID)) + if err != nil { + return err + } + + updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, route) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.DeleteRoute(ctx, LockingStrengthUpdate, accountID, string(routeID)) + }) + + am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta()) + + if updateAccountPeers { + am.updateAccountPeers(ctx, accountID) + } + + return nil +} + +// ListRoutes returns a list of routes from account +func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return nil, err + } + + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() + } + + return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID) +} + +func validateRoute(ctx context.Context, transaction Store, accountID string, routeToSave *route.Route) error { + if err := validateRouteProperties(routeToSave); err != nil { + return err + } + + if routeToSave.Peer != "" { + peer, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, routeToSave.Peer) + if err != nil { + return err + } + + if peer.Meta.GoOS != "linux" { + return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes") + } + } + + groupsMap, err := validateRouteGroups(ctx, transaction, accountID, routeToSave) + if err != nil { + return err + } + + return checkRoutePrefixOrDomainsExistForPeers(ctx, transaction, accountID, routeToSave, groupsMap) +} + +// Helper to validate route properties. +func validateRouteProperties(routeToSave *route.Route) error { if routeToSave.Metric < route.MinMetric || routeToSave.Metric > route.MaxMetric { return status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric) } @@ -263,18 +401,6 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) } - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - - // Do not allow non-Linux peers - if peer := account.GetPeer(routeToSave.Peer); peer != nil { - if peer.Meta.GoOS != "linux" { - return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes") - } - } - if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() { return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") } @@ -291,89 +417,34 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time") } - if len(routeToSave.PeerGroups) > 0 { - err = validateGroups(routeToSave.PeerGroups, account.Groups) - if err != nil { - return err - } - } - - if len(routeToSave.AccessControlGroups) > 0 { - err = validateGroups(routeToSave.AccessControlGroups, account.Groups) - if err != nil { - return err - } - } - - err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains) - if err != nil { - return err - } - - err = validateGroups(routeToSave.Groups, account.Groups) - if err != nil { - return err - } - - oldRoute := account.Routes[routeToSave.ID] - account.Routes[routeToSave.ID] = routeToSave - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - - if am.isRouteChangeAffectPeers(account, oldRoute) || am.isRouteChangeAffectPeers(account, routeToSave) { - am.updateAccountPeers(ctx, accountID) - } - - am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) - return nil } -// DeleteRoute deletes route with routeID -func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - - routy := account.Routes[routeID] - if routy == nil { - return status.Errorf(status.NotFound, "route with ID %s doesn't exist", routeID) - } - delete(account.Routes, routeID) - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - - am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) - - if am.isRouteChangeAffectPeers(account, routy) { - am.updateAccountPeers(ctx, accountID) - } - - return nil -} - -// ListRoutes returns a list of routes from account -func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +// validateRouteGroups validates the route groups and returns the validated groups map. +func validateRouteGroups(ctx context.Context, transaction Store, accountID string, routeToSave *route.Route) (map[string]*nbgroup.Group, error) { + groupsToValidate := slices.Concat(routeToSave.Groups, routeToSave.PeerGroups, routeToSave.AccessControlGroups) + groupsMap, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, groupsToValidate) if err != nil { return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") + if len(routeToSave.PeerGroups) > 0 { + if err = validateGroups(routeToSave.PeerGroups, groupsMap); err != nil { + return nil, err + } } - return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID) + if len(routeToSave.AccessControlGroups) > 0 { + if err = validateGroups(routeToSave.AccessControlGroups, groupsMap); err != nil { + return nil, err + } + } + + if err = validateGroups(routeToSave.Groups, groupsMap); err != nil { + return nil, err + } + + return groupsMap, nil } func toProtocolRoute(route *route.Route) *proto.Route { @@ -649,8 +720,21 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo { return &portInfo } -// isRouteChangeAffectPeers checks if a given route affects peers by determining -// if it has a routing peer, distribution, or peer groups that include peers -func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *Account, route *route.Route) bool { - return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" +// areRouteChangesAffectPeers checks if a given route affects peers by determining +// if it has a routing peer, distribution, or peer groups that include peers. +func areRouteChangesAffectPeers(ctx context.Context, transaction Store, route *route.Route) (bool, error) { + if route.Peer != "" { + return true, nil + } + + hasPeers, err := anyGroupHasPeers(ctx, transaction, route.AccountID, route.Groups) + if err != nil { + return false, err + } + + if hasPeers { + return true, nil + } + + return anyGroupHasPeers(ctx, transaction, route.AccountID, route.PeerGroups) } diff --git a/route/route.go b/route/route.go index e23801e6e..1428acc3d 100644 --- a/route/route.go +++ b/route/route.go @@ -88,18 +88,18 @@ type Route struct { // AccountID is a reference to Account that this object belongs AccountID string `gorm:"index"` // Network and Domains are mutually exclusive - Network netip.Prefix `gorm:"serializer:json"` - Domains domain.List `gorm:"serializer:json"` - KeepRoute bool - NetID NetID - Description string - Peer string - PeerGroups []string `gorm:"serializer:json"` - NetworkType NetworkType - Masquerade bool - Metric int - Enabled bool - Groups []string `gorm:"serializer:json"` + Network netip.Prefix `gorm:"serializer:json"` + Domains domain.List `gorm:"serializer:json"` + KeepRoute bool + NetID NetID + Description string + Peer string + PeerGroups []string `gorm:"serializer:json"` + NetworkType NetworkType + Masquerade bool + Metric int + Enabled bool + Groups []string `gorm:"serializer:json"` AccessControlGroups []string `gorm:"serializer:json"` } @@ -111,19 +111,20 @@ func (r *Route) EventMeta() map[string]any { // Copy copies a route object func (r *Route) Copy() *Route { route := &Route{ - ID: r.ID, - Description: r.Description, - NetID: r.NetID, - Network: r.Network, - Domains: slices.Clone(r.Domains), - KeepRoute: r.KeepRoute, - NetworkType: r.NetworkType, - Peer: r.Peer, - PeerGroups: slices.Clone(r.PeerGroups), - Metric: r.Metric, - Masquerade: r.Masquerade, - Enabled: r.Enabled, - Groups: slices.Clone(r.Groups), + ID: r.ID, + AccountID: r.AccountID, + Description: r.Description, + NetID: r.NetID, + Network: r.Network, + Domains: slices.Clone(r.Domains), + KeepRoute: r.KeepRoute, + NetworkType: r.NetworkType, + Peer: r.Peer, + PeerGroups: slices.Clone(r.PeerGroups), + Metric: r.Metric, + Masquerade: r.Masquerade, + Enabled: r.Enabled, + Groups: slices.Clone(r.Groups), AccessControlGroups: slices.Clone(r.AccessControlGroups), } return route @@ -149,7 +150,7 @@ func (r *Route) IsEqual(other *Route) bool { other.Masquerade == r.Masquerade && other.Enabled == r.Enabled && slices.Equal(r.Groups, other.Groups) && - slices.Equal(r.PeerGroups, other.PeerGroups)&& + slices.Equal(r.PeerGroups, other.PeerGroups) && slices.Equal(r.AccessControlGroups, other.AccessControlGroups) }