fix tests

This commit is contained in:
pascal
2026-02-24 10:32:12 +01:00
parent 33cda4d10c
commit b450fa2cca
9 changed files with 170 additions and 93 deletions

View File

@@ -165,7 +165,7 @@ func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
return Create(s, func() *nbgrpc.ProxyServiceServer {
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
s.AfterInit(func(s *BaseServer) {
proxyService.SetProxyManager(s.ServiceManager())
proxyService.SetServiceManager(s.ServiceManager())
proxyService.SetProxyController(s.ServiceProxyController())
})
return proxyService

View File

@@ -157,7 +157,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
// Eagerly create the gRPC server so that all AfterInit hooks are registered
// before we iterate them. Lazy creation after the loop would miss hooks
// registered during GRPCServer() construction (e.g., SetProxyManager).
// registered during GRPCServer() construction (e.g., SetServiceManager).
s.GRPCServer()
for _, fn := range s.afterInit {

View File

@@ -168,7 +168,7 @@ func (s *ProxyServiceServer) Close() {
s.pkceCleanupCancel()
}
func (s *ProxyServiceServer) SetProxyManager(manager rpservice.Manager) {
func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) {
s.serviceManager = manager
}

View File

@@ -8,12 +8,12 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/types"
)
type mockReverseProxyManager struct {
proxiesByAccount map[string][]*reverseproxy.Service
proxiesByAccount map[string][]*service.Service
err error
}
@@ -21,31 +21,31 @@ func (m *mockReverseProxyManager) DeleteAllServices(ctx context.Context, account
return nil
}
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
if m.err != nil {
return nil, m.err
}
return m.proxiesByAccount[accountID], nil
}
func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
return nil, nil
}
func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
return []*reverseproxy.Service{}, nil
func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
return []*service.Service{}, nil
}
func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*service.Service, error) {
return &service.Service{}, nil
}
func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *service.Service) (*service.Service, error) {
return &service.Service{}, nil
}
func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *service.Service) (*service.Service, error) {
return &service.Service{}, nil
}
func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, userID, reverseProxyID string) error {
@@ -56,7 +56,7 @@ func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, ac
return nil
}
func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error {
func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status service.Status) error {
return nil
}
@@ -68,8 +68,8 @@ func (m *mockReverseProxyManager) ReloadService(ctx context.Context, accountID,
return nil
}
func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*service.Service, error) {
return &service.Service{}, nil
}
func (m *mockReverseProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
@@ -97,7 +97,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name string
domain string
userID string
proxiesByAccount map[string][]*reverseproxy.Service
proxiesByAccount map[string][]*service.Service
users map[string]*types.User
proxyErr error
userErr error
@@ -108,7 +108,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "user not found",
domain: "app.example.com",
userID: "unknown-user",
proxiesByAccount: map[string][]*reverseproxy.Service{
proxiesByAccount: map[string][]*service.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
},
users: map[string]*types.User{},
@@ -119,7 +119,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "proxy not found in user's account",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{},
proxiesByAccount: map[string][]*service.Service{},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
@@ -130,7 +130,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "proxy exists in different account - not accessible",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
proxiesByAccount: map[string][]*service.Service{
"account2": {{Domain: "app.example.com", AccountID: "account2"}},
},
users: map[string]*types.User{
@@ -143,8 +143,8 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "no bearer auth configured - same account allows access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}}},
proxiesByAccount: map[string][]*service.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1", Auth: service.AuthConfig{}}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
@@ -155,12 +155,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "bearer auth disabled - same account allows access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
proxiesByAccount: map[string][]*service.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{Enabled: false},
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{Enabled: false},
},
}},
},
@@ -173,12 +173,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "bearer auth enabled but no groups configured - same account allows access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
proxiesByAccount: map[string][]*service.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{},
},
@@ -194,12 +194,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "user not in allowed groups",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
proxiesByAccount: map[string][]*service.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"group1", "group2"},
},
@@ -216,12 +216,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "user in one of the allowed groups - allow access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
proxiesByAccount: map[string][]*service.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"group1", "group2"},
},
@@ -237,12 +237,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "user in all allowed groups - allow access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
proxiesByAccount: map[string][]*service.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"group1", "group2"},
},
@@ -270,10 +270,10 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "multiple proxies in account - finds correct one",
domain: "app2.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
proxiesByAccount: map[string][]*service.Service{
"account1": {
{Domain: "app1.example.com", AccountID: "account1"},
{Domain: "app2.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}},
{Domain: "app2.example.com", AccountID: "account1", Auth: service.AuthConfig{}},
{Domain: "app3.example.com", AccountID: "account1"},
},
},
@@ -314,7 +314,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
name string
accountID string
domain string
proxiesByAccount map[string][]*reverseproxy.Service
proxiesByAccount map[string][]*service.Service
err error
expectProxy bool
expectErr bool
@@ -323,7 +323,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
name: "proxy found",
accountID: "account1",
domain: "app.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{
proxiesByAccount: map[string][]*service.Service{
"account1": {
{Domain: "other.example.com", AccountID: "account1"},
{Domain: "app.example.com", AccountID: "account1"},
@@ -336,7 +336,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
name: "proxy not found in account",
accountID: "account1",
domain: "unknown.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{
proxiesByAccount: map[string][]*service.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
},
expectProxy: false,
@@ -346,7 +346,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
name: "empty proxy list for account",
accountID: "account1",
domain: "app.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{},
proxiesByAccount: map[string][]*service.Service{},
expectProxy: false,
expectErr: true,
},

View File

@@ -5,7 +5,6 @@ import (
"crypto/rand"
"encoding/base64"
"strings"
"sync"
"testing"
"time"
@@ -26,8 +25,7 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan
}
s.connectedProxies.Store(proxyID, conn)
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
proxySet.(*sync.Map).Store(proxyID, struct{}{})
_ = s.proxyController.RegisterProxyToCluster(context.Background(), clusterAddr, proxyID)
return ch
}
@@ -68,7 +66,7 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
},
}
s.SendServiceUpdateToCluster(context.Background(), update, cluster)
s.SendServiceUpdateToCluster(context.Background(), mapping, cluster)
tokens := make([]string, numProxies)
for i, ch := range channels {
@@ -116,7 +114,7 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
Domain: "test.example.com",
}
s.SendServiceUpdateToCluster(context.Background(), update, cluster)
s.SendServiceUpdateToCluster(context.Background(), mapping, cluster)
resp1 := drainChannel(ch1)
resp2 := drainChannel(ch2)
@@ -126,8 +124,8 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
require.Len(t, resp2.Mapping, 1)
// Delete operations should not generate tokens
assert.Empty(t, msg1.AuthToken)
assert.Empty(t, msg2.AuthToken)
assert.Empty(t, resp1.Mapping[0].AuthToken)
assert.Empty(t, resp2.Mapping[0].AuthToken)
}
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {

View File

@@ -13,7 +13,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -34,14 +34,15 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "../../../server/testdata/auth_callback.sql", t.TempDir())
require.NoError(t, err)
proxyManager := &testValidateSessionProxyManager{store: testStore}
serviceManager := &testValidateSessionServiceManager{store: testStore}
usersManager := &testValidateSessionUsersManager{store: testStore}
proxyManager := &testValidateSessionProxyManager{}
tokenStore, err := NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
proxyService := NewProxyServiceServer(nil, tokenStore, ProxyOIDCConfig{}, nil, usersManager)
proxyService.SetProxyManager(proxyManager)
proxyService := NewProxyServiceServer(nil, tokenStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
proxyService.SetServiceManager(serviceManager)
createTestProxies(t, ctx, testStore)
@@ -57,7 +58,7 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store)
pubKey, privKey := generateSessionKeyPair(t)
testProxy := &reverseproxy.Service{
testProxy := &service.Service{
ID: "testProxyId",
AccountID: "testAccountId",
Name: "Test Proxy",
@@ -65,15 +66,15 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store)
Enabled: true,
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
},
},
}
require.NoError(t, testStore.CreateService(ctx, testProxy))
restrictedProxy := &reverseproxy.Service{
restrictedProxy := &service.Service{
ID: "restrictedProxyId",
AccountID: "testAccountId",
Name: "Restricted Proxy",
@@ -81,8 +82,8 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store)
Enabled: true,
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"allowedGroupId"},
},
@@ -199,7 +200,7 @@ func TestValidateSession_ProxyNotFound(t *testing.T) {
require.NoError(t, err)
assert.False(t, resp.Valid, "Unknown proxy should be denied")
assert.Equal(t, "proxy_not_found", resp.DeniedReason)
assert.Equal(t, "service_not_found", resp.DeniedReason)
}
func TestValidateSession_InvalidToken(t *testing.T) {
@@ -242,62 +243,88 @@ func TestValidateSession_MissingToken(t *testing.T) {
assert.Contains(t, resp.DeniedReason, "missing")
}
type testValidateSessionProxyManager struct {
type testValidateSessionServiceManager struct {
store store.Store
}
func (m *testValidateSessionProxyManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
func (m *testValidateSessionServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*service.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) {
func (m *testValidateSessionServiceManager) GetService(_ context.Context, _, _, _ string) (*service.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
func (m *testValidateSessionServiceManager) CreateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
func (m *testValidateSessionServiceManager) UpdateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) DeleteService(_ context.Context, _, _, _ string) error {
func (m *testValidateSessionServiceManager) DeleteService(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
func (m *testValidateSessionServiceManager) DeleteAllServices(_ context.Context, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error {
func (m *testValidateSessionServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
func (m *testValidateSessionServiceManager) SetStatus(_ context.Context, _, _ string, _ service.Status) error {
return nil
}
func (m *testValidateSessionProxyManager) ReloadService(_ context.Context, _, _ string) error {
func (m *testValidateSessionServiceManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
func (m *testValidateSessionServiceManager) ReloadService(_ context.Context, _, _ string) error {
return nil
}
func (m *testValidateSessionServiceManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
return m.store.GetServices(ctx, store.LockingStrengthNone)
}
func (m *testValidateSessionProxyManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) {
func (m *testValidateSessionServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*service.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
}
func (m *testValidateSessionProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
func (m *testValidateSessionServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
}
func (m *testValidateSessionProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
func (m *testValidateSessionServiceManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
return "", nil
}
type testValidateSessionProxyManager struct{}
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) GetActiveClusterAddresses(_ context.Context) ([]string, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time.Duration) error {
return nil
}
type testValidateSessionUsersManager struct {
store store.Store
}

View File

@@ -211,7 +211,7 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
usersManager,
)
proxyService.SetProxyManager(&testServiceManager{store: testStore})
proxyService.SetServiceManager(&testServiceManager{store: testStore})
handler := NewAuthCallbackHandler(proxyService, nil)

View File

@@ -100,7 +100,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager)
serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, proxyServiceServer, domainManager)
proxyServiceServer.SetProxyManager(serviceManager)
proxyServiceServer.SetServiceManager(serviceManager)
am.SetServiceManager(serviceManager)
// @note this is required so that PAT's validate from store, but JWT's are mocked

View File

@@ -18,8 +18,8 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -37,7 +37,7 @@ type integrationTestSetup struct {
grpcServer *grpc.Server
grpcAddr string
cleanup func()
services []*reverseproxy.Service
services []*service.Service
}
func setupIntegrationTest(t *testing.T) *integrationTestSetup {
@@ -66,13 +66,13 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
privKey := base64.StdEncoding.EncodeToString(priv)
// Create test services in the store
services := []*reverseproxy.Service{
services := []*service.Service{
{
ID: "rp-1",
AccountID: "test-account-1",
Name: "Test App 1",
Domain: "app1.test.proxy.io",
Targets: []*reverseproxy.Target{{
Targets: []*service.Target{{
Path: strPtr("/"),
Host: "10.0.0.1",
Port: 8080,
@@ -91,7 +91,7 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
AccountID: "test-account-1",
Name: "Test App 2",
Domain: "app2.test.proxy.io",
Targets: []*reverseproxy.Target{{
Targets: []*service.Target{{
Path: strPtr("/"),
Host: "10.0.0.2",
Port: 8080,
@@ -125,17 +125,23 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
HMACKey: []byte("test-hmac-key"),
}
proxyManager := &testProxyManager{}
proxyService := nbgrpc.NewProxyServiceServer(
&testAccessLogManager{},
tokenStore,
oidcConfig,
nil,
usersManager,
proxyManager,
)
// Use store-backed service manager
svcMgr := &storeBackedServiceManager{store: testStore, tokenStore: tokenStore}
proxyService.SetProxyManager(svcMgr)
proxyService.SetServiceManager(svcMgr)
proxyController := &testProxyController{}
proxyService.SetProxyController(proxyController)
// Start real gRPC server
lis, err := net.Listen("tcp", "127.0.0.1:0")
@@ -186,6 +192,52 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string,
return nil, 0, nil
}
// testProxyManager is a mock implementation of proxy.Manager for testing.
type testProxyManager struct{}
func (m *testProxyManager) Connect(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testProxyManager) Disconnect(_ context.Context, _ string) error {
return nil
}
func (m *testProxyManager) Heartbeat(_ context.Context, _ string) error {
return nil
}
func (m *testProxyManager) GetActiveClusterAddresses(_ context.Context) ([]string, error) {
return nil, nil
}
func (m *testProxyManager) CleanupStale(_ context.Context, _ time.Duration) error {
return nil
}
// testProxyController is a mock implementation of rpservice.ProxyController for testing.
type testProxyController struct{}
func (c *testProxyController) SendServiceUpdateToCluster(_ context.Context, _ string, _ *proto.ProxyMapping, _ string) {
// noop
}
func (c *testProxyController) GetOIDCValidationConfig() service.OIDCValidationConfig {
return service.OIDCValidationConfig{}
}
func (c *testProxyController) RegisterProxyToCluster(_ context.Context, _, _ string) error {
return nil
}
func (c *testProxyController) UnregisterProxyFromCluster(_ context.Context, _, _ string) error {
return nil
}
func (c *testProxyController) GetProxiesForCluster(_ string) []string {
return nil
}
// storeBackedServiceManager reads directly from the real store.
type storeBackedServiceManager struct {
store store.Store
@@ -196,19 +248,19 @@ func (m *storeBackedServiceManager) DeleteAllServices(ctx context.Context, accou
return nil
}
func (m *storeBackedServiceManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
func (m *storeBackedServiceManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
}
func (m *storeBackedServiceManager) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) {
func (m *storeBackedServiceManager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
}
func (m *storeBackedServiceManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
func (m *storeBackedServiceManager) CreateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, errors.New("not implemented")
}
func (m *storeBackedServiceManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
func (m *storeBackedServiceManager) UpdateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, errors.New("not implemented")
}
@@ -220,7 +272,7 @@ func (m *storeBackedServiceManager) SetCertificateIssuedAt(ctx context.Context,
return nil
}
func (m *storeBackedServiceManager) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error {
func (m *storeBackedServiceManager) SetStatus(ctx context.Context, accountID, serviceID string, status service.Status) error {
return nil
}
@@ -232,15 +284,15 @@ func (m *storeBackedServiceManager) ReloadService(ctx context.Context, accountID
return nil
}
func (m *storeBackedServiceManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
func (m *storeBackedServiceManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, "test-account-1")
}
func (m *storeBackedServiceManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) {
func (m *storeBackedServiceManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*service.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
}
func (m *storeBackedServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
func (m *storeBackedServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
}