diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go index 022ea774c..0f202da0b 100644 --- a/management/internals/controllers/network_map/controller/controller.go +++ b/management/internals/controllers/network_map/controller/controller.go @@ -394,23 +394,26 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str return nil } -func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { - if isRequiresApproval { - network, err := c.repo.GetAccountNetwork(ctx, accountID) - if err != nil { - return nil, nil, nil, 0, err - } +func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer, clientSerial uint64) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { + network, err := c.repo.GetAccountNetwork(ctx, accountID) + if err != nil { + return nil, nil, nil, 0, err + } + if isRequiresApproval { emptyMap := &types.NetworkMap{ Network: network.Copy(), } return peer, emptyMap, nil, 0, nil } - var ( - account *types.Account - err error - ) + if clientSerial > 0 && clientSerial == network.CurrentSerial() { + log.WithContext(ctx).Debugf("client serial %d matches current serial, skipping network map calculation", clientSerial) + return peer, nil, nil, 0, nil + } + + var account *types.Account + if c.experimentalNetworkMap(accountID) { account = c.getAccountFromHolderOrInit(accountID) } else { diff --git a/management/internals/controllers/network_map/interface.go b/management/internals/controllers/network_map/interface.go index b1de7d017..effb95075 100644 --- a/management/internals/controllers/network_map/interface.go +++ b/management/internals/controllers/network_map/interface.go @@ -24,7 +24,7 @@ type Controller interface { UpdateAccountPeers(ctx context.Context, accountID string) error UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error BufferUpdateAccountPeers(ctx context.Context, accountID string) error - GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) + GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer, clientSerial uint64) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) GetDNSDomain(settings *types.Settings) string StartWarmup(context.Context) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) diff --git a/management/internals/controllers/network_map/interface_mock.go b/management/internals/controllers/network_map/interface_mock.go index 5a98eefa8..c6a313523 100644 --- a/management/internals/controllers/network_map/interface_mock.go +++ b/management/internals/controllers/network_map/interface_mock.go @@ -113,9 +113,9 @@ func (mr *MockControllerMockRecorder) GetNetworkMap(ctx, peerID any) *gomock.Cal } // GetValidatedPeerWithMap mocks base method. -func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { +func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer, clientSerial uint64) (*peer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, p) + ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, p, clientSerial) ret0, _ := ret[0].(*peer.Peer) ret1, _ := ret[1].(*types.NetworkMap) ret2, _ := ret[2].([]*posture.Checks) @@ -125,9 +125,9 @@ func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequires } // GetValidatedPeerWithMap indicates an expected call of GetValidatedPeerWithMap. -func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, p any) *gomock.Call { +func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, p, clientSerial any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p, clientSerial) } // OnPeerConnected mocks base method. diff --git a/management/internals/shared/grpc/conversion.go b/management/internals/shared/grpc/conversion.go index 2b15fe4b8..ded890125 100644 --- a/management/internals/shared/grpc/conversion.go +++ b/management/internals/shared/grpc/conversion.go @@ -104,6 +104,20 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set } } +// ToSkipSyncResponse creates a minimal SyncResponse when the client already has the latest network map. +func ToSkipSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, checks []*posture.Checks, extraSettings *types.ExtraSettings, peerGroups []string) *proto.SyncResponse { + response := &proto.SyncResponse{ + SkipNetworkMapUpdate: true, + Checks: toProtocolChecks(ctx, checks), + } + + nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings) + extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings) + response.NetbirdConfig = extendedConfig + + return response +} + func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse { response := &proto.SyncResponse{ PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig), diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 462e2e6eb..ef2895f91 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -239,7 +239,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S metahash := metaHash(peerMeta, realIP.String()) s.loginFilter.addLogin(peerKey.String(), metahash) - peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP) + peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, syncReq.GetNetworkMapSerial()) if err != nil { log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err) s.syncSem.Add(-1) @@ -702,7 +702,12 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer return status.Errorf(codes.Internal, "failed to get peer groups %s", err) } - plainResp := ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort) + var plainResp *proto.SyncResponse + if networkMap == nil { + plainResp = ToSkipSyncResponse(ctx, s.config, peer, turnToken, relayToken, postureChecks, settings.Extra, peerGroups) + } else { + plainResp = ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort) + } key, err := s.secretsManager.GetWGKey() if err != nil { diff --git a/management/server/account.go b/management/server/account.go index a9becc4b6..e69cf46e1 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1617,8 +1617,8 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain } -func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { - peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID) +func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, clientSerial uint64) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { + peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, NetworkMapSerial: clientSerial}, accountID) if err != nil { return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err) } diff --git a/management/server/account/manager.go b/management/server/account/manager.go index b5921ec7a..c1107fed4 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -107,7 +107,7 @@ type Manager interface { UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) - SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) + SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, clientSerial uint64) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) diff --git a/management/server/account_test.go b/management/server/account_test.go index 7f125e3a0..53767d506 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3144,7 +3144,7 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) { b.ResetTimer() start := time.Now() for i := 0; i < b.N; i++ { - _, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}) + _, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}, 0) assert.NoError(b, err) } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 928098dbe..55f184b38 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -37,7 +37,7 @@ type MockAccountManager struct { ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error) GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error - SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) + SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, clientSerial uint64) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error) GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error) @@ -177,9 +177,9 @@ func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, use return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented") } -func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { +func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, clientSerial uint64) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { if am.SyncAndMarkPeerFunc != nil { - return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP) + return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP, clientSerial) } return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } diff --git a/management/server/peer.go b/management/server/peer.go index 49f5bf2a5..16bdc2871 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -645,7 +645,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err) } - p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, false, accountID, newPeer) + p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, false, accountID, newPeer, 0) return p, nmap, pc, err } @@ -731,7 +731,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy } } - return am.networkMapController.GetValidatedPeerWithMap(ctx, peerNotValid, accountID, peer) + return am.networkMapController.GetValidatedPeerWithMap(ctx, peerNotValid, accountID, peer, sync.NetworkMapSerial) } func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login types.PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { @@ -859,7 +859,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer } } - p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer) + p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer, 0) return p, nmap, pc, err } diff --git a/management/server/types/peer.go b/management/server/types/peer.go index 15d343793..4b9769dfc 100644 --- a/management/server/types/peer.go +++ b/management/server/types/peer.go @@ -15,6 +15,9 @@ type PeerSync struct { // UpdateAccountPeers indicate updating account peers, // which occurs when the peer's metadata is updated UpdateAccountPeers bool + // NetworkMapSerial is the last known network map serial number on the client. + // Used to skip network map recalculation if client already has the latest. + NetworkMapSerial uint64 } // PeerLogin used as a data object between the gRPC API and Manager on Login request.