fix tests

This commit is contained in:
pascal
2026-02-24 10:32:12 +01:00
parent 33cda4d10c
commit b450fa2cca
9 changed files with 170 additions and 93 deletions

View File

@@ -165,7 +165,7 @@ func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
return Create(s, func() *nbgrpc.ProxyServiceServer { return Create(s, func() *nbgrpc.ProxyServiceServer {
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager()) proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
s.AfterInit(func(s *BaseServer) { s.AfterInit(func(s *BaseServer) {
proxyService.SetProxyManager(s.ServiceManager()) proxyService.SetServiceManager(s.ServiceManager())
proxyService.SetProxyController(s.ServiceProxyController()) proxyService.SetProxyController(s.ServiceProxyController())
}) })
return proxyService return proxyService

View File

@@ -157,7 +157,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
// Eagerly create the gRPC server so that all AfterInit hooks are registered // Eagerly create the gRPC server so that all AfterInit hooks are registered
// before we iterate them. Lazy creation after the loop would miss hooks // before we iterate them. Lazy creation after the loop would miss hooks
// registered during GRPCServer() construction (e.g., SetProxyManager). // registered during GRPCServer() construction (e.g., SetServiceManager).
s.GRPCServer() s.GRPCServer()
for _, fn := range s.afterInit { for _, fn := range s.afterInit {

View File

@@ -168,7 +168,7 @@ func (s *ProxyServiceServer) Close() {
s.pkceCleanupCancel() s.pkceCleanupCancel()
} }
func (s *ProxyServiceServer) SetProxyManager(manager rpservice.Manager) { func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) {
s.serviceManager = manager s.serviceManager = manager
} }

View File

@@ -8,12 +8,12 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )
type mockReverseProxyManager struct { type mockReverseProxyManager struct {
proxiesByAccount map[string][]*reverseproxy.Service proxiesByAccount map[string][]*service.Service
err error err error
} }
@@ -21,31 +21,31 @@ func (m *mockReverseProxyManager) DeleteAllServices(ctx context.Context, account
return nil return nil
} }
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) { func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
if m.err != nil { if m.err != nil {
return nil, m.err return nil, m.err
} }
return m.proxiesByAccount[accountID], nil return m.proxiesByAccount[accountID], nil
} }
func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) { func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
return nil, nil return nil, nil
} }
func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) { func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
return []*reverseproxy.Service{}, nil return []*service.Service{}, nil
} }
func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*reverseproxy.Service, error) { func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*service.Service, error) {
return &reverseproxy.Service{}, nil return &service.Service{}, nil
} }
func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) { func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *service.Service) (*service.Service, error) {
return &reverseproxy.Service{}, nil return &service.Service{}, nil
} }
func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) { func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *service.Service) (*service.Service, error) {
return &reverseproxy.Service{}, nil return &service.Service{}, nil
} }
func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, userID, reverseProxyID string) error { func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, userID, reverseProxyID string) error {
@@ -56,7 +56,7 @@ func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, ac
return nil return nil
} }
func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error { func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status service.Status) error {
return nil return nil
} }
@@ -68,8 +68,8 @@ func (m *mockReverseProxyManager) ReloadService(ctx context.Context, accountID,
return nil return nil
} }
func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*reverseproxy.Service, error) { func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*service.Service, error) {
return &reverseproxy.Service{}, nil return &service.Service{}, nil
} }
func (m *mockReverseProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) { func (m *mockReverseProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
@@ -97,7 +97,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name string name string
domain string domain string
userID string userID string
proxiesByAccount map[string][]*reverseproxy.Service proxiesByAccount map[string][]*service.Service
users map[string]*types.User users map[string]*types.User
proxyErr error proxyErr error
userErr error userErr error
@@ -108,7 +108,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "user not found", name: "user not found",
domain: "app.example.com", domain: "app.example.com",
userID: "unknown-user", userID: "unknown-user",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1"}}, "account1": {{Domain: "app.example.com", AccountID: "account1"}},
}, },
users: map[string]*types.User{}, users: map[string]*types.User{},
@@ -119,7 +119,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "proxy not found in user's account", name: "proxy not found in user's account",
domain: "app.example.com", domain: "app.example.com",
userID: "user1", userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{}, proxiesByAccount: map[string][]*service.Service{},
users: map[string]*types.User{ users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"}, "user1": {Id: "user1", AccountID: "account1"},
}, },
@@ -130,7 +130,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "proxy exists in different account - not accessible", name: "proxy exists in different account - not accessible",
domain: "app.example.com", domain: "app.example.com",
userID: "user1", userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account2": {{Domain: "app.example.com", AccountID: "account2"}}, "account2": {{Domain: "app.example.com", AccountID: "account2"}},
}, },
users: map[string]*types.User{ users: map[string]*types.User{
@@ -143,8 +143,8 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "no bearer auth configured - same account allows access", name: "no bearer auth configured - same account allows access",
domain: "app.example.com", domain: "app.example.com",
userID: "user1", userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}}}, "account1": {{Domain: "app.example.com", AccountID: "account1", Auth: service.AuthConfig{}}},
}, },
users: map[string]*types.User{ users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"}, "user1": {Id: "user1", AccountID: "account1"},
@@ -155,12 +155,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "bearer auth disabled - same account allows access", name: "bearer auth disabled - same account allows access",
domain: "app.example.com", domain: "app.example.com",
userID: "user1", userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account1": {{ "account1": {{
Domain: "app.example.com", Domain: "app.example.com",
AccountID: "account1", AccountID: "account1",
Auth: reverseproxy.AuthConfig{ Auth: service.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{Enabled: false}, BearerAuth: &service.BearerAuthConfig{Enabled: false},
}, },
}}, }},
}, },
@@ -173,12 +173,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "bearer auth enabled but no groups configured - same account allows access", name: "bearer auth enabled but no groups configured - same account allows access",
domain: "app.example.com", domain: "app.example.com",
userID: "user1", userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account1": {{ "account1": {{
Domain: "app.example.com", Domain: "app.example.com",
AccountID: "account1", AccountID: "account1",
Auth: reverseproxy.AuthConfig{ Auth: service.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{ BearerAuth: &service.BearerAuthConfig{
Enabled: true, Enabled: true,
DistributionGroups: []string{}, DistributionGroups: []string{},
}, },
@@ -194,12 +194,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "user not in allowed groups", name: "user not in allowed groups",
domain: "app.example.com", domain: "app.example.com",
userID: "user1", userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account1": {{ "account1": {{
Domain: "app.example.com", Domain: "app.example.com",
AccountID: "account1", AccountID: "account1",
Auth: reverseproxy.AuthConfig{ Auth: service.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{ BearerAuth: &service.BearerAuthConfig{
Enabled: true, Enabled: true,
DistributionGroups: []string{"group1", "group2"}, DistributionGroups: []string{"group1", "group2"},
}, },
@@ -216,12 +216,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "user in one of the allowed groups - allow access", name: "user in one of the allowed groups - allow access",
domain: "app.example.com", domain: "app.example.com",
userID: "user1", userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account1": {{ "account1": {{
Domain: "app.example.com", Domain: "app.example.com",
AccountID: "account1", AccountID: "account1",
Auth: reverseproxy.AuthConfig{ Auth: service.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{ BearerAuth: &service.BearerAuthConfig{
Enabled: true, Enabled: true,
DistributionGroups: []string{"group1", "group2"}, DistributionGroups: []string{"group1", "group2"},
}, },
@@ -237,12 +237,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "user in all allowed groups - allow access", name: "user in all allowed groups - allow access",
domain: "app.example.com", domain: "app.example.com",
userID: "user1", userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account1": {{ "account1": {{
Domain: "app.example.com", Domain: "app.example.com",
AccountID: "account1", AccountID: "account1",
Auth: reverseproxy.AuthConfig{ Auth: service.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{ BearerAuth: &service.BearerAuthConfig{
Enabled: true, Enabled: true,
DistributionGroups: []string{"group1", "group2"}, DistributionGroups: []string{"group1", "group2"},
}, },
@@ -270,10 +270,10 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "multiple proxies in account - finds correct one", name: "multiple proxies in account - finds correct one",
domain: "app2.example.com", domain: "app2.example.com",
userID: "user1", userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account1": { "account1": {
{Domain: "app1.example.com", AccountID: "account1"}, {Domain: "app1.example.com", AccountID: "account1"},
{Domain: "app2.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}}, {Domain: "app2.example.com", AccountID: "account1", Auth: service.AuthConfig{}},
{Domain: "app3.example.com", AccountID: "account1"}, {Domain: "app3.example.com", AccountID: "account1"},
}, },
}, },
@@ -314,7 +314,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
name string name string
accountID string accountID string
domain string domain string
proxiesByAccount map[string][]*reverseproxy.Service proxiesByAccount map[string][]*service.Service
err error err error
expectProxy bool expectProxy bool
expectErr bool expectErr bool
@@ -323,7 +323,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
name: "proxy found", name: "proxy found",
accountID: "account1", accountID: "account1",
domain: "app.example.com", domain: "app.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account1": { "account1": {
{Domain: "other.example.com", AccountID: "account1"}, {Domain: "other.example.com", AccountID: "account1"},
{Domain: "app.example.com", AccountID: "account1"}, {Domain: "app.example.com", AccountID: "account1"},
@@ -336,7 +336,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
name: "proxy not found in account", name: "proxy not found in account",
accountID: "account1", accountID: "account1",
domain: "unknown.example.com", domain: "unknown.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1"}}, "account1": {{Domain: "app.example.com", AccountID: "account1"}},
}, },
expectProxy: false, expectProxy: false,
@@ -346,7 +346,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
name: "empty proxy list for account", name: "empty proxy list for account",
accountID: "account1", accountID: "account1",
domain: "app.example.com", domain: "app.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{}, proxiesByAccount: map[string][]*service.Service{},
expectProxy: false, expectProxy: false,
expectErr: true, expectErr: true,
}, },

