diff --git a/management/internals/modules/reverseproxy/domain/manager/manager_test.go b/management/internals/modules/reverseproxy/domain/manager/manager_test.go index 2a88fd2b0..67a8eedc0 100644 --- a/management/internals/modules/reverseproxy/domain/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/domain/manager/manager_test.go @@ -62,20 +62,22 @@ func TestGetClusterAllowList_NoBYOP_FallbackToShared(t *testing.T) { assert.Equal(t, []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, result) } -func TestGetClusterAllowList_BYOPError_FallbackToShared(t *testing.T) { +func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) { pm := &mockProxyManager{ getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { return nil, errors.New("db error") }, getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { - return []string{"eu.proxy.netbird.io"}, nil + t.Fatal("should not call GetActiveClusterAddresses when BYOP lookup fails") + return nil, nil }, } mgr := Manager{proxyManager: pm} result, err := mgr.getClusterAllowList(context.Background(), "acc-123") - require.NoError(t, err) - assert.Equal(t, []string{"eu.proxy.netbird.io"}, result) + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "BYOP cluster addresses") } func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) { diff --git a/management/internals/modules/reverseproxy/proxy/proxy.go b/management/internals/modules/reverseproxy/proxy/proxy.go index 96d1142e7..f203ad328 100644 --- a/management/internals/modules/reverseproxy/proxy/proxy.go +++ b/management/internals/modules/reverseproxy/proxy/proxy.go @@ -1,6 +1,11 @@ package proxy -import "time" +import ( + "errors" + "time" +) + +var ErrAccountProxyAlreadyExists = errors.New("account already has a registered proxy") const ( StatusConnected = "connected" diff --git a/management/internals/modules/reverseproxy/proxytoken/handler.go b/management/internals/modules/reverseproxy/proxytoken/handler.go index 591c465cb..728cdf723 100644 --- a/management/internals/modules/reverseproxy/proxytoken/handler.go +++ b/management/internals/modules/reverseproxy/proxytoken/handler.go @@ -15,6 +15,7 @@ import ( "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" + "github.com/netbirdio/netbird/shared/management/status" ) type handler struct { @@ -58,8 +59,14 @@ func (h *handler) createToken(w http.ResponseWriter, r *http.Request) { } var expiresIn time.Duration - if req.ExpiresIn != nil && *req.ExpiresIn > 0 { - expiresIn = time.Duration(*req.ExpiresIn) * time.Second + if req.ExpiresIn != nil { + if *req.ExpiresIn < 0 { + util.WriteErrorResponse("expires_in must be non-negative", http.StatusBadRequest, w) + return + } + if *req.ExpiresIn > 0 { + expiresIn = time.Duration(*req.ExpiresIn) * time.Second + } } accountID := userAuth.AccountId @@ -134,7 +141,11 @@ func (h *handler) revokeToken(w http.ResponseWriter, r *http.Request) { token, err := h.store.GetProxyAccessTokenByID(r.Context(), store.LockingStrengthNone, tokenID) if err != nil { - util.WriteErrorResponse("token not found", http.StatusNotFound, w) + if s, ok := status.FromError(err); ok && s.ErrorType == status.NotFound { + util.WriteErrorResponse("token not found", http.StatusNotFound, w) + } else { + util.WriteErrorResponse("failed to retrieve token", http.StatusInternalServerError, w) + } return } diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index f8c674749..73ff40e86 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -183,7 +183,10 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest if token != nil && token.AccountID != nil { accountID = token.AccountID - existingProxy, _ := s.proxyManager.GetAccountProxy(ctx, *accountID) + existingProxy, err := s.proxyManager.GetAccountProxy(ctx, *accountID) + if err != nil { + log.WithContext(ctx).Debugf("failed to get account proxy for %s: %v", *accountID, err) + } if existingProxy != nil && existingProxy.ID != proxyID { if existingProxy.Status == proxy.StatusConnected { return status.Errorf(codes.ResourceExhausted, "limit of 1 self-hosted proxy per account") @@ -223,7 +226,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, accountID); err != nil { if accountID != nil { cancel() - if strings.Contains(err.Error(), "UNIQUE constraint") || strings.Contains(err.Error(), "duplicate key") || strings.Contains(err.Error(), "idx_proxy_account_id_unique") { + if errors.Is(err, proxy.ErrAccountProxyAlreadyExists) { return status.Errorf(codes.ResourceExhausted, "limit of 1 self-hosted proxy per account") } return status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err) @@ -453,14 +456,18 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA // BYOP proxies only receive updates for their own account's services. func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateResponse) { log.Debugf("Broadcasting service update to all connected proxy servers") - var updateAccountID string - if len(update.Mapping) > 0 { - updateAccountID = update.Mapping[0].AccountId + updateAccountIDs := make(map[string]struct{}) + for _, m := range update.Mapping { + if m.AccountId != "" { + updateAccountIDs[m.AccountId] = struct{}{} + } } s.connectedProxies.Range(func(key, value interface{}) bool { conn := value.(*proxyConnection) - if conn.accountID != nil && updateAccountID != "" && *conn.accountID != updateAccountID { - return true + if conn.accountID != nil && len(updateAccountIDs) > 0 { + if _, ok := updateAccountIDs[*conn.accountID]; !ok { + return true + } } msg := s.perProxyMessage(update, conn.proxyID) if msg == nil { diff --git a/management/internals/shared/grpc/proxy_group_access_test.go b/management/internals/shared/grpc/proxy_group_access_test.go index b2cb64581..e58ccc928 100644 --- a/management/internals/shared/grpc/proxy_group_access_test.go +++ b/management/internals/shared/grpc/proxy_group_access_test.go @@ -91,6 +91,9 @@ func (m *mockReverseProxyManager) StopServiceFromPeer(_ context.Context, _, _, _ func (m *mockReverseProxyManager) StartExposeReaper(_ context.Context) {} func (m *mockReverseProxyManager) GetServiceByDomain(_ context.Context, domain string) (*service.Service, error) { + if m.err != nil { + return nil, m.err + } for _, services := range m.proxiesByAccount { for _, svc := range services { if svc.Domain == domain { diff --git a/management/internals/shared/grpc/proxy_test.go b/management/internals/shared/grpc/proxy_test.go index b7abb28b6..032a04a19 100644 --- a/management/internals/shared/grpc/proxy_test.go +++ b/management/internals/shared/grpc/proxy_test.go @@ -11,8 +11,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + grpcstatus "google.golang.org/grpc/status" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/proto" ) @@ -304,6 +307,58 @@ func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) { assert.Contains(t, err.Error(), "invalid state format") } +func scopedCtx(accountID string) context.Context { + token := &types.ProxyAccessToken{ + ID: "token-1", + AccountID: &accountID, + } + return context.WithValue(context.Background(), ProxyTokenContextKey, token) +} + +func globalCtx() context.Context { + token := &types.ProxyAccessToken{ + ID: "token-global", + } + return context.WithValue(context.Background(), ProxyTokenContextKey, token) +} + +func TestEnforceAccountScope_AllowsMatchingAccount(t *testing.T) { + err := enforceAccountScope(scopedCtx("acc-1"), "acc-1") + assert.NoError(t, err) +} + +func TestEnforceAccountScope_BlocksMismatchedAccount(t *testing.T) { + err := enforceAccountScope(scopedCtx("acc-1"), "acc-2") + require.Error(t, err) + st, ok := grpcstatus.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.PermissionDenied, st.Code()) +} + +func TestEnforceAccountScope_BlocksEmptyRequestAccountID(t *testing.T) { + err := enforceAccountScope(scopedCtx("acc-1"), "") + require.Error(t, err) + st, ok := grpcstatus.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.PermissionDenied, st.Code()) +} + +func TestEnforceAccountScope_AllowsGlobalToken(t *testing.T) { + err := enforceAccountScope(globalCtx(), "acc-1") + assert.NoError(t, err) + + err = enforceAccountScope(globalCtx(), "acc-2") + assert.NoError(t, err) + + err = enforceAccountScope(globalCtx(), "") + assert.NoError(t, err) +} + +func TestEnforceAccountScope_AllowsNoTokenInContext(t *testing.T) { + err := enforceAccountScope(context.Background(), "acc-1") + assert.NoError(t, err) +} + func TestValidateState_RejectsInvalidHMAC(t *testing.T) { ctx := context.Background() pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 3523a8aba..ccf01fd5e 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -17,6 +17,7 @@ import ( "time" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" "github.com/rs/xid" log "github.com/sirupsen/logrus" @@ -5411,11 +5412,23 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error { result := s.db.WithContext(ctx).Save(p) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save proxy: %v", result.Error) - return status.Errorf(status.Internal, "failed to save proxy: %v", result.Error) + if isUniqueConstraintError(result.Error) { + return proxy.ErrAccountProxyAlreadyExists + } + return status.Errorf(status.Internal, "failed to save proxy") } return nil } +func isUniqueConstraintError(err error) bool { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) && pgErr.Code == "23505" { + return true + } + errStr := err.Error() + return strings.Contains(errStr, "UNIQUE constraint") || strings.Contains(errStr, "duplicate key") +} + // DisconnectProxy updates only the status, disconnected_at, and last_seen fields func (s *SqlStore) DisconnectProxy(ctx context.Context, proxyID string) error { now := time.Now() diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 3aebac60c..c5ef88f07 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -3160,6 +3160,7 @@ components: example: "my-proxy-token" expires_in: type: integer + minimum: 0 description: Token expiration in seconds (0 = never expires) example: 0 required: