diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index ad6b34c5f..0b9326fbc 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -184,8 +184,14 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S realIP := getRealIP(ctx) sRealIP := realIP.String() peerMeta := extractPeerMeta(ctx, syncReq.GetMeta()) + userID, err := s.accountManager.GetUserIDByPeerKey(ctx, peerKey.String()) + if err != nil { + s.syncSem.Add(-1) + return mapError(ctx, err) + } + metahashed := metaHash(peerMeta, sRealIP) - if !s.loginFilter.allowLogin(peerKey.String(), metahashed) { + if userID == "" && !s.loginFilter.allowLogin(peerKey.String(), metahashed) { if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountSyncRequestBlocked() } diff --git a/management/server/account.go b/management/server/account.go index 405a3c0f6..52dcc567e 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -2156,3 +2156,7 @@ func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transacti return nil } + +func (am *DefaultAccountManager) GetUserIDByPeerKey(ctx context.Context, peerKey string) (string, error) { + return am.Store.GetUserIDByPeerKey(ctx, store.LockingStrengthNone, peerKey) +} diff --git a/management/server/account/manager.go b/management/server/account/manager.go index b5921ec7a..f0b7c3857 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -123,4 +123,5 @@ type Manager interface { UpdateToPrimaryAccount(ctx context.Context, accountId string) error GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) + GetUserIDByPeerKey(ctx context.Context, peerKey string) (string, error) } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 928098dbe..0d7d2bc3d 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -2,11 +2,12 @@ package mock_server import ( "context" - "github.com/netbirdio/netbird/shared/auth" "net" "net/netip" "time" + "github.com/netbirdio/netbird/shared/auth" + "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -988,3 +989,7 @@ func (am *MockAccountManager) RecalculateNetworkMapCache(ctx context.Context, ac } return nil } + +func (am *MockAccountManager) GetUserIDByPeerKey(ctx context.Context, peerKey string) (string, error) { + return "something", nil +} diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index d2220d4b4..73565a462 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -4082,3 +4082,21 @@ func (s *SqlStore) GetPeersByGroupIDs(ctx context.Context, accountID string, gro return peers, nil } + +func (s *SqlStore) GetUserIDByPeerKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (string, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var userID string + result := tx.Model(&nbpeer.Peer{}). + Select("user_id"). + Take(&userID, GetKeyQueryCondition(s), peerKey) + + if result.Error != nil { + return "", status.Errorf(status.Internal, "failed to get user ID by peer key") + } + + return userID, nil +} diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 2e2623910..d63d624de 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -3718,6 +3718,69 @@ func TestSqlStore_GetPeersByGroupIDs(t *testing.T) { } } +func TestSqlStore_GetUserIDByPeerKey(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + userID := "test-user-123" + peerKey := "peer-key-abc" + + peer := &nbpeer.Peer{ + ID: "test-peer-1", + Key: peerKey, + AccountID: existingAccountID, + UserID: userID, + IP: net.IP{10, 0, 0, 1}, + DNSLabel: "test-peer-1", + } + + err = store.AddPeerToAccount(context.Background(), peer) + require.NoError(t, err) + + retrievedUserID, err := store.GetUserIDByPeerKey(context.Background(), LockingStrengthNone, peerKey) + require.NoError(t, err) + assert.Equal(t, userID, retrievedUserID) +} + +func TestSqlStore_GetUserIDByPeerKey_NotFound(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + nonExistentPeerKey := "non-existent-peer-key" + + userID, err := store.GetUserIDByPeerKey(context.Background(), LockingStrengthNone, nonExistentPeerKey) + require.Error(t, err) + assert.Equal(t, "", userID) +} + +func TestSqlStore_GetUserIDByPeerKey_NoUserID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + peerKey := "peer-key-abc" + + peer := &nbpeer.Peer{ + ID: "test-peer-1", + Key: peerKey, + AccountID: existingAccountID, + UserID: "", + IP: net.IP{10, 0, 0, 1}, + DNSLabel: "test-peer-1", + } + + err = store.AddPeerToAccount(context.Background(), peer) + require.NoError(t, err) + + retrievedUserID, err := store.GetUserIDByPeerKey(context.Background(), LockingStrengthNone, peerKey) + require.NoError(t, err) + assert.Equal(t, "", retrievedUserID) +} + func TestSqlStore_ApproveAccountPeers(t *testing.T) { runTestForAllEngines(t, "", func(t *testing.T, store Store) { accountID := "test-account" diff --git a/management/server/store/store.go b/management/server/store/store.go index 0ec7949f9..dbe135406 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -204,6 +204,7 @@ type Store interface { MarkAccountPrimary(ctx context.Context, accountID string) error UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error GetPolicyRulesByResourceID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) ([]*types.PolicyRule, error) + GetUserIDByPeerKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (string, error) } const (