diff --git a/management/server/account.go b/management/server/account.go index 5dd3c8b26..56787a31c 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -84,7 +84,7 @@ type AccountManager interface { UpdatePeer(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) GetNetworkMap(peerID string) (*NetworkMap, error) GetPeerNetwork(peerID string) (*Network, error) - AddPeer(setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, error) + AddPeer(setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) @@ -119,10 +119,9 @@ type AccountManager interface { GetDNSSettings(accountID string, userID string) (*DNSSettings, error) SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error) - GetPeerAppliedPostureChecks(peerKey string) ([]posture.Checks, error) UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error) - LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API - SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API + LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API + SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API GetAllConnectedPeers() (map[string]struct{}, error) HasConnectedChannel(peerID string) bool GetExternalCacheManager() ExternalCacheManager @@ -134,7 +133,7 @@ type AccountManager interface { UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error GroupValidation(accountId string, groups []string) (bool, error) GetValidatedPeers(account *Account) (map[string]struct{}, error) - SyncAndMarkPeer(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error) + SyncAndMarkPeer(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) CancelPeerRoutines(peer *nbpeer.Peer) error SyncPeerMeta(peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) @@ -1856,13 +1855,13 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla } } -func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error) { +func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { accountID, err := am.Store.GetAccountIDByPeerPubKey(peerPubKey) if err != nil { if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { - return nil, nil, status.Errorf(status.Unauthenticated, "peer not registered") + return nil, nil, nil, status.Errorf(status.Unauthenticated, "peer not registered") } - return nil, nil, err + return nil, nil, nil, err } unlock := am.Store.AcquireAccountReadLock(accountID) @@ -1870,12 +1869,12 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, meta nbpeer. account, err := am.Store.GetAccount(accountID) if err != nil { - return nil, nil, err + return nil, nil, nil, err } - peer, netMap, err := am.SyncPeer(PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account) + peer, netMap, postureChecks, err := am.SyncPeer(PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account) if err != nil { - return nil, nil, err + return nil, nil, nil, err } err = am.MarkPeerConnected(peerPubKey, true, realIP, account) @@ -1883,7 +1882,7 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, meta nbpeer. log.Warnf("failed marking peer as connected %s %v", peerPubKey, err) } - return peer, netMap, nil + return peer, netMap, postureChecks, nil } func (am *DefaultAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error { @@ -1926,7 +1925,7 @@ func (am *DefaultAccountManager) SyncPeerMeta(peerPubKey string, meta nbpeer.Pee return err } - _, _, err = am.SyncPeer(PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, account) + _, _, _, err = am.SyncPeer(PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, account) if err != nil { return mapError(err) } diff --git a/management/server/account_test.go b/management/server/account_test.go index 476a4f823..eaadb5633 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -85,7 +85,7 @@ func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Ac setupKey = key.Key } - _, _, err := manager.AddPeer(setupKey, userID, peer) + _, _, _, err := manager.AddPeer(setupKey, userID, peer) if err != nil { t.Error("expected to add new peer successfully after creating new account, but failed", err) } @@ -997,7 +997,7 @@ func TestAccountManager_AddPeer(t *testing.T) { expectedPeerKey := key.PublicKey().String() expectedSetupKey := setupKey.Key - peer, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, }) @@ -1065,7 +1065,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { expectedPeerKey := key.PublicKey().String() expectedUserID := userID - peer, _, err := manager.AddPeer("", userID, &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer("", userID, &nbpeer.Peer{ Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, }) @@ -1140,7 +1140,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { } expectedPeerKey := key.PublicKey().String() - peer, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, }) @@ -1315,7 +1315,7 @@ func TestAccountManager_DeletePeer(t *testing.T) { peerKey := key.PublicKey().String() - peer, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ Key: peerKey, Meta: nbpeer.PeerSystemMeta{Hostname: peerKey}, }) @@ -1662,7 +1662,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { key, err := wgtypes.GenerateKey() require.NoError(t, err, "unable to generate WireGuard key") - peer, _, err := manager.AddPeer("", userID, &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer("", userID, &nbpeer.Peer{ Key: key.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, LoginExpirationEnabled: true, @@ -1715,7 +1715,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. key, err := wgtypes.GenerateKey() require.NoError(t, err, "unable to generate WireGuard key") - _, _, err = manager.AddPeer("", userID, &nbpeer.Peer{ + _, _, _, err = manager.AddPeer("", userID, &nbpeer.Peer{ Key: key.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, LoginExpirationEnabled: true, @@ -1759,7 +1759,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test key, err := wgtypes.GenerateKey() require.NoError(t, err, "unable to generate WireGuard key") - _, _, err = manager.AddPeer("", userID, &nbpeer.Peer{ + _, _, _, err = manager.AddPeer("", userID, &nbpeer.Peer{ Key: key.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, LoginExpirationEnabled: true, diff --git a/management/server/dns_test.go b/management/server/dns_test.go index b5074e50c..a53789526 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -256,11 +256,11 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro return nil, err } - savedPeer1, _, err := am.AddPeer("", dnsAdminUserID, peer1) + savedPeer1, _, _, err := am.AddPeer("", dnsAdminUserID, peer1) if err != nil { return nil, err } - _, _, err = am.AddPeer("", dnsAdminUserID, peer2) + _, _, _, err = am.AddPeer("", dnsAdminUserID, peer2) if err != nil { return nil, err } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 5501c1925..bf0c3009a 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -139,12 +139,12 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi log.Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP) } - peer, netMap, err := s.accountManager.SyncAndMarkPeer(peerKey.String(), extractPeerMeta(syncReq.GetMeta()), realIP) + peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(peerKey.String(), extractPeerMeta(syncReq.GetMeta()), realIP) if err != nil { return mapError(err) } - err = s.sendInitialSync(peerKey, peer, netMap, srv) + err = s.sendInitialSync(peerKey, peer, netMap, postureChecks, srv) if err != nil { log.Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) return err @@ -376,7 +376,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p sshKey = loginReq.GetPeerKeys().GetSshPubKey() } - peer, netMap, err := s.accountManager.LoginPeer(PeerLogin{ + peer, netMap, postureChecks, err := s.accountManager.LoginPeer(PeerLogin{ WireGuardPubKey: peerKey.String(), SSHKey: string(sshKey), Meta: extractPeerMeta(loginReq.GetMeta()), @@ -398,7 +398,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p loginResp := &proto.LoginResponse{ WiretrusteeConfig: toWiretrusteeConfig(s.config, nil), PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()), - Checks: toProtocolChecks(s.accountManager, peerKey.String()), + Checks: toProtocolChecks(postureChecks), } encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp) if err != nil { @@ -520,7 +520,7 @@ func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePee return remotePeers } -func toSyncResponse(accountManager AccountManager, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string) *proto.SyncResponse { +func toSyncResponse(config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string, checks []*posture.Checks) *proto.SyncResponse { wtConfig := toWiretrusteeConfig(config, turnCredentials) pConfig := toPeerConfig(peer, networkMap.Network, dnsName) @@ -551,7 +551,7 @@ func toSyncResponse(accountManager AccountManager, config *Config, peer *nbpeer. FirewallRules: firewallRules, FirewallRulesIsEmpty: len(firewallRules) == 0, }, - Checks: toProtocolChecks(accountManager, peer.Key), + Checks: toProtocolChecks(checks), } } @@ -561,7 +561,7 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em } // sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization -func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, srv proto.ManagementService_SyncServer) error { +func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error { // make secret time based TURN credentials optional var turnCredentials *TURNCredentials if s.config.TURNConfig.TimeBasedCredentials { @@ -570,7 +570,7 @@ func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, net } else { turnCredentials = nil } - plainResp := toSyncResponse(s.accountManager, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain()) + plainResp := toSyncResponse(s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain(), postureChecks) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) if err != nil { @@ -715,15 +715,9 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) return &proto.Empty{}, nil } -// toProtocolChecks returns posture checks for the peer that needs to be evaluated on the client side. -func toProtocolChecks(accountManager AccountManager, peerKey string) []*proto.Checks { - postureChecks, err := accountManager.GetPeerAppliedPostureChecks(peerKey) - if err != nil { - log.Errorf("failed getting peer's: %s posture checks: %v", peerKey, err) - return nil - } - - protoChecks := make([]*proto.Checks, 0) +// toProtocolChecks converts posture checks to protocol checks. +func toProtocolChecks(postureChecks []*posture.Checks) []*proto.Checks { + protoChecks := make([]*proto.Checks, 0, len(postureChecks)) for _, postureCheck := range postureChecks { protoChecks = append(protoChecks, toProtocolCheck(postureCheck)) } @@ -732,7 +726,7 @@ func toProtocolChecks(accountManager AccountManager, peerKey string) []*proto.Ch } // toProtocolCheck converts a posture.Checks to a proto.Checks. -func toProtocolCheck(postureCheck posture.Checks) *proto.Checks { +func toProtocolCheck(postureCheck *posture.Checks) *proto.Checks { protoCheck := &proto.Checks{} if check := postureCheck.Checks.ProcessCheck; check != nil { diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index a915dca64..669fab861 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -30,11 +30,11 @@ type MockAccountManager struct { ListUsersFunc func(accountID string) ([]*server.User, error) GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error) MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error - SyncAndMarkPeerFunc func(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error) + SyncAndMarkPeerFunc func(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) DeletePeerFunc func(accountID, peerKey, userID string) error GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error) GetPeerNetworkFunc func(peerKey string) (*server.Network, error) - AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error) + AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error) GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error) GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error) @@ -83,10 +83,9 @@ type MockAccountManager struct { GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error) SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error) - GetPeerAppliedPostureChecksFunc func(peerKey string) ([]posture.Checks, error) UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error) - LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error) - SyncPeerFunc func(sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, error) + LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) + SyncPeerFunc func(sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error GetAllConnectedPeersFunc func() (map[string]struct{}, error) HasConnectedChannelFunc func(peerID string) bool @@ -102,11 +101,11 @@ type MockAccountManager struct { FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) } -func (am *MockAccountManager) SyncAndMarkPeer(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error) { +func (am *MockAccountManager) SyncAndMarkPeer(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { if am.SyncAndMarkPeerFunc != nil { return am.SyncAndMarkPeerFunc(peerPubKey, meta, realIP) } - return nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") + return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } func (am *MockAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error { @@ -282,11 +281,11 @@ func (am *MockAccountManager) AddPeer( setupKey string, userId string, peer *nbpeer.Peer, -) (*nbpeer.Peer, *server.NetworkMap, error) { +) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { if am.AddPeerFunc != nil { return am.AddPeerFunc(setupKey, userId, peer) } - return nil, nil, status.Errorf(codes.Unimplemented, "method AddPeer is not implemented") + return nil, nil, nil, status.Errorf(codes.Unimplemented, "method AddPeer is not implemented") } // GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface @@ -627,14 +626,6 @@ func (am *MockAccountManager) GetPeer(accountID, peerID, userID string) (*nbpeer return nil, status.Errorf(codes.Unimplemented, "method GetPeer is not implemented") } -// GetPeerAppliedPostureChecks mocks GetPeerAppliedPostureChecks of the AccountManager interface -func (am *MockAccountManager) GetPeerAppliedPostureChecks(peerKey string) ([]posture.Checks, error) { - if am.GetPeerAppliedPostureChecksFunc != nil { - return am.GetPeerAppliedPostureChecksFunc(peerKey) - } - return nil, status.Errorf(codes.Unimplemented, "method GetPeerAppliedPostureChecks is not implemented") -} - // UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface func (am *MockAccountManager) UpdateAccountSettings(accountID, userID string, newSettings *server.Settings) (*server.Account, error) { if am.UpdateAccountSettingsFunc != nil { @@ -644,19 +635,19 @@ func (am *MockAccountManager) UpdateAccountSettings(accountID, userID string, ne } // LoginPeer mocks LoginPeer of the AccountManager interface -func (am *MockAccountManager) LoginPeer(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error) { +func (am *MockAccountManager) LoginPeer(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { if am.LoginPeerFunc != nil { return am.LoginPeerFunc(login) } - return nil, nil, status.Errorf(codes.Unimplemented, "method LoginPeer is not implemented") + return nil, nil, nil, status.Errorf(codes.Unimplemented, "method LoginPeer is not implemented") } // SyncPeer mocks SyncPeer of the AccountManager interface -func (am *MockAccountManager) SyncPeer(sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, error) { +func (am *MockAccountManager) SyncPeer(sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { if am.SyncPeerFunc != nil { return am.SyncPeerFunc(sync, account) } - return nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented") + return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented") } // GetAllConnectedPeers mocks GetAllConnectedPeers of the AccountManager interface diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index f2921532d..4e07943b3 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -851,11 +851,11 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error return nil, err } - _, _, err = am.AddPeer("", userID, peer1) + _, _, _, err = am.AddPeer("", userID, peer1) if err != nil { return nil, err } - _, _, err = am.AddPeer("", userID, peer2) + _, _, _, err = am.AddPeer("", userID, peer2) if err != nil { return nil, err } diff --git a/management/server/peer.go b/management/server/peer.go index 6987cff3e..fa482eec0 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -6,6 +6,7 @@ import ( "strings" "time" + "github.com/netbirdio/netbird/management/server/posture" "github.com/rs/xid" log "github.com/sirupsen/logrus" @@ -333,10 +334,10 @@ func (am *DefaultAccountManager) GetPeerNetwork(peerID string) (*Network, error) // to it. We also add the User ID to the peer metadata to identify registrant. If no userID provided, then fail with status.PermissionDenied // Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused). // The peer property is just a placeholder for the Peer properties to pass further -func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, error) { +func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { if setupKey == "" && userID == "" { // no auth method provided => reject access - return nil, nil, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login") + return nil, nil, nil, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login") } upperKey := strings.ToUpper(setupKey) @@ -350,7 +351,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P accountID, err = am.Store.GetAccountIDBySetupKey(setupKey) } if err != nil { - return nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found") + return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found") } unlock := am.Store.AcquireAccountWriteLock(accountID) @@ -364,7 +365,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P // ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account) account, err = am.Store.GetAccount(accountID) if err != nil { - return nil, nil, err + return nil, nil, nil, err } if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" { @@ -383,7 +384,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P // The connecting peer should be able to recover with a retry. _, err = account.FindPeerByPubKey(peer.Key) if err == nil { - return nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered") + return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered") } opEvent := &activity.Event{ @@ -397,11 +398,11 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P // validate the setup key if adding with a key sk, err := account.FindSetupKey(upperKey) if err != nil { - return nil, nil, err + return nil, nil, nil, err } if !sk.IsValid() { - return nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid") + return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid") } account.SetupKeys[sk.Key] = sk.IncrementUsage() @@ -419,14 +420,14 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P newLabel, err := getPeerHostLabel(peer.Meta.Hostname, existingLabels) if err != nil { - return nil, nil, err + return nil, nil, nil, err } peer.DNSLabel = newLabel network := account.Network nextIp, err := AllocatePeerIP(network.Net, takenIps) if err != nil { - return nil, nil, err + return nil, nil, nil, err } registrationTime := time.Now().UTC() @@ -453,7 +454,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P // add peer to 'All' group group, err := account.GetGroupAll() if err != nil { - return nil, nil, err + return nil, nil, nil, err } group.Peers = append(group.Peers, newPeer.ID) @@ -461,12 +462,12 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P if addedByUser { groupsToAdd, err = account.getUserGroups(userID) if err != nil { - return nil, nil, err + return nil, nil, nil, err } } else { groupsToAdd, err = account.getSetupKeyGroups(upperKey) if err != nil { - return nil, nil, err + return nil, nil, nil, err } } @@ -483,7 +484,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P if addedByUser { user, err := account.FindUser(userID) if err != nil { - return nil, nil, status.Errorf(status.Internal, "couldn't find user") + return nil, nil, nil, status.Errorf(status.Internal, "couldn't find user") } user.updateLastLogin(newPeer.LastLogin) } @@ -492,7 +493,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P account.Network.IncSerial() err = am.Store.SaveAccount(account) if err != nil { - return nil, nil, err + return nil, nil, nil, err } // Account is saved, we can release the lock @@ -511,33 +512,35 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P approvedPeersMap, err := am.GetValidatedPeers(account) if err != nil { - return nil, nil, err + return nil, nil, nil, err } + + postureChecks := am.getPeerPostureChecks(account, peer) networkMap := account.GetPeerNetworkMap(newPeer.ID, am.dnsDomain, approvedPeersMap) - return newPeer, networkMap, nil + return newPeer, networkMap, postureChecks, nil } // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible -func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, error) { +func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey) if err != nil { - return nil, nil, status.NewPeerNotRegisteredError() + return nil, nil, nil, status.NewPeerNotRegisteredError() } err = checkIfPeerOwnerIsBlocked(peer, account) if err != nil { - return nil, nil, err + return nil, nil, nil, err } if peerLoginExpired(peer, account.Settings) { - return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more") + return nil, nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more") } peer, updated := updatePeerMeta(peer, sync.Meta, account) if updated { err = am.Store.SaveAccount(account) if err != nil { - return nil, nil, err + return nil, nil, nil, err } if sync.UpdateAccountPeers { @@ -547,14 +550,16 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbp peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) if err != nil { - return nil, nil, err + return nil, nil, nil, err } + var postureChecks []*posture.Checks + if peerNotValid { emptyMap := &NetworkMap{ Network: account.Network.Copy(), } - return peer, emptyMap, nil + return peer, emptyMap, postureChecks, nil } if isStatusChanged { @@ -563,14 +568,16 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbp validPeersMap, err := am.GetValidatedPeers(account) if err != nil { - return nil, nil, err + return nil, nil, nil, err } - return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, validPeersMap), nil + postureChecks = am.getPeerPostureChecks(account, peer) + + return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, validPeersMap), postureChecks, nil } // LoginPeer logs in or registers a peer. // If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so. -func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) { +func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { accountID, err := am.Store.GetAccountIDByPeerPubKey(login.WireGuardPubKey) if err != nil { if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { @@ -596,18 +603,19 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw return am.AddPeer(login.SetupKey, login.UserID, newPeer) } + log.Errorf("failed while logging in peer %s: %v", login.WireGuardPubKey, err) - return nil, nil, status.Errorf(status.Internal, "failed while logging in peer") + return nil, nil, nil, status.Errorf(status.Internal, "failed while logging in peer") } peer, err := am.Store.GetPeerByPeerPubKey(login.WireGuardPubKey) if err != nil { - return nil, nil, status.NewPeerNotRegisteredError() + return nil, nil, nil, status.NewPeerNotRegisteredError() } accSettings, err := am.Store.GetAccountSettings(accountID) if err != nil { - return nil, nil, status.Errorf(status.Internal, "failed to get account settings: %s", err) + return nil, nil, nil, status.Errorf(status.Internal, "failed to get account settings: %s", err) } var isWriteLock bool @@ -617,7 +625,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw switch { case expired: if err := checkAuth(login.UserID, peer); err != nil { - return nil, nil, err + return nil, nil, nil, err } isWriteLock = true log.Debugf("peer login expired, acquiring write lock") @@ -647,17 +655,17 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw // fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies account, err := am.Store.GetAccount(accountID) if err != nil { - return nil, nil, err + return nil, nil, nil, err } peer, err = account.FindPeerByPubKey(login.WireGuardPubKey) if err != nil { - return nil, nil, status.NewPeerNotRegisteredError() + return nil, nil, nil, status.NewPeerNotRegisteredError() } err = checkIfPeerOwnerIsBlocked(peer, account) if err != nil { - return nil, nil, err + return nil, nil, nil, err } // this flag prevents unnecessary calls to the persistent store. @@ -666,7 +674,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw if peerLoginExpired(peer, account.Settings) { err = checkAuth(login.UserID, peer) if err != nil { - return nil, nil, err + return nil, nil, nil, err } // If peer was expired before and if it reached this point, it is re-authenticated. // UserID is present, meaning that JWT validation passed successfully in the API layer. @@ -677,7 +685,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw // sync user last login with peer last login user, err := account.FindUser(login.UserID) if err != nil { - return nil, nil, status.Errorf(status.Internal, "couldn't find user") + return nil, nil, nil, status.Errorf(status.Internal, "couldn't find user") } user.updateLastLogin(peer.LastLogin) @@ -686,7 +694,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) if err != nil { - return nil, nil, err + return nil, nil, nil, err } peer, updated := updatePeerMeta(peer, login.Meta, account) if updated { @@ -695,17 +703,17 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw peer, err = am.checkAndUpdatePeerSSHKey(peer, account, login.SSHKey) if err != nil { - return nil, nil, err + return nil, nil, nil, err } if shouldStoreAccount { if !isWriteLock { log.Errorf("account %s should be stored but is not write locked", accountID) - return nil, nil, status.Errorf(status.Internal, "account should be stored but is not write locked") + return nil, nil, nil, status.Errorf(status.Internal, "account should be stored but is not write locked") } err = am.Store.SaveAccount(account) if err != nil { - return nil, nil, err + return nil, nil, nil, err } } unlock() @@ -715,19 +723,22 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw am.updateAccountPeers(account) } + var postureChecks []*posture.Checks + if isRequiresApproval { emptyMap := &NetworkMap{ Network: account.Network.Copy(), } - return peer, emptyMap, nil + return peer, emptyMap, postureChecks, nil } approvedPeersMap, err := am.GetValidatedPeers(account) if err != nil { - return nil, nil, err + return nil, nil, nil, err } + postureChecks = am.getPeerPostureChecks(account, peer) - return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap), nil + return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap), postureChecks, nil } func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error { @@ -916,8 +927,10 @@ func (am *DefaultAccountManager) updateAccountPeers(account *Account) { log.Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) continue } + + postureChecks := am.getPeerPostureChecks(account, peer) remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap) - update := toSyncResponse(am, nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain()) + update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks) am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update}) } } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 6063cc2a7..c5305cf5b 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -92,7 +92,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { return } - peer1, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, }) @@ -106,7 +106,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { t.Fatal(err) return } - _, _, err = manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, }) @@ -165,7 +165,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { return } - peer1, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, }) @@ -179,7 +179,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { t.Fatal(err) return } - peer2, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ + peer2, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, }) @@ -341,7 +341,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) { return } - peer1, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, }) @@ -355,7 +355,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) { t.Fatal(err) return } - _, _, err = manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, }) @@ -413,7 +413,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { return } - peer1, _, err := manager.AddPeer("", someUser, &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer("", someUser, &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, }) @@ -429,7 +429,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { } // the second peer added with a setup key - peer2, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ + peer2, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, }) @@ -601,7 +601,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { return } - _, _, err = manager.AddPeer("", someUser, &nbpeer.Peer{ + _, _, _, err = manager.AddPeer("", someUser, &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, }) @@ -610,7 +610,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { return } - _, _, err = manager.AddPeer("", adminUser, &nbpeer.Peer{ + _, _, _, err = manager.AddPeer("", adminUser, &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, }) diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 873f8da59..d525482b7 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -7,7 +7,6 @@ import ( nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" - log "github.com/sirupsen/logrus" ) const ( @@ -185,36 +184,14 @@ func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureCh return postureChecks, nil } -// GetPeerAppliedPostureChecks returns posture checks that are applied to the peer. -func (am *DefaultAccountManager) GetPeerAppliedPostureChecks(peerKey string) ([]posture.Checks, error) { - account, err := am.Store.GetAccountByPeerPubKey(peerKey) - if err != nil { - log.Errorf("failed while getting peer %s: %v", peerKey, err) - return nil, err - } - - peer, err := account.FindPeerByPubKey(peerKey) - if err != nil { - return nil, status.Errorf(status.NotFound, "peer is not registered") - } - if peer == nil { - return nil, nil - } - - peerPostureChecks := am.collectPeerPostureChecks(account, peer) - - postureChecksList := make([]posture.Checks, 0, len(peerPostureChecks)) - for _, check := range peerPostureChecks { - postureChecksList = append(postureChecksList, check) - } - - return postureChecksList, nil -} - -// collectPeerPostureChecks collects the posture checks applied for a given peer. -func (am *DefaultAccountManager) collectPeerPostureChecks(account *Account, peer *nbpeer.Peer) map[string]posture.Checks { +// getPeerPostureChecks returns the posture checks applied for a given peer. +func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peer *nbpeer.Peer) []*posture.Checks { peerPostureChecks := make(map[string]posture.Checks) + if len(account.PostureChecks) == 0 { + return nil + } + for _, policy := range account.Policies { if !policy.Enabled { continue @@ -225,7 +202,13 @@ func (am *DefaultAccountManager) collectPeerPostureChecks(account *Account, peer } } - return peerPostureChecks + postureChecksList := make([]*posture.Checks, 0, len(peerPostureChecks)) + for _, check := range peerPostureChecks { + checkCopy := check + postureChecksList = append(postureChecksList, &checkCopy) + } + + return postureChecksList } // isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.