mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-15 05:09:55 +00:00
[proxy] feature: bring your own proxy (#5627)
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -50,6 +51,11 @@ type ProxyOIDCConfig struct {
|
||||
KeysLocation string
|
||||
}
|
||||
|
||||
// ProxyTokenChecker checks whether a proxy access token is still valid.
|
||||
type ProxyTokenChecker interface {
|
||||
IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error)
|
||||
}
|
||||
|
||||
// ProxyServiceServer implements the ProxyService gRPC server
|
||||
type ProxyServiceServer struct {
|
||||
proto.UnimplementedProxyServiceServer
|
||||
@@ -78,6 +84,9 @@ type ProxyServiceServer struct {
|
||||
// Store for one-time authentication tokens
|
||||
tokenStore *OneTimeTokenStore
|
||||
|
||||
// Checker for proxy access token validity
|
||||
tokenChecker ProxyTokenChecker
|
||||
|
||||
// OIDC configuration for proxy authentication
|
||||
oidcConfig ProxyOIDCConfig
|
||||
|
||||
@@ -123,6 +132,8 @@ type proxyConnection struct {
|
||||
proxyID string
|
||||
sessionID string
|
||||
address string
|
||||
accountID *string
|
||||
tokenID string
|
||||
capabilities *proto.ProxyCapabilities
|
||||
stream proto.ProxyService_GetMappingUpdateServer
|
||||
sendChan chan *proto.GetMappingUpdateResponse
|
||||
@@ -130,8 +141,19 @@ type proxyConnection struct {
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func enforceAccountScope(ctx context.Context, requestAccountID string) error {
|
||||
token := GetProxyTokenFromContext(ctx)
|
||||
if token == nil || token.AccountID == nil {
|
||||
return nil
|
||||
}
|
||||
if requestAccountID == "" || *token.AccountID != requestAccountID {
|
||||
return status.Errorf(codes.PermissionDenied, "account-scoped token cannot access account %s", requestAccountID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewProxyServiceServer creates a new proxy service server.
|
||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
|
||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s := &ProxyServiceServer{
|
||||
accessLogManager: accessLogMgr,
|
||||
@@ -141,6 +163,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
|
||||
peersManager: peersManager,
|
||||
usersManager: usersManager,
|
||||
proxyManager: proxyMgr,
|
||||
tokenChecker: tokenChecker,
|
||||
snapshotBatchSize: snapshotBatchSizeFromEnv(),
|
||||
cancel: cancel,
|
||||
}
|
||||
@@ -200,6 +223,25 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
|
||||
}
|
||||
|
||||
var accountID *string
|
||||
token := GetProxyTokenFromContext(ctx)
|
||||
if token != nil && token.AccountID != nil {
|
||||
accountID = token.AccountID
|
||||
|
||||
available, err := s.proxyManager.IsClusterAddressAvailable(ctx, proxyAddress, *accountID)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "check cluster address: %v", err)
|
||||
}
|
||||
if !available {
|
||||
return status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", proxyAddress)
|
||||
}
|
||||
}
|
||||
|
||||
var tokenID string
|
||||
if token != nil {
|
||||
tokenID = token.ID
|
||||
}
|
||||
|
||||
sessionID := uuid.NewString()
|
||||
|
||||
if old, loaded := s.connectedProxies.Load(proxyID); loaded {
|
||||
@@ -217,6 +259,8 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
proxyID: proxyID,
|
||||
sessionID: sessionID,
|
||||
address: proxyAddress,
|
||||
accountID: accountID,
|
||||
tokenID: tokenID,
|
||||
capabilities: req.GetCapabilities(),
|
||||
stream: stream,
|
||||
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
|
||||
@@ -224,7 +268,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Register proxy in database with capabilities
|
||||
var caps *proxy.Capabilities
|
||||
if c := req.GetCapabilities(); c != nil {
|
||||
caps = &proxy.Capabilities{
|
||||
@@ -233,10 +276,13 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
SupportsCrowdsec: c.SupportsCrowdsec,
|
||||
}
|
||||
}
|
||||
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps)
|
||||
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, accountID, caps)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
|
||||
cancel()
|
||||
if accountID != nil {
|
||||
return status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err)
|
||||
}
|
||||
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
|
||||
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
|
||||
}
|
||||
|
||||
@@ -266,6 +312,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
"session_id": sessionID,
|
||||
"address": proxyAddress,
|
||||
"cluster_addr": proxyAddress,
|
||||
"account_id": accountID,
|
||||
"total_proxies": len(s.GetConnectedProxies()),
|
||||
}).Info("Proxy registered in cluster")
|
||||
defer func() {
|
||||
@@ -286,7 +333,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
|
||||
}()
|
||||
|
||||
go s.heartbeat(connCtx, proxyRecord)
|
||||
go s.heartbeat(connCtx, conn, proxyRecord)
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
@@ -298,8 +345,9 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
}
|
||||
}
|
||||
|
||||
// heartbeat updates the proxy's last_seen timestamp every minute
|
||||
func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) {
|
||||
// heartbeat updates the proxy's last_seen timestamp every minute and
|
||||
// disconnects the proxy if its access token has been revoked.
|
||||
func (s *ProxyServiceServer) heartbeat(ctx context.Context, conn *proxyConnection, p *proxy.Proxy) {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -309,6 +357,19 @@ func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) {
|
||||
if err := s.proxyManager.Heartbeat(ctx, p); err != nil {
|
||||
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", p.ID, err)
|
||||
}
|
||||
|
||||
if conn.tokenID != "" && s.tokenChecker != nil {
|
||||
valid, err := s.tokenChecker.IsProxyAccessTokenValid(ctx, conn.tokenID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to check token validity for proxy %s: %v", conn.proxyID, err)
|
||||
continue
|
||||
}
|
||||
if !valid {
|
||||
log.WithContext(ctx).Warnf("proxy %s token revoked or expired, disconnecting", conn.proxyID)
|
||||
conn.cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
case <-ctx.Done():
|
||||
log.WithContext(ctx).Infof("proxy %s heartbeat stopped: context canceled", p.ID)
|
||||
return
|
||||
@@ -316,8 +377,6 @@ func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) {
|
||||
}
|
||||
}
|
||||
|
||||
// sendSnapshot sends the initial snapshot of services to the connecting proxy.
|
||||
// Only entries matching the proxy's cluster address are sent.
|
||||
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
|
||||
if !isProxyAddressValid(conn.address) {
|
||||
return fmt.Errorf("proxy address is invalid")
|
||||
@@ -355,7 +414,13 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *proxyConnection) ([]*proto.ProxyMapping, error) {
|
||||
services, err := s.serviceManager.GetGlobalServices(ctx)
|
||||
var services []*rpservice.Service
|
||||
var err error
|
||||
if conn.accountID != nil {
|
||||
services, err = s.serviceManager.GetAccountServices(ctx, *conn.accountID)
|
||||
} else {
|
||||
services, err = s.serviceManager.GetGlobalServices(ctx)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get services from store: %w", err)
|
||||
}
|
||||
@@ -380,8 +445,14 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
|
||||
return mappings, nil
|
||||
}
|
||||
|
||||
// isProxyAddressValid validates a proxy address
|
||||
// isProxyAddressValid validates a proxy address (domain name or IP address)
|
||||
func isProxyAddressValid(addr string) bool {
|
||||
if addr == "" {
|
||||
return false
|
||||
}
|
||||
if net.ParseIP(addr) != nil {
|
||||
return true
|
||||
}
|
||||
_, err := domain.ValidateDomains([]string{addr})
|
||||
return err == nil
|
||||
}
|
||||
@@ -405,6 +476,10 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
|
||||
func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) {
|
||||
accessLog := req.GetLog()
|
||||
|
||||
if err := enforceAccountScope(ctx, accessLog.GetAccountId()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fields := log.Fields{
|
||||
"service_id": accessLog.GetServiceId(),
|
||||
"account_id": accessLog.GetAccountId(),
|
||||
@@ -442,11 +517,32 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA
|
||||
// Management should call this when services are created/updated/removed.
|
||||
// For create/update operations a unique one-time auth token is generated per
|
||||
// proxy so that every replica can independently authenticate with management.
|
||||
// BYOP proxies only receive updates for their own account's services.
|
||||
func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateResponse) {
|
||||
log.Debugf("Broadcasting service update to all connected proxy servers")
|
||||
updateAccountIDs := make(map[string]struct{})
|
||||
for _, m := range update.Mapping {
|
||||
if m.AccountId != "" {
|
||||
updateAccountIDs[m.AccountId] = struct{}{}
|
||||
}
|
||||
}
|
||||
s.connectedProxies.Range(func(key, value interface{}) bool {
|
||||
conn := value.(*proxyConnection)
|
||||
resp := s.perProxyMessage(update, conn.proxyID)
|
||||
connUpdate := update
|
||||
if conn.accountID != nil && len(updateAccountIDs) > 0 {
|
||||
if _, ok := updateAccountIDs[*conn.accountID]; !ok {
|
||||
return true
|
||||
}
|
||||
filtered := filterMappingsForAccount(update.Mapping, *conn.accountID)
|
||||
if len(filtered) == 0 {
|
||||
return true
|
||||
}
|
||||
connUpdate = &proto.GetMappingUpdateResponse{
|
||||
Mapping: filtered,
|
||||
InitialSyncComplete: update.InitialSyncComplete,
|
||||
}
|
||||
}
|
||||
resp := s.perProxyMessage(connUpdate, conn.proxyID)
|
||||
if resp == nil {
|
||||
log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID)
|
||||
conn.cancel()
|
||||
@@ -463,6 +559,26 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
|
||||
})
|
||||
}
|
||||
|
||||
// ForceDisconnect cancels the gRPC stream for a connected proxy, causing it to disconnect.
|
||||
func (s *ProxyServiceServer) ForceDisconnect(proxyID string) {
|
||||
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
|
||||
conn := connVal.(*proxyConnection)
|
||||
conn.cancel()
|
||||
s.connectedProxies.Delete(proxyID)
|
||||
log.WithFields(log.Fields{"proxyID": proxyID}).Info("force disconnected proxy")
|
||||
}
|
||||
}
|
||||
|
||||
func filterMappingsForAccount(mappings []*proto.ProxyMapping, accountID string) []*proto.ProxyMapping {
|
||||
var filtered []*proto.ProxyMapping
|
||||
for _, m := range mappings {
|
||||
if m.AccountId == accountID {
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// GetConnectedProxies returns a list of connected proxy IDs
|
||||
func (s *ProxyServiceServer) GetConnectedProxies() []string {
|
||||
var proxies []string
|
||||
@@ -531,6 +647,9 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
|
||||
continue
|
||||
}
|
||||
conn := connVal.(*proxyConnection)
|
||||
if conn.accountID != nil && update.AccountId != "" && *conn.accountID != update.AccountId {
|
||||
continue
|
||||
}
|
||||
if !proxyAcceptsMapping(conn, update) {
|
||||
log.WithContext(ctx).Debugf("Skipping proxy %s: does not support custom ports for mapping %s", proxyID, update.Id)
|
||||
continue
|
||||
@@ -618,6 +737,10 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
||||
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get service from store: %v", err)
|
||||
@@ -737,6 +860,10 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
|
||||
|
||||
// SendStatusUpdate handles status updates from proxy clients.
|
||||
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
|
||||
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accountID := req.GetAccountId()
|
||||
serviceID := req.GetServiceId()
|
||||
protoStatus := req.GetStatus()
|
||||
@@ -807,6 +934,10 @@ func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status {
|
||||
|
||||
// CreateProxyPeer handles proxy peer creation with one-time token authentication
|
||||
func (s *ProxyServiceServer) CreateProxyPeer(ctx context.Context, req *proto.CreateProxyPeerRequest) (*proto.CreateProxyPeerResponse, error) {
|
||||
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serviceID := req.GetServiceId()
|
||||
accountID := req.GetAccountId()
|
||||
token := req.GetToken()
|
||||
@@ -861,6 +992,10 @@ func strPtr(s string) *string {
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCURLRequest) (*proto.GetOIDCURLResponse, error) {
|
||||
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
redirectURL, err := url.Parse(req.GetRedirectUrl())
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err)
|
||||
@@ -989,21 +1124,9 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
|
||||
|
||||
// GenerateSessionToken creates a signed session JWT for the given domain and user.
|
||||
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
|
||||
// Find the service by domain to get its signing key
|
||||
services, err := s.serviceManager.GetGlobalServices(ctx)
|
||||
service, err := s.getServiceByDomain(ctx, domain)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get services: %w", err)
|
||||
}
|
||||
|
||||
var service *rpservice.Service
|
||||
for _, svc := range services {
|
||||
if svc.Domain == domain {
|
||||
service = svc
|
||||
break
|
||||
}
|
||||
}
|
||||
if service == nil {
|
||||
return "", fmt.Errorf("service not found for domain: %s", domain)
|
||||
return "", fmt.Errorf("service not found for domain %s: %w", domain, err)
|
||||
}
|
||||
|
||||
if service.SessionPrivateKey == "" {
|
||||
@@ -1101,6 +1224,10 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := enforceAccountScope(ctx, service.AccountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pubKeyBytes, err := base64.StdEncoding.DecodeString(service.SessionPublicKey)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
@@ -1184,18 +1311,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
|
||||
services, err := s.serviceManager.GetGlobalServices(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get services: %w", err)
|
||||
}
|
||||
|
||||
for _, service := range services {
|
||||
if service.Domain == domain {
|
||||
return service, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("service not found for domain: %s", domain)
|
||||
return s.serviceManager.GetServiceByDomain(ctx, domain)
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error {
|
||||
|
||||
29
management/internals/shared/grpc/proxy_address_test.go
Normal file
29
management/internals/shared/grpc/proxy_address_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIsProxyAddressValid(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr string
|
||||
valid bool
|
||||
}{
|
||||
{name: "valid domain", addr: "eu.proxy.netbird.io", valid: true},
|
||||
{name: "valid subdomain", addr: "byop.proxy.example.com", valid: true},
|
||||
{name: "valid IPv4", addr: "10.0.0.1", valid: true},
|
||||
{name: "valid IPv4 public", addr: "203.0.113.10", valid: true},
|
||||
{name: "valid IPv6", addr: "::1", valid: true},
|
||||
{name: "valid IPv6 full", addr: "2001:db8::1", valid: true},
|
||||
{name: "empty string", addr: "", valid: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.valid, isProxyAddressValid(tt.addr))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -153,9 +153,6 @@ func (i *proxyAuthInterceptor) doValidateProxyToken(ctx context.Context) (*types
|
||||
return nil, status.Errorf(codes.Unauthenticated, "invalid token")
|
||||
}
|
||||
|
||||
// TODO: Enforce AccountID scope for "bring your own proxy" feature.
|
||||
// Currently tokens are management-wide; AccountID field is reserved for future use.
|
||||
|
||||
if !token.IsValid() {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "token expired or revoked")
|
||||
}
|
||||
|
||||
@@ -53,6 +53,10 @@ func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error {
|
||||
return nil
|
||||
}
|
||||
@@ -91,6 +95,20 @@ func (m *mockReverseProxyManager) StopServiceFromPeer(_ context.Context, _, _, _
|
||||
|
||||
func (m *mockReverseProxyManager) StartExposeReaper(_ context.Context) {}
|
||||
|
||||
func (m *mockReverseProxyManager) GetServiceByDomain(_ context.Context, domain string) (*service.Service, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
for _, services := range m.proxiesByAccount {
|
||||
for _, svc := range services {
|
||||
if svc.Domain == domain {
|
||||
return svc, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, errors.New("service not found for domain: " + domain)
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -12,9 +12,12 @@ import (
|
||||
cachestore "github.com/eko/gocache/lib/v4/store"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/codes"
|
||||
grpcstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
@@ -316,6 +319,58 @@ func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "invalid state format")
|
||||
}
|
||||
|
||||
func scopedCtx(accountID string) context.Context {
|
||||
token := &types.ProxyAccessToken{
|
||||
ID: "token-1",
|
||||
AccountID: &accountID,
|
||||
}
|
||||
return context.WithValue(context.Background(), ProxyTokenContextKey, token)
|
||||
}
|
||||
|
||||
func globalCtx() context.Context {
|
||||
token := &types.ProxyAccessToken{
|
||||
ID: "token-global",
|
||||
}
|
||||
return context.WithValue(context.Background(), ProxyTokenContextKey, token)
|
||||
}
|
||||
|
||||
func TestEnforceAccountScope_AllowsMatchingAccount(t *testing.T) {
|
||||
err := enforceAccountScope(scopedCtx("acc-1"), "acc-1")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestEnforceAccountScope_BlocksMismatchedAccount(t *testing.T) {
|
||||
err := enforceAccountScope(scopedCtx("acc-1"), "acc-2")
|
||||
require.Error(t, err)
|
||||
st, ok := grpcstatus.FromError(err)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, codes.PermissionDenied, st.Code())
|
||||
}
|
||||
|
||||
func TestEnforceAccountScope_BlocksEmptyRequestAccountID(t *testing.T) {
|
||||
err := enforceAccountScope(scopedCtx("acc-1"), "")
|
||||
require.Error(t, err)
|
||||
st, ok := grpcstatus.FromError(err)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, codes.PermissionDenied, st.Code())
|
||||
}
|
||||
|
||||
func TestEnforceAccountScope_AllowsGlobalToken(t *testing.T) {
|
||||
err := enforceAccountScope(globalCtx(), "acc-1")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = enforceAccountScope(globalCtx(), "acc-2")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = enforceAccountScope(globalCtx(), "")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestEnforceAccountScope_AllowsNoTokenInContext(t *testing.T) {
|
||||
err := enforceAccountScope(context.Background(), "acc-1")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||
|
||||
@@ -42,7 +42,7 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
|
||||
tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||
|
||||
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
|
||||
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager, nil)
|
||||
proxyService.SetServiceManager(serviceManager)
|
||||
|
||||
createTestProxies(t, ctx, testStore)
|
||||
@@ -318,13 +318,17 @@ func (m *testValidateSessionServiceManager) StopServiceFromPeer(_ context.Contex
|
||||
|
||||
func (m *testValidateSessionServiceManager) StartExposeReaper(_ context.Context) {}
|
||||
|
||||
func (m *testValidateSessionServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
|
||||
return m.store.GetServiceByDomain(ctx, domain)
|
||||
}
|
||||
|
||||
func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type testValidateSessionProxyManager struct{}
|
||||
|
||||
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *proxy.Capabilities) error {
|
||||
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *string, _ *proxy.Capabilities) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -340,6 +344,10 @@ func (m *testValidateSessionProxyManager) GetActiveClusterAddresses(_ context.Co
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) GetActiveClusterAddressesForAccount(_ context.Context, _ string) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) GetActiveClusters(_ context.Context) ([]proxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -348,6 +356,22 @@ func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) GetAccountProxy(_ context.Context, _ string) (*proxy.Proxy, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) CountAccountProxies(_ context.Context, _ string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) IsClusterAddressAvailable(_ context.Context, _, _ string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) DeleteProxy(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user