View File

@@ -5,7 +5,6 @@ import (
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
@@ -26,8 +25,7 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan
} }
s.connectedProxies.Store(proxyID, conn) s.connectedProxies.Store(proxyID, conn)
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{}) _ = s.proxyController.RegisterProxyToCluster(context.Background(), clusterAddr, proxyID)
proxySet.(*sync.Map).Store(proxyID, struct{}{})
return ch return ch
} }
@@ -68,7 +66,7 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
}, },
} }
s.SendServiceUpdateToCluster(context.Background(), update, cluster) s.SendServiceUpdateToCluster(context.Background(), mapping, cluster)
tokens := make([]string, numProxies) tokens := make([]string, numProxies)
for i, ch := range channels { for i, ch := range channels {
@@ -116,7 +114,7 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
Domain: "test.example.com", Domain: "test.example.com",
} }
s.SendServiceUpdateToCluster(context.Background(), update, cluster) s.SendServiceUpdateToCluster(context.Background(), mapping, cluster)
resp1 := drainChannel(ch1) resp1 := drainChannel(ch1)
resp2 := drainChannel(ch2) resp2 := drainChannel(ch2)
@@ -126,8 +124,8 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
require.Len(t, resp2.Mapping, 1) require.Len(t, resp2.Mapping, 1)
// Delete operations should not generate tokens // Delete operations should not generate tokens
assert.Empty(t, msg1.AuthToken) assert.Empty(t, resp1.Mapping[0].AuthToken)
assert.Empty(t, msg2.AuthToken) assert.Empty(t, resp2.Mapping[0].AuthToken)
} }
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) { func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {

View File

@@ -13,7 +13,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
@@ -34,14 +34,15 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "../../../server/testdata/auth_callback.sql", t.TempDir()) testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "../../../server/testdata/auth_callback.sql", t.TempDir())
require.NoError(t, err) require.NoError(t, err)
proxyManager := &testValidateSessionProxyManager{store: testStore} serviceManager := &testValidateSessionServiceManager{store: testStore}
usersManager := &testValidateSessionUsersManager{store: testStore} usersManager := &testValidateSessionUsersManager{store: testStore}
proxyManager := &testValidateSessionProxyManager{}
tokenStore, err := NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100) tokenStore, err := NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
require.NoError(t, err) require.NoError(t, err)
proxyService := NewProxyServiceServer(nil, tokenStore, ProxyOIDCConfig{}, nil, usersManager) proxyService := NewProxyServiceServer(nil, tokenStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
proxyService.SetProxyManager(proxyManager) proxyService.SetServiceManager(serviceManager)
createTestProxies(t, ctx, testStore) createTestProxies(t, ctx, testStore)
@@ -57,7 +58,7 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store)
pubKey, privKey := generateSessionKeyPair(t) pubKey, privKey := generateSessionKeyPair(t)
testProxy := &reverseproxy.Service{ testProxy := &service.Service{
ID: "testProxyId", ID: "testProxyId",
AccountID: "testAccountId", AccountID: "testAccountId",
Name: "Test Proxy", Name: "Test Proxy",
@@ -65,15 +66,15 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store)
Enabled: true, Enabled: true,
SessionPrivateKey: privKey, SessionPrivateKey: privKey,
SessionPublicKey: pubKey, SessionPublicKey: pubKey,
Auth: reverseproxy.AuthConfig{ Auth: service.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{ BearerAuth: &service.BearerAuthConfig{
Enabled: true, Enabled: true,
}, },
}, },
} }
require.NoError(t, testStore.CreateService(ctx, testProxy)) require.NoError(t, testStore.CreateService(ctx, testProxy))
restrictedProxy := &reverseproxy.Service{ restrictedProxy := &service.Service{
ID: "restrictedProxyId", ID: "restrictedProxyId",
AccountID: "testAccountId", AccountID: "testAccountId",
Name: "Restricted Proxy", Name: "Restricted Proxy",
@@ -81,8 +82,8 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store)
Enabled: true, Enabled: true,
SessionPrivateKey: privKey, SessionPrivateKey: privKey,
SessionPublicKey: pubKey, SessionPublicKey: pubKey,
Auth: reverseproxy.AuthConfig{ Auth: service.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{ BearerAuth: &service.BearerAuthConfig{
Enabled: true, Enabled: true,
DistributionGroups: []string{"allowedGroupId"}, DistributionGroups: []string{"allowedGroupId"},
}, },
@@ -199,7 +200,7 @@ func TestValidateSession_ProxyNotFound(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.False(t, resp.Valid, "Unknown proxy should be denied") assert.False(t, resp.Valid, "Unknown proxy should be denied")
assert.Equal(t, "proxy_not_found", resp.DeniedReason) assert.Equal(t, "service_not_found", resp.DeniedReason)
} }
func TestValidateSession_InvalidToken(t *testing.T) { func TestValidateSession_InvalidToken(t *testing.T) {
@@ -242,62 +243,88 @@ func TestValidateSession_MissingToken(t *testing.T) {
assert.Contains(t, resp.DeniedReason, "missing") assert.Contains(t, resp.DeniedReason, "missing")
} }
type testValidateSessionProxyManager struct { type testValidateSessionServiceManager struct {
store store.Store store store.Store
} }
func (m *testValidateSessionProxyManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) { func (m *testValidateSessionServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*service.Service, error) {
return nil, nil return nil, nil
} }
func (m *testValidateSessionProxyManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) { func (m *testValidateSessionServiceManager) GetService(_ context.Context, _, _, _ string) (*service.Service, error) {
return nil, nil return nil, nil
} }
func (m *testValidateSessionProxyManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) { func (m *testValidateSessionServiceManager) CreateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, nil return nil, nil
} }
func (m *testValidateSessionProxyManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) { func (m *testValidateSessionServiceManager) UpdateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, nil return nil, nil
} }
func (m *testValidateSessionProxyManager) DeleteService(_ context.Context, _, _, _ string) error { func (m *testValidateSessionServiceManager) DeleteService(_ context.Context, _, _, _ string) error {
return nil return nil
} }
func (m *testValidateSessionProxyManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error { func (m *testValidateSessionServiceManager) DeleteAllServices(_ context.Context, _, _ string) error {
return nil return nil
} }
func (m *testValidateSessionProxyManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error { func (m *testValidateSessionServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
return nil return nil
} }
func (m *testValidateSessionProxyManager) ReloadAllServicesForAccount(_ context.Context, _ string) error { func (m *testValidateSessionServiceManager) SetStatus(_ context.Context, _, _ string, _ service.Status) error {
return nil return nil
} }
func (m *testValidateSessionProxyManager) ReloadService(_ context.Context, _, _ string) error { func (m *testValidateSessionServiceManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
return nil return nil
} }
func (m *testValidateSessionProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) { func (m *testValidateSessionServiceManager) ReloadService(_ context.Context, _, _ string) error {
return nil
}
func (m *testValidateSessionServiceManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
return m.store.GetServices(ctx, store.LockingStrengthNone) return m.store.GetServices(ctx, store.LockingStrengthNone)
} }
func (m *testValidateSessionProxyManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) { func (m *testValidateSessionServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*service.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID) return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
} }
func (m *testValidateSessionProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) { func (m *testValidateSessionServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
} }
func (m *testValidateSessionProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) { func (m *testValidateSessionServiceManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
return "", nil return "", nil
} }
type testValidateSessionProxyManager struct{}
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) GetActiveClusterAddresses(_ context.Context) ([]string, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time.Duration) error {
return nil
}
type testValidateSessionUsersManager struct { type testValidateSessionUsersManager struct {
store store.Store store store.Store
} }

