diff --git a/README.md b/README.md index 1d2a976c2..c3b365694 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,9 @@
+ + +
@@ -29,13 +32,13 @@
See
Documentation
- Join our Slack channel + Join our Slack channel or our Community forum

- - New: NetBird Kubernetes Operator + + New: NetBird terraform provider

diff --git a/client/firewall/uspfilter/nat.go b/client/firewall/uspfilter/nat.go index 9a7fa4d3d..61292ad18 100644 --- a/client/firewall/uspfilter/nat.go +++ b/client/firewall/uspfilter/nat.go @@ -32,11 +32,13 @@ func ipv4Checksum(header []byte) uint16 { var sum1, sum2 uint32 + // Parallel processing - unroll and compute two sums simultaneously sum1 += uint32(binary.BigEndian.Uint16(header[0:2])) sum2 += uint32(binary.BigEndian.Uint16(header[2:4])) sum1 += uint32(binary.BigEndian.Uint16(header[4:6])) sum2 += uint32(binary.BigEndian.Uint16(header[6:8])) sum1 += uint32(binary.BigEndian.Uint16(header[8:10])) + // Skip checksum field at [10:12] sum2 += uint32(binary.BigEndian.Uint16(header[12:14])) sum1 += uint32(binary.BigEndian.Uint16(header[14:16])) sum2 += uint32(binary.BigEndian.Uint16(header[16:18])) @@ -44,6 +46,7 @@ func ipv4Checksum(header []byte) uint16 { sum := sum1 + sum2 + // Handle remaining bytes for headers > 20 bytes for i := 20; i < len(header)-1; i += 2 { sum += uint32(binary.BigEndian.Uint16(header[i : i+2])) } @@ -52,6 +55,7 @@ func ipv4Checksum(header []byte) uint16 { sum += uint32(header[len(header)-1]) << 8 } + // Optimized carry fold - single iteration handles most cases sum = (sum & 0xFFFF) + (sum >> 16) if sum > 0xFFFF { sum++ @@ -65,6 +69,7 @@ func icmpChecksum(data []byte) uint16 { var sum1, sum2, sum3, sum4 uint32 i := 0 + // Process 16 bytes at once with 4 parallel accumulators for i <= len(data)-16 { sum1 += uint32(binary.BigEndian.Uint16(data[i : i+2])) sum2 += uint32(binary.BigEndian.Uint16(data[i+2 : i+4])) @@ -79,6 +84,7 @@ func icmpChecksum(data []byte) uint16 { sum := sum1 + sum2 + sum3 + sum4 + // Handle remaining bytes for i < len(data)-1 { sum += uint32(binary.BigEndian.Uint16(data[i : i+2])) i += 2 @@ -255,6 +261,7 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr m.dnatMutex.Lock() defer m.dnatMutex.Unlock() + // Initialize both maps together if either is nil if m.dnatMappings == nil || m.dnatBiMap == nil { m.dnatMappings = make(map[netip.Addr]netip.Addr) m.dnatBiMap = newBiDNATMap() @@ -482,12 +489,14 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) { func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 { sum := uint32(^oldChecksum) + // Fast path for IPv4 addresses (4 bytes) - most common case if len(oldBytes) == 4 && len(newBytes) == 4 { sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2])) sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4])) sum += uint32(binary.BigEndian.Uint16(newBytes[0:2])) sum += uint32(binary.BigEndian.Uint16(newBytes[2:4])) } else { + // Fallback for other lengths for i := 0; i < len(oldBytes)-1; i += 2 { sum += uint32(^binary.BigEndian.Uint16(oldBytes[i : i+2])) } diff --git a/client/internal/login.go b/client/internal/login.go index 677b7431a..8f9440fdd 100644 --- a/client/internal/login.go +++ b/client/internal/login.go @@ -156,7 +156,7 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm. ) loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels) if err != nil { - log.Errorf("failed registering peer %v,%s", err, validSetupKey.String()) + log.Errorf("failed registering peer %v", err) return nil, err } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index af98986b1..e0974ab2a 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -144,6 +144,7 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) { cr := m.initialClientRoutes(config.InitialRoutes) routesForComparison := slices.Clone(cr) + if config.DNSFeatureFlag { m.fakeIPManager = fakeip.NewManager() @@ -158,7 +159,7 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) { cr = append(cr, fakeIPRoute) } - m.notifier.SetInitialClientRoutes(cr, routesForComparison, config.DNSFeatureFlag) + m.notifier.SetInitialClientRoutes(cr, routesForComparison) } func (m *DefaultManager) setupRefCounters(useNoop bool) { diff --git a/client/internal/routemanager/notifier/notifier.go b/client/internal/routemanager/notifier/notifier.go index b9e8d92d6..e69de29bb 100644 --- a/client/internal/routemanager/notifier/notifier.go +++ b/client/internal/routemanager/notifier/notifier.go @@ -1,145 +0,0 @@ -package notifier - -import ( - "net/netip" - "runtime" - "slices" - "sort" - "strings" - "sync" - - "github.com/netbirdio/netbird/client/internal/listener" - "github.com/netbirdio/netbird/route" -) - -type Notifier struct { - initialRoutes []*route.Route - routesForComparison []*route.Route - dnsFeatureFlag bool - - listener listener.NetworkChangeListener - listenerMux sync.Mutex -} - -func NewNotifier() *Notifier { - return &Notifier{} -} - -func (n *Notifier) SetListener(listener listener.NetworkChangeListener) { - n.listenerMux.Lock() - defer n.listenerMux.Unlock() - n.listener = listener -} - -func (n *Notifier) SetInitialClientRoutes(allRoutes []*route.Route, routesForComparison []*route.Route, dnsFeatureFlag bool) { - n.dnsFeatureFlag = dnsFeatureFlag - n.initialRoutes = allRoutes - n.routesForComparison = routesForComparison -} - -func (n *Notifier) OnNewRoutes(idMap route.HAMap) { - if runtime.GOOS != "android" { - return - } - - var newRoutes []*route.Route - for _, routes := range idMap { - newRoutes = append(newRoutes, routes...) - } - - if !n.hasRouteDiff(n.routesForComparison, newRoutes) { - return - } - - n.routesForComparison = newRoutes - n.notify() -} - -// OnNewPrefixes is called from iOS only -func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) { - newNets := make([]string, 0) - for _, prefix := range prefixes { - newNets = append(newNets, prefix.String()) - } - - sort.Strings(newNets) - - currentNets := n.routesToStrings(n.routesForComparison) - if slices.Equal(currentNets, newNets) { - return - } - - n.notify() -} - -func (n *Notifier) notify() { - n.listenerMux.Lock() - defer n.listenerMux.Unlock() - if n.listener == nil { - return - } - - routeStrings := n.routesToStrings(n.routesForComparison) - sort.Strings(routeStrings) - go func(l listener.NetworkChangeListener) { - l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(routeStrings, n.routesForComparison), ",")) - }(n.listener) -} - -// hasRouteDiff compares two route slices for differences -func (n *Notifier) hasRouteDiff(a []*route.Route, b []*route.Route) bool { - aFiltered := n.filterRoutes(a) - bFiltered := n.filterRoutes(b) - - slices.SortFunc(aFiltered, func(x, y *route.Route) int { - return strings.Compare(x.NetString(), y.NetString()) - }) - slices.SortFunc(bFiltered, func(x, y *route.Route) int { - return strings.Compare(x.NetString(), y.NetString()) - }) - - return !slices.EqualFunc(aFiltered, bFiltered, func(x, y *route.Route) bool { - return x.NetString() == y.NetString() - }) -} - -// filterRoutes filters routes based on DNS feature flag -func (n *Notifier) filterRoutes(routes []*route.Route) []*route.Route { - filtered := make([]*route.Route, 0, len(routes)) - for _, r := range routes { - if r.IsDynamic() && !n.dnsFeatureFlag { - // this kind of dynamic route is not supported on android - continue - } - filtered = append(filtered, r) - } - return filtered -} - -// routesToStrings converts routes to string slice (caller should sort if needed) -func (n *Notifier) routesToStrings(routes []*route.Route) []string { - filtered := n.filterRoutes(routes) - nets := make([]string, 0, len(filtered)) - for _, r := range filtered { - nets = append(nets, r.NetString()) - } - return nets -} - -func (n *Notifier) GetInitialRouteRanges() []string { - initialStrings := n.routesToStrings(n.initialRoutes) - sort.Strings(initialStrings) - return n.addIPv6RangeIfNeeded(initialStrings, n.initialRoutes) -} - -// addIPv6RangeIfNeeded returns the input ranges with the default IPv6 range when there is an IPv4 default route. -func (n *Notifier) addIPv6RangeIfNeeded(inputRanges []string, routes []*route.Route) []string { - for _, r := range routes { - // we are intentionally adding the ipv6 default range in case of ipv4 default range - // to ensure that all traffic is managed by the tunnel interface on android - if r.Network.Addr().Is4() && r.Network.Bits() == 0 { - return append(slices.Clone(inputRanges), "::/0") - } - } - return inputRanges -} diff --git a/client/internal/routemanager/notifier/notifier_android.go b/client/internal/routemanager/notifier/notifier_android.go new file mode 100644 index 000000000..dec0af87c --- /dev/null +++ b/client/internal/routemanager/notifier/notifier_android.go @@ -0,0 +1,127 @@ +//go:build android + +package notifier + +import ( + "net/netip" + "slices" + "sort" + "strings" + "sync" + + "github.com/netbirdio/netbird/client/internal/listener" + "github.com/netbirdio/netbird/route" +) + +type Notifier struct { + initialRoutes []*route.Route + currentRoutes []*route.Route + + listener listener.NetworkChangeListener + listenerMux sync.Mutex +} + +func NewNotifier() *Notifier { + return &Notifier{} +} + +func (n *Notifier) SetListener(listener listener.NetworkChangeListener) { + n.listenerMux.Lock() + defer n.listenerMux.Unlock() + n.listener = listener +} + +func (n *Notifier) SetInitialClientRoutes(initialRoutes []*route.Route, routesForComparison []*route.Route) { + // initialRoutes contains fake IP block for interface configuration + filteredInitial := make([]*route.Route, 0) + for _, r := range initialRoutes { + if r.IsDynamic() { + continue + } + filteredInitial = append(filteredInitial, r) + } + n.initialRoutes = filteredInitial + + // routesForComparison excludes fake IP block for comparison with new routes + filteredComparison := make([]*route.Route, 0) + for _, r := range routesForComparison { + if r.IsDynamic() { + continue + } + filteredComparison = append(filteredComparison, r) + } + n.currentRoutes = filteredComparison +} + +func (n *Notifier) OnNewRoutes(idMap route.HAMap) { + var newRoutes []*route.Route + for _, routes := range idMap { + for _, r := range routes { + if r.IsDynamic() { + continue + } + newRoutes = append(newRoutes, r) + } + } + + if !n.hasRouteDiff(n.currentRoutes, newRoutes) { + return + } + + n.currentRoutes = newRoutes + n.notify() +} + +func (n *Notifier) OnNewPrefixes([]netip.Prefix) { + // Not used on Android +} + +func (n *Notifier) notify() { + n.listenerMux.Lock() + defer n.listenerMux.Unlock() + if n.listener == nil { + return + } + + routeStrings := n.routesToStrings(n.currentRoutes) + sort.Strings(routeStrings) + go func(l listener.NetworkChangeListener) { + l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(routeStrings, n.currentRoutes), ",")) + }(n.listener) +} + +func (n *Notifier) routesToStrings(routes []*route.Route) []string { + nets := make([]string, 0, len(routes)) + for _, r := range routes { + nets = append(nets, r.NetString()) + } + return nets +} + +func (n *Notifier) hasRouteDiff(a []*route.Route, b []*route.Route) bool { + slices.SortFunc(a, func(x, y *route.Route) int { + return strings.Compare(x.NetString(), y.NetString()) + }) + slices.SortFunc(b, func(x, y *route.Route) int { + return strings.Compare(x.NetString(), y.NetString()) + }) + + return !slices.EqualFunc(a, b, func(x, y *route.Route) bool { + return x.NetString() == y.NetString() + }) +} + +func (n *Notifier) GetInitialRouteRanges() []string { + initialStrings := n.routesToStrings(n.initialRoutes) + sort.Strings(initialStrings) + return n.addIPv6RangeIfNeeded(initialStrings, n.initialRoutes) +} + +func (n *Notifier) addIPv6RangeIfNeeded(inputRanges []string, routes []*route.Route) []string { + for _, r := range routes { + if r.Network.Addr().Is4() && r.Network.Bits() == 0 { + return append(slices.Clone(inputRanges), "::/0") + } + } + return inputRanges +} diff --git a/client/internal/routemanager/notifier/notifier_ios.go b/client/internal/routemanager/notifier/notifier_ios.go new file mode 100644 index 000000000..bb125cfa4 --- /dev/null +++ b/client/internal/routemanager/notifier/notifier_ios.go @@ -0,0 +1,80 @@ +//go:build ios + +package notifier + +import ( + "net/netip" + "slices" + "sort" + "strings" + "sync" + + "github.com/netbirdio/netbird/client/internal/listener" + "github.com/netbirdio/netbird/route" +) + +type Notifier struct { + currentPrefixes []string + + listener listener.NetworkChangeListener + listenerMux sync.Mutex +} + +func NewNotifier() *Notifier { + return &Notifier{} +} + +func (n *Notifier) SetListener(listener listener.NetworkChangeListener) { + n.listenerMux.Lock() + defer n.listenerMux.Unlock() + n.listener = listener +} + +func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) { + // iOS doesn't care about initial routes +} + +func (n *Notifier) OnNewRoutes(route.HAMap) { + // Not used on iOS +} + +func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) { + newNets := make([]string, 0) + for _, prefix := range prefixes { + newNets = append(newNets, prefix.String()) + } + + sort.Strings(newNets) + + if slices.Equal(n.currentPrefixes, newNets) { + return + } + + n.currentPrefixes = newNets + n.notify() +} + +func (n *Notifier) notify() { + n.listenerMux.Lock() + defer n.listenerMux.Unlock() + if n.listener == nil { + return + } + + go func(l listener.NetworkChangeListener) { + l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(n.currentPrefixes), ",")) + }(n.listener) +} + +func (n *Notifier) GetInitialRouteRanges() []string { + return nil +} + +func (n *Notifier) addIPv6RangeIfNeeded(inputRanges []string) []string { + for _, r := range inputRanges { + if r == "0.0.0.0/0" { + return append(slices.Clone(inputRanges), "::/0") + } + } + return inputRanges +} diff --git a/client/internal/routemanager/notifier/notifier_other.go b/client/internal/routemanager/notifier/notifier_other.go new file mode 100644 index 000000000..77045b839 --- /dev/null +++ b/client/internal/routemanager/notifier/notifier_other.go @@ -0,0 +1,36 @@ +//go:build !android && !ios + +package notifier + +import ( + "net/netip" + + "github.com/netbirdio/netbird/client/internal/listener" + "github.com/netbirdio/netbird/route" +) + +type Notifier struct{} + +func NewNotifier() *Notifier { + return &Notifier{} +} + +func (n *Notifier) SetListener(listener listener.NetworkChangeListener) { + // Not used on non-mobile platforms +} + +func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) { + // Not used on non-mobile platforms +} + +func (n *Notifier) OnNewRoutes(idMap route.HAMap) { + // Not used on non-mobile platforms +} + +func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) { + // Not used on non-mobile platforms +} + +func (n *Notifier) GetInitialRouteRanges() []string { + return []string{} +} \ No newline at end of file diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go index 8caf22f81..106c520da 100644 --- a/client/internal/routemanager/systemops/systemops.go +++ b/client/internal/routemanager/systemops/systemops.go @@ -5,6 +5,7 @@ import ( "net" "net/netip" "sync" + "sync/atomic" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/routemanager/notifier" @@ -52,6 +53,9 @@ type SysOps struct { mu sync.Mutex // notifier is used to notify the system of route changes (also used on mobile) notifier *notifier.Notifier + // seq is an atomic counter for generating unique sequence numbers for route messages + //nolint:unused // only used on BSD systems + seq atomic.Uint32 } func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps { @@ -61,6 +65,11 @@ func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps { } } +//nolint:unused // only used on BSD systems +func (r *SysOps) getSeq() int { + return int(r.seq.Add(1)) +} + func (r *SysOps) validateRoute(prefix netip.Prefix) error { addr := prefix.Addr() diff --git a/client/internal/routemanager/systemops/systemops_unix.go b/client/internal/routemanager/systemops/systemops_unix.go index f284e131b..46e5ca915 100644 --- a/client/internal/routemanager/systemops/systemops_unix.go +++ b/client/internal/routemanager/systemops/systemops_unix.go @@ -108,7 +108,7 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next Type: action, Flags: unix.RTF_UP, Version: unix.RTM_VERSION, - Seq: 1, + Seq: r.getSeq(), } const numAddrs = unix.RTAX_NETMASK + 1 diff --git a/management/server/account.go b/management/server/account.go index 3b7359502..cd0c933f0 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -106,6 +106,18 @@ type DefaultAccountManager struct { disableDefaultPolicy bool } +func isUniqueConstraintError(err error) bool { + switch { + case strings.Contains(err.Error(), "(SQLSTATE 23505)"), + strings.Contains(err.Error(), "Error 1062 (23000)"), + strings.Contains(err.Error(), "UNIQUE constraint failed"): + return true + + default: + return false + } +} + // getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups. // Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups, // newly groups to create and an error if any occurred. @@ -1192,6 +1204,71 @@ func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID s return am.Store.GetAccountMeta(ctx, store.LockingStrengthShare, accountID) } +// GetAccountOnboarding retrieves the onboarding information for a specific account. +func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) { + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + + if !allowed { + return nil, status.NewPermissionDeniedError() + } + + onboarding, err := am.Store.GetAccountOnboarding(ctx, accountID) + if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() { + log.Errorf("failed to get account onboarding for accountssssssss %s: %v", accountID, err) + return nil, err + } + + if onboarding == nil { + onboarding = &types.AccountOnboarding{ + AccountID: accountID, + } + } + + return onboarding, nil +} + +func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) { + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update) + if err != nil { + return nil, fmt.Errorf("failed to validate user permissions: %w", err) + } + + if !allowed { + return nil, status.NewPermissionDeniedError() + } + + oldOnboarding, err := am.Store.GetAccountOnboarding(ctx, accountID) + if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() { + return nil, fmt.Errorf("failed to get account onboarding: %w", err) + } + + if oldOnboarding == nil { + oldOnboarding = &types.AccountOnboarding{ + AccountID: accountID, + } + } + + if newOnboarding == nil { + return oldOnboarding, nil + } + + if oldOnboarding.IsEqual(*newOnboarding) { + log.WithContext(ctx).Debugf("no changes in onboarding for account %s", accountID) + return oldOnboarding, nil + } + + newOnboarding.AccountID = accountID + err = am.Store.SaveAccountOnboarding(ctx, newOnboarding) + if err != nil { + return nil, fmt.Errorf("failed to update account onboarding: %w", err) + } + + return newOnboarding, nil +} + func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { if userAuth.UserId == "" { return "", "", errors.New(emptyUserID) @@ -1661,25 +1738,6 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction return false, nil } -func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, s store.Store, accountID string, peerHostName string) (string, error) { - existingLabels, err := s.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID) - if err != nil { - return "", fmt.Errorf("failed to get peer dns labels: %w", err) - } - - labelMap := ConvertSliceToMap(existingLabels) - newLabel, err := types.GetPeerHostLabel(peerHostName, labelMap) - if err != nil { - return "", fmt.Errorf("failed to get new host label: %w", err) - } - - if newLabel == "" { - return "", fmt.Errorf("failed to get new host label: %w", err) - } - - return newLabel, nil -} - func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) { allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read) if err != nil { @@ -1733,6 +1791,10 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string, dis PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, RoutingPeerDNSResolutionEnabled: true, }, + Onboarding: types.AccountOnboarding{ + OnboardingFlowPending: true, + SignupFormPending: true, + }, } if err := acc.AddAllGroup(disableDefaultPolicy); err != nil { diff --git a/management/server/account/manager.go b/management/server/account/manager.go index de5031c03..ed17fa5ec 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -39,6 +39,7 @@ type Manager interface { GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) + GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) AccountExists(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) @@ -89,6 +90,7 @@ type Manager interface { SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) + UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API GetAllConnectedPeers() (map[string]struct{}, error) diff --git a/management/server/account_test.go b/management/server/account_test.go index 7f319b81e..fcd40b082 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2623,11 +2623,11 @@ func TestAccount_SetJWTGroups(t *testing.T) { account := &types.Account{ Id: "accountID", Peers: map[string]*nbpeer.Peer{ - "peer1": {ID: "peer1", Key: "key1", UserID: "user1"}, - "peer2": {ID: "peer2", Key: "key2", UserID: "user1"}, - "peer3": {ID: "peer3", Key: "key3", UserID: "user1"}, - "peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, - "peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, + "peer1": {ID: "peer1", Key: "key1", UserID: "user1", IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"}, + "peer2": {ID: "peer2", Key: "key2", UserID: "user1", IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"}, + "peer3": {ID: "peer3", Key: "key3", UserID: "user1", IP: net.IP{3, 3, 3, 3}, DNSLabel: "peer3.domain.test"}, + "peer4": {ID: "peer4", Key: "key4", UserID: "user2", IP: net.IP{4, 4, 4, 4}, DNSLabel: "peer4.domain.test"}, + "peer5": {ID: "peer5", Key: "key5", UserID: "user2", IP: net.IP{5, 5, 5, 5}, DNSLabel: "peer5.domain.test"}, }, Groups: map[string]*types.Group{ "group1": {ID: "group1", Name: "group1", Issued: types.GroupIssuedAPI, Peers: []string{}}, @@ -3147,11 +3147,11 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { minMsPerOpCICD float64 maxMsPerOpCICD float64 }{ - {"Small", 50, 5, 7, 20, 10, 80}, + {"Small", 50, 5, 7, 20, 5, 80}, {"Medium", 500, 100, 5, 40, 30, 140}, {"Large", 5000, 200, 80, 120, 140, 390}, - {"Small single", 50, 10, 7, 20, 10, 80}, - {"Medium single", 500, 10, 5, 40, 20, 85}, + {"Small single", 50, 10, 7, 20, 6, 80}, + {"Medium single", 500, 10, 5, 40, 15, 85}, {"Large 5", 5000, 15, 80, 120, 80, 200}, } @@ -3343,11 +3343,11 @@ func TestPropagateUserGroupMemberships(t *testing.T) { account, err := manager.GetOrCreateAccountByUser(ctx, initiatorId, domain) require.NoError(t, err) - peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId} + peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"} err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer1) require.NoError(t, err) - peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId} + peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"} err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer2) require.NoError(t, err) @@ -3448,3 +3448,74 @@ func TestPropagateUserGroupMemberships(t *testing.T) { } }) } + +func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) { + manager, err := createManager(t) + require.NoError(t, err) + + account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") + require.NoError(t, err) + + t.Run("should return account onboarding when onboarding exist", func(t *testing.T) { + onboarding, err := manager.GetAccountOnboarding(context.Background(), account.Id, userID) + require.NoError(t, err) + require.NotNil(t, onboarding) + assert.Equal(t, account.Id, onboarding.AccountID) + assert.Equal(t, true, onboarding.OnboardingFlowPending) + assert.Equal(t, true, onboarding.SignupFormPending) + if onboarding.UpdatedAt.IsZero() { + t.Errorf("Onboarding was not retrieved from the store") + } + }) + + t.Run("should return account onboarding when onboard don't exist", func(t *testing.T) { + account.Id = "with-zero-onboarding" + account.Onboarding = types.AccountOnboarding{} + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + onboarding, err := manager.GetAccountOnboarding(context.Background(), account.Id, userID) + require.NoError(t, err) + require.NotNil(t, onboarding) + _, err = manager.Store.GetAccountOnboarding(context.Background(), account.Id) + require.Error(t, err, "should return error when onboarding is not set") + }) +} + +func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) { + manager, err := createManager(t) + require.NoError(t, err) + + account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") + require.NoError(t, err) + + onboarding := &types.AccountOnboarding{ + OnboardingFlowPending: true, + SignupFormPending: true, + } + + t.Run("update onboarding with no change", func(t *testing.T) { + updated, err := manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, onboarding) + require.NoError(t, err) + assert.Equal(t, onboarding.OnboardingFlowPending, updated.OnboardingFlowPending) + assert.Equal(t, onboarding.SignupFormPending, updated.SignupFormPending) + if updated.UpdatedAt.IsZero() { + t.Errorf("Onboarding was updated in the store") + } + }) + + onboarding.OnboardingFlowPending = false + onboarding.SignupFormPending = false + + t.Run("update onboarding", func(t *testing.T) { + updated, err := manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, onboarding) + require.NoError(t, err) + require.NotNil(t, updated) + assert.Equal(t, onboarding.OnboardingFlowPending, updated.OnboardingFlowPending) + assert.Equal(t, onboarding.SignupFormPending, updated.SignupFormPending) + }) + + t.Run("update onboarding with no onboarding", func(t *testing.T) { + _, err = manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, nil) + require.NoError(t, err) + }) +} diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 1c5ca9b04..f8c2b9854 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -60,6 +60,8 @@ components: description: Account creator type: string example: google-oauth2|277474792786460067937 + onboarding: + $ref: '#/components/schemas/AccountOnboarding' required: - id - settings @@ -67,6 +69,21 @@ components: - domain_category - created_at - created_by + - onboarding + AccountOnboarding: + type: object + properties: + signup_form_pending: + description: Indicates whether the account signup form is pending + type: boolean + example: true + onboarding_flow_pending: + description: Indicates whether the account onboarding flow is pending + type: boolean + example: false + required: + - signup_form_pending + - onboarding_flow_pending AccountSettings: type: object properties: @@ -153,6 +170,8 @@ components: properties: settings: $ref: '#/components/schemas/AccountSettings' + onboarding: + $ref: '#/components/schemas/AccountOnboarding' required: - settings User: diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index d27fd2a57..a9f17aab4 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -250,8 +250,9 @@ type Account struct { DomainCategory string `json:"domain_category"` // Id Account ID - Id string `json:"id"` - Settings AccountSettings `json:"settings"` + Id string `json:"id"` + Onboarding AccountOnboarding `json:"onboarding"` + Settings AccountSettings `json:"settings"` } // AccountExtraSettings defines model for AccountExtraSettings. @@ -266,9 +267,19 @@ type AccountExtraSettings struct { PeerApprovalEnabled bool `json:"peer_approval_enabled"` } +// AccountOnboarding defines model for AccountOnboarding. +type AccountOnboarding struct { + // OnboardingFlowPending Indicates whether the account onboarding flow is pending + OnboardingFlowPending bool `json:"onboarding_flow_pending"` + + // SignupFormPending Indicates whether the account signup form is pending + SignupFormPending bool `json:"signup_form_pending"` +} + // AccountRequest defines model for AccountRequest. type AccountRequest struct { - Settings AccountSettings `json:"settings"` + Onboarding *AccountOnboarding `json:"onboarding,omitempty"` + Settings AccountSettings `json:"settings"` } // AccountSettings defines model for AccountSettings. diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index dfc782b3f..ab59434d1 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -59,7 +59,13 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) { return } - resp := toAccountResponse(accountID, settings, meta) + onboarding, err := h.accountManager.GetAccountOnboarding(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + resp := toAccountResponse(accountID, settings, meta, onboarding) util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) } @@ -126,6 +132,20 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled } + var onboarding *types.AccountOnboarding + if req.Onboarding != nil { + onboarding = &types.AccountOnboarding{ + OnboardingFlowPending: req.Onboarding.OnboardingFlowPending, + SignupFormPending: req.Onboarding.SignupFormPending, + } + } + + updatedOnboarding, err := h.accountManager.UpdateAccountOnboarding(r.Context(), accountID, userID, onboarding) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings) if err != nil { util.WriteError(r.Context(), err, w) @@ -138,7 +158,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { return } - resp := toAccountResponse(accountID, updatedSettings, meta) + resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding) util.WriteJSONObject(r.Context(), w, &resp) } @@ -167,7 +187,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta) *api.Account { +func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding) *api.Account { jwtAllowGroups := settings.JWTAllowGroups if jwtAllowGroups == nil { jwtAllowGroups = []string{} @@ -188,6 +208,11 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A DnsDomain: &settings.DNSDomain, } + apiOnboarding := api.AccountOnboarding{ + OnboardingFlowPending: onboarding.OnboardingFlowPending, + SignupFormPending: onboarding.SignupFormPending, + } + if settings.Extra != nil { apiSettings.Extra = &api.AccountExtraSettings{ PeerApprovalEnabled: settings.Extra.PeerApprovalEnabled, @@ -203,5 +228,6 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A CreatedBy: meta.CreatedBy, Domain: meta.Domain, DomainCategory: meta.DomainCategory, + Onboarding: apiOnboarding, } } diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index a18798743..dbf0c22bc 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -54,6 +54,18 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler { GetAccountMetaFunc: func(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) { return account.GetMeta(), nil }, + GetAccountOnboardingFunc: func(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) { + return &types.AccountOnboarding{ + OnboardingFlowPending: true, + SignupFormPending: true, + }, nil + }, + UpdateAccountOnboardingFunc: func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) { + return &types.AccountOnboarding{ + OnboardingFlowPending: true, + SignupFormPending: true, + }, nil + }, }, settingsManager: settingsMockManager, } @@ -117,7 +129,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { expectedBody: true, requestType: http.MethodPut, requestPath: "/api/accounts/" + accountID, - requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": true}}"), + requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"), expectedStatus: http.StatusOK, expectedSettings: api.AccountSettings{ PeerLoginExpiration: 15552000, @@ -139,7 +151,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { expectedBody: true, requestType: http.MethodPut, requestPath: "/api/accounts/" + accountID, - requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true}}"), + requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"), expectedStatus: http.StatusOK, expectedSettings: api.AccountSettings{ PeerLoginExpiration: 15552000, @@ -161,7 +173,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { expectedBody: true, requestType: http.MethodPut, requestPath: "/api/accounts/" + accountID, - requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 554400,\"peer_login_expiration_enabled\": true,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"groups\",\"groups_propagation_enabled\":true,\"regular_users_view_blocked\":true}}"), + requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 554400,\"peer_login_expiration_enabled\": true,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"groups\",\"groups_propagation_enabled\":true,\"regular_users_view_blocked\":true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"), expectedStatus: http.StatusOK, expectedSettings: api.AccountSettings{ PeerLoginExpiration: 554400, @@ -178,12 +190,34 @@ func TestAccounts_AccountsHandler(t *testing.T) { expectedArray: false, expectedID: accountID, }, + { + name: "PutAccount OK without onboarding", + expectedBody: true, + requestType: http.MethodPut, + requestPath: "/api/accounts/" + accountID, + requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true}}"), + expectedStatus: http.StatusOK, + expectedSettings: api.AccountSettings{ + PeerLoginExpiration: 15552000, + PeerLoginExpirationEnabled: false, + GroupsPropagationEnabled: br(false), + JwtGroupsClaimName: sr("roles"), + JwtGroupsEnabled: br(true), + JwtAllowGroups: &[]string{"test"}, + RegularUsersViewBlocked: true, + RoutingPeerDnsResolutionEnabled: br(false), + LazyConnectionEnabled: br(false), + DnsDomain: sr(""), + }, + expectedArray: false, + expectedID: accountID, + }, { name: "Update account failure with high peer_login_expiration more than 180 days", expectedBody: true, requestType: http.MethodPut, requestPath: "/api/accounts/" + accountID, - requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552001,\"peer_login_expiration_enabled\": true}}"), + requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552001,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"), expectedStatus: http.StatusUnprocessableEntity, expectedArray: false, }, @@ -192,7 +226,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { expectedBody: true, requestType: http.MethodPut, requestPath: "/api/accounts/" + accountID, - requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 3599,\"peer_login_expiration_enabled\": true}}"), + requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 3599,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"), expectedStatus: http.StatusUnprocessableEntity, expectedArray: false, }, diff --git a/management/server/migration/migration.go b/management/server/migration/migration.go index c8a852e0a..ab11be731 100644 --- a/management/server/migration/migration.go +++ b/management/server/migration/migration.go @@ -373,3 +373,42 @@ func DropIndex[T any](ctx context.Context, db *gorm.DB, indexName string) error log.WithContext(ctx).Infof("dropped index %s from table %T", indexName, model) return nil } + +func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName string, columns ...string) error { + var model T + + stmt := &gorm.Statement{DB: db} + if err := stmt.Parse(&model); err != nil { + return fmt.Errorf("failed to parse model schema: %w", err) + } + tableName := stmt.Schema.Table + dialect := db.Dialector.Name() + + var columnClause string + if dialect == "mysql" { + var withLength []string + for _, col := range columns { + if col == "ip" || col == "dns_label" { + withLength = append(withLength, fmt.Sprintf("%s(64)", col)) + } else { + withLength = append(withLength, col) + } + } + columnClause = strings.Join(withLength, ", ") + } else { + columnClause = strings.Join(columns, ", ") + } + + createStmt := fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s (%s)", indexName, tableName, columnClause) + if dialect == "postgres" || dialect == "sqlite" { + createStmt = strings.Replace(createStmt, "CREATE UNIQUE INDEX", "CREATE UNIQUE INDEX IF NOT EXISTS", 1) + } + + log.WithContext(ctx).Infof("executing index creation: %s", createStmt) + if err := db.Exec(createStmt).Error; err != nil { + return fmt.Errorf("failed to create index %s: %w", indexName, err) + } + + log.WithContext(ctx).Infof("successfully created index %s on table %s", indexName, tableName) + return nil +} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 3caa6744a..8837f9f50 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -30,94 +30,95 @@ type MockAccountManager struct { GetAccountFunc func(ctx context.Context, accountID string) (*types.Account, error) CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType types.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error) - GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) - AccountExistsFunc func(ctx context.Context, accountID string) (bool, error) - GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error) - GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) - ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error) - GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) - MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error - SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) - DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error - GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error) - GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error) - AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) - GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error) - GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error) - GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error) - SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group, create bool) error - SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error - DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error - DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error - GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error - GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error - GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*types.Group, error) - DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error - GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) - SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *types.Policy, create bool) (*types.Policy, error) - DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error - ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*types.Policy, error) - GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error) - UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error - UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) - CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) - GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) - SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error - DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error - ListRoutesFunc func(ctx context.Context, accountID, userID string) ([]*route.Route, error) - SaveSetupKeyFunc func(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error) - ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) - SaveUserFunc func(ctx context.Context, accountID, userID string, user *types.User) (*types.UserInfo, error) - SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *types.User, addIfNotExists bool) (*types.UserInfo, error) - SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*types.User, addIfNotExists bool) ([]*types.UserInfo, error) - DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error - DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error - CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) - DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error - GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*types.PersonalAccessToken, error) - GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*types.PersonalAccessToken, error) - GetNameServerGroupFunc func(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) - CreateNameServerGroupFunc func(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) - SaveNameServerGroupFunc func(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error - DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error - ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) - CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) - GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) - DeleteAccountFunc func(ctx context.Context, accountID, userID string) error - GetDNSDomainFunc func(settings *types.Settings) string - StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) - GetEventsFunc func(ctx context.Context, accountID, userID string) ([]*activity.Event, error) - GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*types.DNSSettings, error) - SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *types.DNSSettings) error - GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) - UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) - LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) - SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) - InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error - GetAllConnectedPeersFunc func() (map[string]struct{}, error) - HasConnectedChannelFunc func(peerID string) bool - GetExternalCacheManagerFunc func() account.ExternalCacheManager - GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) - SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error) - DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error - ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) - GetIdpManagerFunc func() idp.Manager - UpdateIntegratedValidatorGroupsFunc func(ctx context.Context, accountID string, userID string, groups []string) error - GroupValidationFunc func(ctx context.Context, accountId string, groups []string) (bool, error) - SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error - FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) - GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error) - GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*types.Account, error) - GetUserByIDFunc func(ctx context.Context, id string) (*types.User, error) - GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*types.Settings, error) - DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error - BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) - GetStoreFunc func() store.Store - UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error) - GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error) - GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) - GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error) - + GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) + AccountExistsFunc func(ctx context.Context, accountID string) (bool, error) + GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error) + GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) + ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error) + GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) + MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error + SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error + GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error) + GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error) + AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error) + GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error) + GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error) + SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group, create bool) error + SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error + DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error + DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error + GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error + GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error + GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*types.Group, error) + DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error + GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) + SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *types.Policy, create bool) (*types.Policy, error) + DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error + ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*types.Policy, error) + GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error) + UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error + UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) + CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) + GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) + SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error + DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error + ListRoutesFunc func(ctx context.Context, accountID, userID string) ([]*route.Route, error) + SaveSetupKeyFunc func(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error) + ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) + SaveUserFunc func(ctx context.Context, accountID, userID string, user *types.User) (*types.UserInfo, error) + SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *types.User, addIfNotExists bool) (*types.UserInfo, error) + SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*types.User, addIfNotExists bool) ([]*types.UserInfo, error) + DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error + DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error + CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) + DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error + GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*types.PersonalAccessToken, error) + GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*types.PersonalAccessToken, error) + GetNameServerGroupFunc func(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) + CreateNameServerGroupFunc func(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) + SaveNameServerGroupFunc func(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error + DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error + ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) + CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) + GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) + DeleteAccountFunc func(ctx context.Context, accountID, userID string) error + GetDNSDomainFunc func(settings *types.Settings) string + StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) + GetEventsFunc func(ctx context.Context, accountID, userID string) ([]*activity.Event, error) + GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*types.DNSSettings, error) + SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *types.DNSSettings) error + GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) + UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) + LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error + GetAllConnectedPeersFunc func() (map[string]struct{}, error) + HasConnectedChannelFunc func(peerID string) bool + GetExternalCacheManagerFunc func() account.ExternalCacheManager + GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) + SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error) + DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error + ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) + GetIdpManagerFunc func() idp.Manager + UpdateIntegratedValidatorGroupsFunc func(ctx context.Context, accountID string, userID string, groups []string) error + GroupValidationFunc func(ctx context.Context, accountId string, groups []string) (bool, error) + SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error + FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) + GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error) + GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*types.Account, error) + GetUserByIDFunc func(ctx context.Context, id string) (*types.User, error) + GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*types.Settings, error) + DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error + BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) + GetStoreFunc func() store.Store + UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error) + GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error) + GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) + GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error) + GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error) + UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) } @@ -814,6 +815,22 @@ func (am *MockAccountManager) GetAccountMeta(ctx context.Context, accountID stri return nil, status.Errorf(codes.Unimplemented, "method GetAccountMeta is not implemented") } +// GetAccountOnboarding mocks GetAccountOnboarding of the AccountManager interface +func (am *MockAccountManager) GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) { + if am.GetAccountOnboardingFunc != nil { + return am.GetAccountOnboardingFunc(ctx, accountID, userID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetAccountOnboarding is not implemented") +} + +// UpdateAccountOnboarding mocks UpdateAccountOnboarding of the AccountManager interface +func (am *MockAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID string, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) { + if am.UpdateAccountOnboardingFunc != nil { + return am.UpdateAccountOnboardingFunc(ctx, accountID, userID, onboarding) + } + return nil, status.Errorf(codes.Unimplemented, "method UpdateAccountOnboarding is not implemented") +} + // GetUserByID mocks GetUserByID of the AccountManager interface func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) { if am.GetUserByIDFunc != nil { diff --git a/management/server/peer.go b/management/server/peer.go index 254048a96..1dd390dd9 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -15,13 +15,14 @@ import ( log "github.com/sirupsen/logrus" "golang.org/x/exp/maps" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server/geolocation" + "github.com/netbirdio/netbird/management/server/idp" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" - "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" @@ -234,14 +235,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } if peer.Name != update.Name { - existingLabels, err := getPeerDNSLabels(ctx, transaction, accountID) + var newLabel string + newLabel, err = getPeerIPDNSLabel(ctx, transaction, peer.IP, accountID, update.Name) if err != nil { - return err - } - - newLabel, err := types.GetPeerHostLabel(update.Name, existingLabels) - if err != nil { - return err + return fmt.Errorf("failed to get free DNS label: %w", err) } peer.Name = update.Name @@ -463,208 +460,232 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s upperKey := strings.ToUpper(setupKey) hashedKey := sha256.Sum256([]byte(upperKey)) encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) - var accountID string - var err error - addedByUser := false - if len(userID) > 0 { - addedByUser = true - accountID, err = am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userID) - } else { - accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey) - } - if err != nil { - return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found") - } - - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer func() { - if unlock != nil { - unlock() - } - }() + addedByUser := len(userID) > 0 // This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice. // Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow) // and the peer disconnects with a timeout and tries to register again. // We just check if this machine has been registered before and reject the second registration. // The connecting peer should be able to recover with a retry. - _, err = am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, peer.Key) + _, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peer.Key) if err == nil { return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered") } opEvent := &activity.Event{ Timestamp: time.Now().UTC(), - AccountID: accountID, } var newPeer *nbpeer.Peer var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - var setupKeyID string - var setupKeyName string - var ephemeral bool - var groupsToAdd []string - var allowExtraDNSLabels bool - if addedByUser { - user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthUpdate, userID) - if err != nil { - return fmt.Errorf("failed to get user groups: %w", err) - } - groupsToAdd = user.AutoGroups - opEvent.InitiatorID = userID - opEvent.Activity = activity.PeerAddedByUser - } else { - // Validate the setup key - sk, err := transaction.GetSetupKeyBySecret(ctx, store.LockingStrengthUpdate, encodedHashedKey) - if err != nil { - return fmt.Errorf("failed to get setup key: %w", err) - } - - if !sk.IsValid() { - return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid") - } - - opEvent.InitiatorID = sk.Id - opEvent.Activity = activity.PeerAddedWithSetupKey - groupsToAdd = sk.AutoGroups - ephemeral = sk.Ephemeral - setupKeyID = sk.Id - setupKeyName = sk.Name - allowExtraDNSLabels = sk.AllowExtraDNSLabels - - if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 { - return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels") - } - } - - if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" { - if am.idpManager != nil { - userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID}) - if err == nil && userdata != nil { - peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0]) - } - } - } - - freeLabel, err := am.getFreeDNSLabel(ctx, transaction, accountID, peer.Meta.Hostname) + var setupKeyID string + var setupKeyName string + var ephemeral bool + var groupsToAdd []string + var allowExtraDNSLabels bool + var accountID string + if addedByUser { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID) if err != nil { - return fmt.Errorf("failed to get free DNS label: %w", err) + return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: user not found") } - - freeIP, err := getFreeIP(ctx, transaction, accountID) + groupsToAdd = user.AutoGroups + opEvent.InitiatorID = userID + opEvent.Activity = activity.PeerAddedByUser + accountID = user.AccountID + } else { + // Validate the setup key + sk, err := am.Store.GetSetupKeyBySecret(ctx, store.LockingStrengthNone, encodedHashedKey) if err != nil { - return fmt.Errorf("failed to get free IP: %w", err) + return nil, nil, nil, status.Errorf(status.NotFound, "couldn't add peer: setup key is invalid") } - if err := domain.ValidateDomainsList(peer.ExtraDNSLabels); err != nil { - return status.Errorf(status.InvalidArgument, "invalid extra DNS labels: %v", err) + // we will check key twice for early return + if !sk.IsValid() { + return nil, nil, nil, status.Errorf(status.NotFound, "couldn't add peer: setup key is invalid") } - registrationTime := time.Now().UTC() - newPeer = &nbpeer.Peer{ - ID: xid.New().String(), - AccountID: accountID, - Key: peer.Key, - IP: freeIP, - Meta: peer.Meta, - Name: peer.Meta.Hostname, - DNSLabel: freeLabel, - UserID: userID, - Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime}, - SSHEnabled: false, - SSHKey: peer.SSHKey, - LastLogin: ®istrationTime, - CreatedAt: registrationTime, - LoginExpirationEnabled: addedByUser, - Ephemeral: ephemeral, - Location: peer.Location, - InactivityExpirationEnabled: addedByUser, - ExtraDNSLabels: peer.ExtraDNSLabels, - AllowExtraDNSLabels: allowExtraDNSLabels, - } - settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) - if err != nil { - return fmt.Errorf("failed to get account settings: %w", err) - } + opEvent.InitiatorID = sk.Id + opEvent.Activity = activity.PeerAddedWithSetupKey + groupsToAdd = sk.AutoGroups + ephemeral = sk.Ephemeral + setupKeyID = sk.Id + setupKeyName = sk.Name + allowExtraDNSLabels = sk.AllowExtraDNSLabels + accountID = sk.AccountID - opEvent.TargetID = newPeer.ID - opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain(settings)) - if !addedByUser { - opEvent.Meta["setup_key_name"] = setupKeyName + if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 { + return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels") } + } + opEvent.AccountID = accountID - if am.geo != nil && newPeer.Location.ConnectionIP != nil { - location, err := am.geo.Lookup(newPeer.Location.ConnectionIP) - if err != nil { - log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err) - } else { - newPeer.Location.CountryCode = location.Country.ISOCode - newPeer.Location.CityName = location.City.Names.En - newPeer.Location.GeoNameID = location.City.GeonameID + if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" { + if am.idpManager != nil { + userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID}) + if err == nil && userdata != nil { + peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0]) } } + } - newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra) - - err = transaction.AddPeerToAllGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID) - if err != nil { - return fmt.Errorf("failed adding peer to All group: %w", err) - } - - if len(groupsToAdd) > 0 { - for _, g := range groupsToAdd { - err = transaction.AddPeerToGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID, g) - if err != nil { - return err - } - } - } - - err = transaction.AddPeerToAccount(ctx, store.LockingStrengthUpdate, newPeer) - if err != nil { - return fmt.Errorf("failed to add peer to account: %w", err) - } - - err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) - if err != nil { - return fmt.Errorf("failed to increment network serial: %w", err) - } - - if addedByUser { - err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin()) - if err != nil { - log.WithContext(ctx).Debugf("failed to update user last login: %v", err) - } - } else { - err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID) - if err != nil { - return fmt.Errorf("failed to increment setup key usage: %w", err) - } - } - - updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, newPeer.ID) - if err != nil { - return err - } - - log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID) - return nil - }) + if err := domain.ValidateDomainsList(peer.ExtraDNSLabels); err != nil { + return nil, nil, nil, status.Errorf(status.InvalidArgument, "invalid extra DNS labels: %v", err) + } + registrationTime := time.Now().UTC() + newPeer = &nbpeer.Peer{ + ID: xid.New().String(), + AccountID: accountID, + Key: peer.Key, + Meta: peer.Meta, + Name: peer.Meta.Hostname, + UserID: userID, + Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime}, + SSHEnabled: false, + SSHKey: peer.SSHKey, + LastLogin: ®istrationTime, + CreatedAt: registrationTime, + LoginExpirationEnabled: addedByUser, + Ephemeral: ephemeral, + Location: peer.Location, + InactivityExpirationEnabled: addedByUser, + ExtraDNSLabels: peer.ExtraDNSLabels, + AllowExtraDNSLabels: allowExtraDNSLabels, + } + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { + return nil, nil, nil, fmt.Errorf("failed to get account settings: %w", err) + } + + if am.geo != nil && newPeer.Location.ConnectionIP != nil { + location, err := am.geo.Lookup(newPeer.Location.ConnectionIP) + if err != nil { + log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err) + } else { + newPeer.Location.CountryCode = location.Country.ISOCode + newPeer.Location.CityName = location.City.Names.En + newPeer.Location.GeoNameID = location.City.GeonameID + } + } + + newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra) + + network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed getting network: %w", err) + } + + maxAttempts := 10 + for attempt := 1; attempt <= maxAttempts; attempt++ { + var freeIP net.IP + freeIP, err = types.AllocateRandomPeerIP(network.Net) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to get free IP: %w", err) + } + + var freeLabel string + freeLabel, err = getPeerIPDNSLabel(ctx, am.Store, freeIP, accountID, peer.Meta.Hostname) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err) + } + + newPeer.DNSLabel = freeLabel + newPeer.IP = freeIP + + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer func() { + if unlock != nil { + unlock() + } + }() + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + err = transaction.AddPeerToAccount(ctx, store.LockingStrengthUpdate, newPeer) + if err != nil { + return err + } + + err = transaction.AddPeerToAllGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID) + if err != nil { + return fmt.Errorf("failed adding peer to All group: %w", err) + } + + if len(groupsToAdd) > 0 { + for _, g := range groupsToAdd { + err = transaction.AddPeerToGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID, g) + if err != nil { + return err + } + } + } + + if addedByUser { + err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin()) + if err != nil { + log.WithContext(ctx).Debugf("failed to update user last login: %v", err) + } + } else { + sk, err := transaction.GetSetupKeyBySecret(ctx, store.LockingStrengthUpdate, encodedHashedKey) + if err != nil { + return fmt.Errorf("failed to get setup key: %w", err) + } + + // we validate at the end to not block the setup key for too long + if !sk.IsValid() { + return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid") + } + + err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID) + if err != nil { + return fmt.Errorf("failed to increment setup key usage: %w", err) + } + } + + err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID) + return nil + }) + if err == nil { + unlock() + unlock = nil + break + } + + if isUniqueConstraintError(err) { + unlock() + unlock = nil + log.WithContext(ctx).Debugf("Failed to add peer in attempt %d, retrying: %v", attempt, err) + continue + } + return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err) } + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to add peer to database after %d attempts: %w", maxAttempts, err) + } + + updateAccountPeers, err = isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID) + if err != nil { + updateAccountPeers = true + } if newPeer == nil { return nil, nil, nil, fmt.Errorf("new peer is nil") } - am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) + opEvent.TargetID = newPeer.ID + opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain(settings)) + if !addedByUser { + opEvent.Meta["setup_key_name"] = setupKeyName + } - unlock() - unlock = nil + am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) if updateAccountPeers { am.BufferUpdateAccountPeers(ctx, accountID) @@ -673,23 +694,21 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer) } -func getFreeIP(ctx context.Context, transaction store.Store, accountID string) (net.IP, error) { - takenIps, err := transaction.GetTakenIPs(ctx, store.LockingStrengthShare, accountID) +func getPeerIPDNSLabel(ctx context.Context, tx store.Store, ip net.IP, accountID, peerHostName string) (string, error) { + ip = ip.To4() + + dnsName, err := nbdns.GetParsedDomainLabel(peerHostName) if err != nil { - return nil, fmt.Errorf("failed to get taken IPs: %w", err) + return "", fmt.Errorf("failed to parse peer host name %s: %w", peerHostName, err) } - network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthUpdate, accountID) + _, err = tx.GetPeerIdByLabel(ctx, store.LockingStrengthNone, accountID, dnsName) if err != nil { - return nil, fmt.Errorf("failed getting network: %w", err) + //nolint:nilerr + return dnsName, nil } - nextIp, err := types.AllocatePeerIP(network.Net, takenIps) - if err != nil { - return nil, fmt.Errorf("failed to allocate new peer ip: %w", err) - } - - return nextIp, nil + return fmt.Sprintf("%s-%d-%d", dnsName, ip[2], ip[3]), nil } // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible @@ -838,7 +857,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer if login.UserID != "" { if peer.UserID != login.UserID { log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, login.UserID) - return status.Errorf(status.Unauthenticated, "invalid user") + return status.NewPeerLoginMismatchError() } changed, err := am.handleUserPeer(ctx, transaction, peer, settings) @@ -1087,7 +1106,7 @@ func checkAuth(ctx context.Context, loginUserID string, peer *nbpeer.Peer) error } if peer.UserID != loginUserID { log.WithContext(ctx).Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, loginUserID) - return status.Errorf(status.Unauthenticated, "can't login with this credentials") + return status.NewPeerLoginMismatchError() } return nil } @@ -1477,19 +1496,6 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str return groupIDs, err } -func getPeerDNSLabels(ctx context.Context, transaction store.Store, accountID string) (types.LookupMap, error) { - dnsLabels, err := transaction.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID) - if err != nil { - return nil, err - } - - existingLabels := make(types.LookupMap) - for _, label := range dnsLabels { - existingLabels[label] = struct{}{} - } - return existingLabels, nil -} - // IsPeerInActiveGroup checks if the given peer is part of a group that is used // in an active DNS, route, or ACL configuration. func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID, peerID string) (bool, error) { diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 8ce1dfb4e..f7140e254 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -20,14 +20,14 @@ type Peer struct { // WireGuard public key Key string `gorm:"index"` // IP address of the Peer - IP net.IP `gorm:"serializer:json"` + IP net.IP `gorm:"serializer:json"` // uniqueness index per accountID (check migrations) // Meta is a Peer system meta data Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"` // Name is peer's name (machine name) Name string // DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's // domain to the peer label. e.g. peer-dns-label.netbird.cloud - DNSLabel string + DNSLabel string // uniqueness index per accountID (check migrations) // Status peer's management connection status Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"` // The user ID that registered the peer diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 775385a29..31439d670 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -10,7 +10,9 @@ import ( "net/netip" "os" "runtime" + "strconv" "strings" + "sync" "testing" "time" @@ -19,11 +21,13 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/exp/maps" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/util" @@ -1373,6 +1377,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { existingSetupKeyID string expectedGroupIDsInAccount []string expectAddPeerError bool + errorType status.Type expectedErrorMsgSubstring string }{ { @@ -1385,13 +1390,15 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { name: "Failed registration with setup key not allowing extra DNS labels", existingSetupKeyID: "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", expectAddPeerError: true, + errorType: status.PreconditionFailed, expectedErrorMsgSubstring: "setup key doesn't allow extra DNS labels", }, { name: "Absent setup key", existingSetupKeyID: "AAAAAAAA-38F5-4553-B31E-DD66C696CEBB", expectAddPeerError: true, - expectedErrorMsgSubstring: "failed adding new peer: account not found", + errorType: status.NotFound, + expectedErrorMsgSubstring: "couldn't add peer: setup key is invalid", }, } @@ -1416,6 +1423,11 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { if tc.expectAddPeerError { require.Error(t, err, "Expected an error when adding peer with setup key: %s", tc.existingSetupKeyID) assert.Contains(t, err.Error(), tc.expectedErrorMsgSubstring, "Error message mismatch") + e, ok := status.FromError(err) + if !ok { + t.Fatal("Failed to map error") + } + assert.Equal(t, e.Type(), tc.errorType) return } @@ -2057,10 +2069,14 @@ func Test_DeletePeer(t *testing.T) { "peer1": { ID: "peer1", AccountID: accountID, + IP: net.IP{1, 1, 1, 1}, + DNSLabel: "peer1.test", }, "peer2": { ID: "peer2", AccountID: accountID, + IP: net.IP{2, 2, 2, 2}, + DNSLabel: "peer2.test", }, } account.Groups = map[string]*types.Group{ @@ -2090,3 +2106,138 @@ func Test_DeletePeer(t *testing.T) { assert.NotContains(t, group.Peers, "peer1") } + +func Test_IsUniqueConstraintError(t *testing.T) { + tests := []struct { + name string + engine types.Engine + }{ + { + name: "PostgreSQL uniqueness error", + engine: types.PostgresStoreEngine, + }, + { + name: "MySQL uniqueness error", + engine: types.MysqlStoreEngine, + }, + { + name: "SQLite uniqueness error", + engine: types.SqliteStoreEngine, + }, + } + + peer := &nbpeer.Peer{ + ID: "test-peer-id", + AccountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + DNSLabel: "test-peer-dns-label", + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Setenv("NETBIRD_STORE_ENGINE", string(tt.engine)) + s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + + err = s.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer) + assert.NoError(t, err) + + err = s.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer) + result := isUniqueConstraintError(err) + assert.True(t, result) + }) + } +} + +func Test_AddPeer(t *testing.T) { + t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine)) + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + return + } + + accountID := "testaccount" + userID := "testuser" + + _, err = createAccount(manager, accountID, userID, "domain.com") + if err != nil { + t.Fatal("error creating account") + return + } + + setupKey, err := manager.CreateSetupKey(context.Background(), accountID, "test-key", types.SetupKeyReusable, time.Hour, nil, 10000, userID, false, false) + if err != nil { + t.Fatal("error creating setup key") + return + } + + const totalPeers = 300 // totalPeers / differentHostnames should be less than 10 (due to concurrent retries) + const differentHostnames = 50 + + var wg sync.WaitGroup + errs := make(chan error, totalPeers+differentHostnames) + start := make(chan struct{}) + for i := 0; i < totalPeers; i++ { + wg.Add(1) + hostNameID := i % differentHostnames + + go func(i int) { + defer wg.Done() + + newPeer := &nbpeer.Peer{ + Key: "key" + strconv.Itoa(i), + Meta: nbpeer.PeerSystemMeta{Hostname: "peer" + strconv.Itoa(hostNameID), GoOS: "linux"}, + } + + <-start + + _, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", newPeer) + if err != nil { + errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err) + return + } + + }(i) + } + startTime := time.Now() + + close(start) + wg.Wait() + close(errs) + + t.Logf("time since start: %s", time.Since(startTime)) + + for err := range errs { + t.Fatal(err) + } + + account, err := manager.Store.GetAccount(context.Background(), accountID) + if err != nil { + t.Fatalf("Failed to get account %s: %v", accountID, err) + } + + assert.Equal(t, totalPeers, len(account.Peers), "Expected %d peers in account %s, got %d", totalPeers, accountID, len(account.Peers)) + + seenIP := make(map[string]bool) + for _, p := range account.Peers { + ipStr := p.IP.String() + if seenIP[ipStr] { + t.Fatalf("Duplicate IP found in account %s: %s", accountID, ipStr) + } + seenIP[ipStr] = true + } + + seenLabel := make(map[string]bool) + for _, p := range account.Peers { + if seenLabel[p.DNSLabel] { + t.Fatalf("Duplicate Label found in account %s: %s", accountID, p.DNSLabel) + } + seenLabel[p.DNSLabel] = true + } + + assert.Equal(t, totalPeers, maps.Values(account.SetupKeys)[0].UsedTimes) + assert.Equal(t, uint64(totalPeers), account.Network.Serial) +} diff --git a/management/server/status/error.go b/management/server/status/error.go index 5a6f6d1a7..e3cc27b29 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -90,6 +90,11 @@ func NewAccountNotFoundError(accountKey string) error { return Errorf(NotFound, "account not found: %s", accountKey) } +// NewAccountOnboardingNotFoundError creates a new Error with NotFound type for a missing account onboarding +func NewAccountOnboardingNotFoundError(accountKey string) error { + return Errorf(NotFound, "account onboarding not found: %s", accountKey) +} + // NewPeerNotPartOfAccountError creates a new Error with PermissionDenied type for a peer not being part of an account func NewPeerNotPartOfAccountError() error { return Errorf(PermissionDenied, "peer is not part of this account") @@ -105,11 +110,16 @@ func NewUserBlockedError() error { return Errorf(PermissionDenied, "user is blocked") } -// NewPeerNotRegisteredError creates a new Error with NotFound type for a missing peer +// NewPeerNotRegisteredError creates a new Error with Unauthenticated type unregistered peer func NewPeerNotRegisteredError() error { return Errorf(Unauthenticated, "peer is not registered") } +// NewPeerLoginMismatchError creates a new Error with Unauthenticated type for a peer that is already registered for another user +func NewPeerLoginMismatchError() error { + return Errorf(Unauthenticated, "peer is already registered by a different User or a Setup Key") +} + // NewPeerLoginExpiredError creates a new Error with PermissionDenied type for an expired peer func NewPeerLoginExpiredError() error { return Errorf(PermissionDenied, "peer login has expired, please log in once more") diff --git a/management/server/store/file_store.go b/management/server/store/file_store.go index 3b95164f5..d5d9337ca 100644 --- a/management/server/store/file_store.go +++ b/management/server/store/file_store.go @@ -156,7 +156,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) { allGroup, err := account.GetGroupAll() if err != nil { - log.WithContext(ctx).Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err) + log.WithContext(ctx).Errorf("unable to find the All group, this should happen only when migratePreAuto from a version that didn't support groups. Error: %v", err) // if the All group didn't exist we probably don't have routes to update continue } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 72a73a57a..e380a7da7 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -92,17 +92,20 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil } - if err := migrate(ctx, db); err != nil { - return nil, fmt.Errorf("migrate: %w", err) + if err := migratePreAuto(ctx, db); err != nil { + return nil, fmt.Errorf("migratePreAuto: %w", err) } err = db.AutoMigrate( &types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, - &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, + &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{}, ) if err != nil { - return nil, fmt.Errorf("auto migrate: %w", err) + return nil, fmt.Errorf("auto migratePreAuto: %w", err) + } + if err := migratePostAuto(ctx, db); err != nil { + return nil, fmt.Errorf("migratePostAuto: %w", err) } return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil @@ -725,6 +728,32 @@ func (s *SqlStore) GetAccountMeta(ctx context.Context, lockStrength LockingStren return &accountMeta, nil } +// GetAccountOnboarding retrieves the onboarding information for a specific account. +func (s *SqlStore) GetAccountOnboarding(ctx context.Context, accountID string) (*types.AccountOnboarding, error) { + var accountOnboarding types.AccountOnboarding + result := s.db.Model(&accountOnboarding).First(&accountOnboarding, accountIDCondition, accountID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewAccountOnboardingNotFoundError(accountID) + } + log.WithContext(ctx).Errorf("error when getting account onboarding %s from the store: %s", accountID, result.Error) + return nil, status.NewGetAccountFromStoreError(result.Error) + } + + return &accountOnboarding, nil +} + +// SaveAccountOnboarding updates the onboarding information for a specific account. +func (s *SqlStore) SaveAccountOnboarding(ctx context.Context, onboarding *types.AccountOnboarding) error { + result := s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(onboarding) + if result.Error != nil { + log.WithContext(ctx).Errorf("error when saving account onboarding %s in the store: %s", onboarding.AccountID, result.Error) + return status.Errorf(status.Internal, "error when saving account onboarding %s in the store: %s", onboarding.AccountID, result.Error) + } + + return nil +} + func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { start := time.Now() defer func() { @@ -967,7 +996,7 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength return ips, nil } -func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) { +func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string, dnsLabel string) ([]string, error) { tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) @@ -975,7 +1004,7 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock var labels []string result := tx.Model(&nbpeer.Peer{}). - Where("account_id = ?", accountID). + Where("account_id = ? AND dns_label LIKE ?", accountID, dnsLabel+"%"). Pluck("dns_label", &labels) if result.Error != nil { @@ -1254,7 +1283,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.NewSetupKeyNotFoundError(key) + return nil, status.Errorf(status.PreconditionFailed, "setup key not found") } log.WithContext(ctx).Errorf("failed to get setup key by secret from store: %v", result.Error) return nil, status.Errorf(status.Internal, "failed to get setup key by secret from store") @@ -1410,7 +1439,11 @@ func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStreng // GetAccountPeers retrieves peers for an account. func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) { var peers []*nbpeer.Peer - query := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Where(accountIDCondition, accountID) + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + query := tx.Where(accountIDCondition, accountID) if nameFilter != "" { query = query.Where("name LIKE ?", "%"+nameFilter+"%") @@ -2546,6 +2579,27 @@ func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength return &peer, nil } +func (s *SqlStore) GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error) { + tx := s.db.WithContext(ctx) + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var peerID string + result := tx.Model(&nbpeer.Peer{}). + Select("id"). + // Where(" = ?", hostname). + Where("account_id = ? AND dns_label = ?", accountID, hostname). + Limit(1). + Scan(&peerID) + + if peerID == "" { + return "", gorm.ErrRecordNotFound + } + + return peerID, result.Error +} + func (s *SqlStore) CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error) { var count int64 result := s.db.Model(&types.Account{}). diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index f187be8c7..738c5a28c 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -10,6 +10,7 @@ import ( "net/netip" "os" "runtime" + "sort" "sync" "testing" "time" @@ -353,9 +354,16 @@ func TestSqlite_DeleteAccount(t *testing.T) { t.Errorf("expecting 1 Accounts to be stored after SaveAccount()") } + o, err := store.GetAccountOnboarding(context.Background(), account.Id) + require.NoError(t, err) + require.Equal(t, o.AccountID, account.Id) + err = store.DeleteAccount(context.Background(), account) require.NoError(t, err) + _, err = store.GetAccountOnboarding(context.Background(), account.Id) + require.Error(t, err, "expecting error after removing DeleteAccount when getting onboarding") + if len(store.GetAllAccounts(context.Background())) != 0 { t.Errorf("expecting 0 Accounts to be stored after DeleteAccount()") } @@ -413,12 +421,21 @@ func Test_GetAccount(t *testing.T) { account, err := store.GetAccount(context.Background(), id) require.NoError(t, err) require.Equal(t, id, account.Id, "account id should match") + require.Equal(t, false, account.Onboarding.OnboardingFlowPending) + + id = "9439-34653001fc3b-bf1c8084-ba50-4ce7" + + account, err = store.GetAccount(context.Background(), id) + require.NoError(t, err) + require.Equal(t, id, account.Id, "account id should match") + require.Equal(t, true, account.Onboarding.OnboardingFlowPending) _, err = store.GetAccount(context.Background(), "non-existing-account") assert.Error(t, err) parsedErr, ok := status.FromError(err) require.True(t, ok) require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") + }) } @@ -630,7 +647,7 @@ func TestMigrate(t *testing.T) { t.Cleanup(cleanUp) assert.NoError(t, err) - err = migrate(context.Background(), store.(*SqlStore).db) + err = migratePreAuto(context.Background(), store.(*SqlStore).db) require.NoError(t, err, "Migration should not fail on empty db") _, ipnet, err := net.ParseCIDR("10.0.0.0/24") @@ -685,10 +702,10 @@ func TestMigrate(t *testing.T) { err = store.(*SqlStore).db.Save(rt).Error require.NoError(t, err, "Failed to insert Gob data") - err = migrate(context.Background(), store.(*SqlStore).db) + err = migratePreAuto(context.Background(), store.(*SqlStore).db) require.NoError(t, err, "Migration should not fail on gob populated db") - err = migrate(context.Background(), store.(*SqlStore).db) + err = migratePreAuto(context.Background(), store.(*SqlStore).db) require.NoError(t, err, "Migration should not fail on migrated db") err = store.(*SqlStore).db.Delete(rt).Where("id = ?", "route1").Error @@ -704,10 +721,10 @@ func TestMigrate(t *testing.T) { err = store.(*SqlStore).db.Save(nRT).Error require.NoError(t, err, "Failed to insert json nil slice data") - err = migrate(context.Background(), store.(*SqlStore).db) + err = migratePreAuto(context.Background(), store.(*SqlStore).db) require.NoError(t, err, "Migration should not fail on json nil slice populated db") - err = migrate(context.Background(), store.(*SqlStore).db) + err = migratePreAuto(context.Background(), store.(*SqlStore).db) require.NoError(t, err, "Migration should not fail on migrated db") } @@ -950,6 +967,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) { peer1 := &nbpeer.Peer{ ID: "peer1", AccountID: existingAccountID, + DNSLabel: "peer1", IP: net.IP{1, 1, 1, 1}, } err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1) @@ -961,8 +979,9 @@ func TestSqlite_GetTakenIPs(t *testing.T) { assert.Equal(t, []net.IP{ip1}, takenIPs) peer2 := &nbpeer.Peer{ - ID: "peer2", + ID: "peer1second", AccountID: existingAccountID, + DNSLabel: "peer1-1", IP: net.IP{2, 2, 2, 2}, } err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2) @@ -972,49 +991,100 @@ func TestSqlite_GetTakenIPs(t *testing.T) { require.NoError(t, err) ip2 := net.IP{2, 2, 2, 2}.To16() assert.Equal(t, []net.IP{ip1, ip2}, takenIPs) - } func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { - t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) - if err != nil { - return - } - t.Cleanup(cleanup) + runTestForAllEngines(t, "../testdata/extended-store.sql", func(t *testing.T, store Store) { + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + peerHostname := "peer1" - existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + _, err := store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) - _, err = store.GetAccount(context.Background(), existingAccountID) - require.NoError(t, err) + labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID, peerHostname) + require.NoError(t, err) + assert.Equal(t, []string{}, labels) - labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) - require.NoError(t, err) - assert.Equal(t, []string{}, labels) + peer1 := &nbpeer.Peer{ + ID: "peer1", + AccountID: existingAccountID, + DNSLabel: "peer1", + IP: net.IP{1, 1, 1, 1}, + } + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1) + require.NoError(t, err) - peer1 := &nbpeer.Peer{ - ID: "peer1", - AccountID: existingAccountID, - DNSLabel: "peer1.domain.test", - } - err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1) - require.NoError(t, err) + labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID, peerHostname) + require.NoError(t, err) + assert.Equal(t, []string{"peer1"}, labels) - labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) - require.NoError(t, err) - assert.Equal(t, []string{"peer1.domain.test"}, labels) + peer2 := &nbpeer.Peer{ + ID: "peer1second", + AccountID: existingAccountID, + DNSLabel: "peer1-1", + IP: net.IP{2, 2, 2, 2}, + } + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2) + require.NoError(t, err) - peer2 := &nbpeer.Peer{ - ID: "peer2", - AccountID: existingAccountID, - DNSLabel: "peer2.domain.test", - } - err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2) - require.NoError(t, err) + labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID, peerHostname) + require.NoError(t, err) - labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) - require.NoError(t, err) - assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels) + expected := []string{"peer1", "peer1-1"} + sort.Strings(expected) + sort.Strings(labels) + assert.Equal(t, expected, labels) + }) +} + +func Test_AddPeerWithSameDnsLabel(t *testing.T) { + runTestForAllEngines(t, "../testdata/extended-store.sql", func(t *testing.T, store Store) { + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + _, err := store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + + peer1 := &nbpeer.Peer{ + ID: "peer1", + AccountID: existingAccountID, + DNSLabel: "peer1.domain.test", + } + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1) + require.NoError(t, err) + + peer2 := &nbpeer.Peer{ + ID: "peer1second", + AccountID: existingAccountID, + DNSLabel: "peer1.domain.test", + } + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2) + require.Error(t, err) + }) +} + +func Test_AddPeerWithSameIP(t *testing.T) { + runTestForAllEngines(t, "../testdata/extended-store.sql", func(t *testing.T, store Store) { + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + _, err := store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + + peer1 := &nbpeer.Peer{ + ID: "peer1", + AccountID: existingAccountID, + IP: net.IP{1, 1, 1, 1}, + } + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1) + require.NoError(t, err) + + peer2 := &nbpeer.Peer{ + ID: "peer1second", + AccountID: existingAccountID, + IP: net.IP{1, 1, 1, 1}, + } + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2) + require.Error(t, err) + }) } func TestSqlite_GetAccountNetwork(t *testing.T) { @@ -2042,6 +2112,7 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *ty PeerInactivityExpirationEnabled: false, PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, }, + Onboarding: types.AccountOnboarding{SignupFormPending: true, OnboardingFlowPending: true}, } if err := acc.AddAllGroup(false); err != nil { @@ -3386,6 +3457,63 @@ func TestSqlStore_GetAccountMeta(t *testing.T) { require.Equal(t, time.Date(2024, time.October, 2, 14, 1, 38, 210000000, time.UTC), accountMeta.CreatedAt.UTC()) } +func TestSqlStore_GetAccountOnboarding(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "9439-34653001fc3b-bf1c8084-ba50-4ce7" + a, err := store.GetAccount(context.Background(), accountID) + require.NoError(t, err) + t.Logf("Onboarding: %+v", a.Onboarding) + err = store.SaveAccount(context.Background(), a) + require.NoError(t, err) + onboarding, err := store.GetAccountOnboarding(context.Background(), accountID) + require.NoError(t, err) + require.NotNil(t, onboarding) + require.Equal(t, accountID, onboarding.AccountID) + require.Equal(t, time.Date(2024, time.October, 2, 14, 1, 38, 210000000, time.UTC), onboarding.CreatedAt.UTC()) +} + +func TestSqlStore_SaveAccountOnboarding(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + t.Run("New onboarding should be saved correctly", func(t *testing.T) { + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + onboarding := &types.AccountOnboarding{ + AccountID: accountID, + SignupFormPending: true, + OnboardingFlowPending: true, + } + + err = store.SaveAccountOnboarding(context.Background(), onboarding) + require.NoError(t, err) + + savedOnboarding, err := store.GetAccountOnboarding(context.Background(), accountID) + require.NoError(t, err) + require.Equal(t, onboarding.SignupFormPending, savedOnboarding.SignupFormPending) + require.Equal(t, onboarding.OnboardingFlowPending, savedOnboarding.OnboardingFlowPending) + }) + + t.Run("Existing onboarding should be updated correctly", func(t *testing.T) { + accountID := "9439-34653001fc3b-bf1c8084-ba50-4ce7" + onboarding, err := store.GetAccountOnboarding(context.Background(), accountID) + require.NoError(t, err) + + onboarding.OnboardingFlowPending = !onboarding.OnboardingFlowPending + onboarding.SignupFormPending = !onboarding.SignupFormPending + + err = store.SaveAccountOnboarding(context.Background(), onboarding) + require.NoError(t, err) + + savedOnboarding, err := store.GetAccountOnboarding(context.Background(), accountID) + require.NoError(t, err) + require.Equal(t, onboarding.SignupFormPending, savedOnboarding.SignupFormPending) + require.Equal(t, onboarding.OnboardingFlowPending, savedOnboarding.OnboardingFlowPending) + }) +} + func TestSqlStore_GetAnyAccountID(t *testing.T) { t.Run("should return account ID when accounts exist", func(t *testing.T) { store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) diff --git a/management/server/store/store.go b/management/server/store/store.go index f66130ad3..b3254c4c9 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -52,6 +52,7 @@ type Store interface { GetAllAccounts(ctx context.Context) []*types.Account GetAccount(ctx context.Context, accountID string) (*types.Account, error) GetAccountMeta(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.AccountMeta, error) + GetAccountOnboarding(ctx context.Context, accountID string) (*types.AccountOnboarding, error) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) @@ -74,6 +75,7 @@ type Store interface { SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.Settings) error CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error) + SaveAccountOnboarding(ctx context.Context, onboarding *types.AccountOnboarding) error GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) @@ -117,7 +119,7 @@ type Store interface { SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error - GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) + GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string, hostname string) ([]string, error) AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error) @@ -193,6 +195,7 @@ type Store interface { SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *resourceTypes.NetworkResource) error DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error) + GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error) } const ( @@ -234,9 +237,9 @@ func getStoreEngine(ctx context.Context, dataDir string, kind types.Engine) type if util.FileExists(jsonStoreFile) && !util.FileExists(sqliteStoreFile) { log.WithContext(ctx).Warnf("unsupported store engine specified, but found %s. Automatically migrating to SQLite.", jsonStoreFile) - // Attempt to migrate from JSON store to SQLite + // Attempt to migratePreAuto from JSON store to SQLite if err := MigrateFileStoreToSqlite(ctx, dataDir); err != nil { - log.WithContext(ctx).Errorf("failed to migrate filestore to SQLite: %v", err) + log.WithContext(ctx).Errorf("failed to migratePreAuto filestore to SQLite: %v", err) kind = types.FileStoreEngine } } @@ -280,9 +283,9 @@ func checkFileStoreEngine(kind types.Engine, dataDir string) error { return nil } -// migrate migrates the SQLite database to the latest schema -func migrate(ctx context.Context, db *gorm.DB) error { - migrations := getMigrations(ctx) +// migratePreAuto migrates the SQLite database to the latest schema +func migratePreAuto(ctx context.Context, db *gorm.DB) error { + migrations := getMigrationsPreAuto(ctx) for _, m := range migrations { if err := m(db); err != nil { @@ -293,7 +296,7 @@ func migrate(ctx context.Context, db *gorm.DB) error { return nil } -func getMigrations(ctx context.Context) []migrationFunc { +func getMigrationsPreAuto(ctx context.Context) []migrationFunc { return []migrationFunc{ func(db *gorm.DB) error { return migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](ctx, db, "network_net") @@ -329,6 +332,28 @@ func getMigrations(ctx context.Context) []migrationFunc { return migration.DropIndex[routerTypes.NetworkRouter](ctx, db, "idx_network_routers_id") }, } +} // migratePostAuto migrates the SQLite database to the latest schema +func migratePostAuto(ctx context.Context, db *gorm.DB) error { + migrations := getMigrationsPostAuto(ctx) + + for _, m := range migrations { + if err := m(db); err != nil { + return err + } + } + + return nil +} + +func getMigrationsPostAuto(ctx context.Context) []migrationFunc { + return []migrationFunc{ + func(db *gorm.DB) error { + return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_account_ip", "account_id", "ip") + }, + func(db *gorm.DB) error { + return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_account_dnslabel", "account_id", "dns_label") + }, + } } // NewTestStoreFromSQL is only used in tests. It will create a test database base of the store engine set in env. @@ -577,7 +602,7 @@ func MigrateFileStoreToSqlite(ctx context.Context, dataDir string) error { sqliteStoreAccounts := len(store.GetAllAccounts(ctx)) if fsStoreAccounts != sqliteStoreAccounts { - return fmt.Errorf("failed to migrate accounts from file to sqlite. Expected accounts: %d, got: %d", + return fmt.Errorf("failed to migratePreAuto accounts from file to sqlite. Expected accounts: %d, got: %d", fsStoreAccounts, sqliteStoreAccounts) } diff --git a/management/server/testdata/store.sql b/management/server/testdata/store.sql index 41b8fa2f7..a21783857 100644 --- a/management/server/testdata/store.sql +++ b/management/server/testdata/store.sql @@ -1,4 +1,5 @@ CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `account_onboardings` (`account_id` text, `created_at` datetime,`updated_at` datetime, `onboarding_flow_pending` numeric, `signup_form_pending` numeric, PRIMARY KEY (`account_id`)); CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); @@ -38,7 +39,8 @@ CREATE INDEX `idx_networks_id` ON `networks`(`id`); CREATE INDEX `idx_networks_account_id` ON `networks`(`account_id`); INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','edafee4e-63fb-11ec-90d6-0242ac120003','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); -INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); +INSERT INTO accounts VALUES('9439-34653001fc3b-bf1c8084-ba50-4ce7','90d6-0242ac120003-edafee4e-63fb-11ec','2024-10-02 16:01:38.210000+02:00','test2.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO account_onboardings VALUES('9439-34653001fc3b-bf1c8084-ba50-4ce7','2024-10-02 16:01:38.210000+02:00','2021-08-19 20:46:20.005936822+02:00',1,0);INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'["cs1tnh0hhcjnqoiuebeg"]',0,0); INSERT INTO users VALUES('a23efe53-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','owner',0,0,'','[]',0,NULL,'2024-10-02 16:03:06.779156+02:00','api',0,''); INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 16:03:06.779156+02:00','api',0,''); @@ -52,4 +54,4 @@ INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','D INSERT INTO network_routers VALUES('ctc20ji7qv9ck2sebc80','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','cs1tnh0hhcjnqoiuebeg',NULL,0,0); INSERT INTO network_resources VALUES ('ctc4nci7qv9061u6ilfg','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Host','192.168.1.1'); INSERT INTO networks VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Test Network','Test Network'); -INSERT INTO peers VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','','','192.168.0.0','','','','','','','','','','','','','','','','','test','test','2023-01-01 00:00:00+00:00',0,0,0,'a23efe53-63fb-11ec-90d6-0242ac120003','',0,0,'2023-01-01 00:00:00+00:00','2023-01-01 00:00:00+00:00',0,'','','',0); +INSERT INTO peers VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','','','"192.168.0.0"','','','','','','','','','','','','','','','','','test','test','2023-01-01 00:00:00+00:00',0,0,0,'a23efe53-63fb-11ec-90d6-0242ac120003','',0,0,'2023-01-01 00:00:00+00:00','2023-01-01 00:00:00+00:00',0,'','','',0); diff --git a/management/server/testdata/store_with_expired_peers.sql b/management/server/testdata/store_with_expired_peers.sql index 5990a0625..f2ef56a23 100644 --- a/management/server/testdata/store_with_expired_peers.sql +++ b/management/server/testdata/store_with_expired_peers.sql @@ -30,7 +30,7 @@ INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62 INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-34653001fc3b','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,0,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); -INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.97"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost-1','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,''); INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,''); INSERT INTO installations VALUES(1,''); diff --git a/management/server/types/account.go b/management/server/types/account.go index 5a62ee4c6..f0887be07 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -82,11 +82,11 @@ type Account struct { DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"` // Settings is a dictionary of Account settings - Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` - + Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"` NetworkRouters []*routerTypes.NetworkRouter `gorm:"foreignKey:AccountID;references:id"` NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"` + Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"` } // Subclass used in gorm to only load network and not whole account @@ -104,6 +104,20 @@ type AccountSettings struct { Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` } +type AccountOnboarding struct { + AccountID string `gorm:"primaryKey"` + OnboardingFlowPending bool + SignupFormPending bool + CreatedAt time.Time + UpdatedAt time.Time +} + +// IsEqual compares two AccountOnboarding objects and returns true if they are equal +func (o AccountOnboarding) IsEqual(onboarding AccountOnboarding) bool { + return o.OnboardingFlowPending == onboarding.OnboardingFlowPending && + o.SignupFormPending == onboarding.SignupFormPending +} + // GetRoutesToSync returns the enabled routes for the peer ID and the routes // from the ACL peers that have distribution groups associated with the peer ID. // Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. @@ -866,6 +880,7 @@ func (a *Account) Copy() *Account { Networks: nets, NetworkRouters: networkRouters, NetworkResources: networkResources, + Onboarding: a.Onboarding, } } diff --git a/management/server/types/network.go b/management/server/types/network.go index 00082bb41..eb8415264 100644 --- a/management/server/types/network.go +++ b/management/server/types/network.go @@ -1,6 +1,7 @@ package types import ( + "encoding/binary" "math/rand" "net" "sync" @@ -161,24 +162,65 @@ func (n *Network) Copy() *Network { // This method considers already taken IPs and reuses IPs if there are gaps in takenIps // E.g. if ipNet=100.30.0.0/16 and takenIps=[100.30.0.1, 100.30.0.4] then the result would be 100.30.0.2 or 100.30.0.3 func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) { - takenIPMap := make(map[string]struct{}) - takenIPMap[ipNet.IP.String()] = struct{}{} + baseIP := ipToUint32(ipNet.IP.Mask(ipNet.Mask)) + totalIPs := uint32(1 << SubnetSize) + + taken := make(map[uint32]struct{}, len(takenIps)+1) + taken[baseIP] = struct{}{} // reserve network IP + taken[baseIP+totalIPs-1] = struct{}{} // reserve broadcast IP + for _, ip := range takenIps { - takenIPMap[ip.String()] = struct{}{} + taken[ipToUint32(ip)] = struct{}{} } - ips, _ := generateIPs(&ipNet, takenIPMap) + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + maxAttempts := (int(totalIPs) - len(taken)) / 100 - if len(ips) == 0 { - return nil, status.Errorf(status.PreconditionFailed, "failed allocating new IP for the ipNet %s - network is out of IPs", ipNet.String()) + for i := 0; i < maxAttempts; i++ { + offset := uint32(rng.Intn(int(totalIPs-2))) + 1 + candidate := baseIP + offset + if _, exists := taken[candidate]; !exists { + return uint32ToIP(candidate), nil + } } - // pick a random IP - s := rand.NewSource(time.Now().Unix()) - r := rand.New(s) - intn := r.Intn(len(ips)) + for offset := uint32(1); offset < totalIPs-1; offset++ { + candidate := baseIP + offset + if _, exists := taken[candidate]; !exists { + return uint32ToIP(candidate), nil + } + } - return ips[intn], nil + return nil, status.Errorf(status.PreconditionFailed, "network %s is out of IPs", ipNet.String()) +} + +func AllocateRandomPeerIP(ipNet net.IPNet) (net.IP, error) { + baseIP := ipToUint32(ipNet.IP.Mask(ipNet.Mask)) + + ones, bits := ipNet.Mask.Size() + hostBits := bits - ones + + totalIPs := uint32(1 << hostBits) + + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + offset := uint32(rng.Intn(int(totalIPs-2))) + 1 + + candidate := baseIP + offset + return uint32ToIP(candidate), nil +} + +func ipToUint32(ip net.IP) uint32 { + ip = ip.To4() + if len(ip) < 4 { + return 0 + } + return binary.BigEndian.Uint32(ip) +} + +func uint32ToIP(n uint32) net.IP { + ip := make(net.IP, 4) + binary.BigEndian.PutUint32(ip, n) + return ip } // generateIPs generates a list of all possible IPs of the given network excluding IPs specified in the exclusion list diff --git a/management/server/user.go b/management/server/user.go index a1f1c46d5..7d8382978 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -550,7 +550,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, ctx, transaction, groupsMap, accountID, initiatorUserID, initiatorUser, update, addIfNotExists, settings, ) if err != nil { - return fmt.Errorf("failed to process user update: %w", err) + return fmt.Errorf("failed to process update for user %s: %w", update.Id, err) } usersToSave = append(usersToSave, updatedUser) addUserEvents = append(addUserEvents, userEvents...)