View File

@@ -211,7 +211,7 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
usersManager, usersManager,
) )
proxyService.SetProxyManager(&testServiceManager{store: testStore}) proxyService.SetServiceManager(&testServiceManager{store: testStore})
handler := NewAuthCallbackHandler(proxyService, nil) handler := NewAuthCallbackHandler(proxyService, nil)

View File

@@ -100,7 +100,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr) proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager) domainManager := manager.NewManager(store, proxyMgr, permissionsManager)
serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, proxyServiceServer, domainManager) serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, proxyServiceServer, domainManager)
proxyServiceServer.SetProxyManager(serviceManager) proxyServiceServer.SetServiceManager(serviceManager)
am.SetServiceManager(serviceManager) am.SetServiceManager(serviceManager)
// @note this is required so that PAT's validate from store, but JWT's are mocked // @note this is required so that PAT's validate from store, but JWT's are mocked

View File

@@ -18,8 +18,8 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
@@ -37,7 +37,7 @@ type integrationTestSetup struct {
grpcServer *grpc.Server grpcServer *grpc.Server
grpcAddr string grpcAddr string
cleanup func() cleanup func()
services []*reverseproxy.Service services []*service.Service
} }
func setupIntegrationTest(t *testing.T) *integrationTestSetup { func setupIntegrationTest(t *testing.T) *integrationTestSetup {
@@ -66,13 +66,13 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
privKey := base64.StdEncoding.EncodeToString(priv) privKey := base64.StdEncoding.EncodeToString(priv)
// Create test services in the store // Create test services in the store
services := []*reverseproxy.Service{ services := []*service.Service{
{ {
ID: "rp-1", ID: "rp-1",
AccountID: "test-account-1", AccountID: "test-account-1",
Name: "Test App 1", Name: "Test App 1",
Domain: "app1.test.proxy.io", Domain: "app1.test.proxy.io",
Targets: []*reverseproxy.Target{{ Targets: []*service.Target{{
Path: strPtr("/"), Path: strPtr("/"),
Host: "10.0.0.1", Host: "10.0.0.1",
Port: 8080, Port: 8080,
@@ -91,7 +91,7 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
AccountID: "test-account-1", AccountID: "test-account-1",
Name: "Test App 2", Name: "Test App 2",
Domain: "app2.test.proxy.io", Domain: "app2.test.proxy.io",
Targets: []*reverseproxy.Target{{ Targets: []*service.Target{{
Path: strPtr("/"), Path: strPtr("/"),
Host: "10.0.0.2", Host: "10.0.0.2",
Port: 8080, Port: 8080,
@@ -125,17 +125,23 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
HMACKey: []byte("test-hmac-key"), HMACKey: []byte("test-hmac-key"),
} }
proxyManager := &testProxyManager{}
proxyService := nbgrpc.NewProxyServiceServer( proxyService := nbgrpc.NewProxyServiceServer(
&testAccessLogManager{}, &testAccessLogManager{},
tokenStore, tokenStore,
oidcConfig, oidcConfig,
nil, nil,
usersManager, usersManager,
proxyManager,
) )
// Use store-backed service manager // Use store-backed service manager
svcMgr := &storeBackedServiceManager{store: testStore, tokenStore: tokenStore} svcMgr := &storeBackedServiceManager{store: testStore, tokenStore: tokenStore}
proxyService.SetProxyManager(svcMgr) proxyService.SetServiceManager(svcMgr)
proxyController := &testProxyController{}
proxyService.SetProxyController(proxyController)
// Start real gRPC server // Start real gRPC server
lis, err := net.Listen("tcp", "127.0.0.1:0") lis, err := net.Listen("tcp", "127.0.0.1:0")
@@ -186,6 +192,52 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string,
return nil, 0, nil return nil, 0, nil
} }
// testProxyManager is a mock implementation of proxy.Manager for testing.
type testProxyManager struct{}
func (m *testProxyManager) Connect(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testProxyManager) Disconnect(_ context.Context, _ string) error {
return nil
}
func (m *testProxyManager) Heartbeat(_ context.Context, _ string) error {
return nil
}
func (m *testProxyManager) GetActiveClusterAddresses(_ context.Context) ([]string, error) {
return nil, nil
}
func (m *testProxyManager) CleanupStale(_ context.Context, _ time.Duration) error {
return nil
}
// testProxyController is a mock implementation of rpservice.ProxyController for testing.
type testProxyController struct{}
func (c *testProxyController) SendServiceUpdateToCluster(_ context.Context, _ string, _ *proto.ProxyMapping, _ string) {
// noop
}
func (c *testProxyController) GetOIDCValidationConfig() service.OIDCValidationConfig {
return service.OIDCValidationConfig{}
}
func (c *testProxyController) RegisterProxyToCluster(_ context.Context, _, _ string) error {
return nil
}
func (c *testProxyController) UnregisterProxyFromCluster(_ context.Context, _, _ string) error {
return nil
}
func (c *testProxyController) GetProxiesForCluster(_ string) []string {
return nil
}
// storeBackedServiceManager reads directly from the real store. // storeBackedServiceManager reads directly from the real store.
type storeBackedServiceManager struct { type storeBackedServiceManager struct {
store store.Store store store.Store
@@ -196,19 +248,19 @@ func (m *storeBackedServiceManager) DeleteAllServices(ctx context.Context, accou
return nil return nil
} }
func (m *storeBackedServiceManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) { func (m *storeBackedServiceManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
} }
func (m *storeBackedServiceManager) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) { func (m *storeBackedServiceManager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID) return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
} }
func (m *storeBackedServiceManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) { func (m *storeBackedServiceManager) CreateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (m *storeBackedServiceManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) { func (m *storeBackedServiceManager) UpdateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
@@ -220,7 +272,7 @@ func (m *storeBackedServiceManager) SetCertificateIssuedAt(ctx context.Context,
return nil return nil
} }
func (m *storeBackedServiceManager) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error { func (m *storeBackedServiceManager) SetStatus(ctx context.Context, accountID, serviceID string, status service.Status) error {
return nil return nil
} }
@@ -232,15 +284,15 @@ func (m *storeBackedServiceManager) ReloadService(ctx context.Context, accountID
return nil return nil
} }
func (m *storeBackedServiceManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) { func (m *storeBackedServiceManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, "test-account-1") return m.store.GetAccountServices(ctx, store.LockingStrengthNone, "test-account-1")
} }
func (m *storeBackedServiceManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) { func (m *storeBackedServiceManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*service.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID) return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
} }
func (m *storeBackedServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) { func (m *storeBackedServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
} }