review comments

This commit is contained in:
crn4
2026-03-24 13:32:38 +01:00
parent 177171e437
commit 0b5380a7dc
8 changed files with 113 additions and 16 deletions

View File

@@ -62,20 +62,22 @@ func TestGetClusterAllowList_NoBYOP_FallbackToShared(t *testing.T) {
assert.Equal(t, []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, result)
}
func TestGetClusterAllowList_BYOPError_FallbackToShared(t *testing.T) {
func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return nil, errors.New("db error")
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return []string{"eu.proxy.netbird.io"}, nil
t.Fatal("should not call GetActiveClusterAddresses when BYOP lookup fails")
return nil, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"eu.proxy.netbird.io"}, result)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "BYOP cluster addresses")
}
func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) {

View File

@@ -1,6 +1,11 @@
package proxy
import "time"
import (
"errors"
"time"
)
var ErrAccountProxyAlreadyExists = errors.New("account already has a registered proxy")
const (
StatusConnected = "connected"

View File

@@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
type handler struct {
@@ -58,8 +59,14 @@ func (h *handler) createToken(w http.ResponseWriter, r *http.Request) {
}
var expiresIn time.Duration
if req.ExpiresIn != nil && *req.ExpiresIn > 0 {
expiresIn = time.Duration(*req.ExpiresIn) * time.Second
if req.ExpiresIn != nil {
if *req.ExpiresIn < 0 {
util.WriteErrorResponse("expires_in must be non-negative", http.StatusBadRequest, w)
return
}
if *req.ExpiresIn > 0 {
expiresIn = time.Duration(*req.ExpiresIn) * time.Second
}
}
accountID := userAuth.AccountId
@@ -134,7 +141,11 @@ func (h *handler) revokeToken(w http.ResponseWriter, r *http.Request) {
token, err := h.store.GetProxyAccessTokenByID(r.Context(), store.LockingStrengthNone, tokenID)
if err != nil {
util.WriteErrorResponse("token not found", http.StatusNotFound, w)
if s, ok := status.FromError(err); ok && s.ErrorType == status.NotFound {
util.WriteErrorResponse("token not found", http.StatusNotFound, w)
} else {
util.WriteErrorResponse("failed to retrieve token", http.StatusInternalServerError, w)
}
return
}

View File

@@ -183,7 +183,10 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
if token != nil && token.AccountID != nil {
accountID = token.AccountID
existingProxy, _ := s.proxyManager.GetAccountProxy(ctx, *accountID)
existingProxy, err := s.proxyManager.GetAccountProxy(ctx, *accountID)
if err != nil {
log.WithContext(ctx).Debugf("failed to get account proxy for %s: %v", *accountID, err)
}
if existingProxy != nil && existingProxy.ID != proxyID {
if existingProxy.Status == proxy.StatusConnected {
return status.Errorf(codes.ResourceExhausted, "limit of 1 self-hosted proxy per account")
@@ -223,7 +226,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, accountID); err != nil {
if accountID != nil {
cancel()
if strings.Contains(err.Error(), "UNIQUE constraint") || strings.Contains(err.Error(), "duplicate key") || strings.Contains(err.Error(), "idx_proxy_account_id_unique") {
if errors.Is(err, proxy.ErrAccountProxyAlreadyExists) {
return status.Errorf(codes.ResourceExhausted, "limit of 1 self-hosted proxy per account")
}
return status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err)
@@ -453,14 +456,18 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA
// BYOP proxies only receive updates for their own account's services.
func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateResponse) {
log.Debugf("Broadcasting service update to all connected proxy servers")
var updateAccountID string
if len(update.Mapping) > 0 {
updateAccountID = update.Mapping[0].AccountId
updateAccountIDs := make(map[string]struct{})
for _, m := range update.Mapping {
if m.AccountId != "" {
updateAccountIDs[m.AccountId] = struct{}{}
}
}
s.connectedProxies.Range(func(key, value interface{}) bool {
conn := value.(*proxyConnection)
if conn.accountID != nil && updateAccountID != "" && *conn.accountID != updateAccountID {
return true
if conn.accountID != nil && len(updateAccountIDs) > 0 {
if _, ok := updateAccountIDs[*conn.accountID]; !ok {
return true
}
}
msg := s.perProxyMessage(update, conn.proxyID)
if msg == nil {

View File

@@ -91,6 +91,9 @@ func (m *mockReverseProxyManager) StopServiceFromPeer(_ context.Context, _, _, _
func (m *mockReverseProxyManager) StartExposeReaper(_ context.Context) {}
func (m *mockReverseProxyManager) GetServiceByDomain(_ context.Context, domain string) (*service.Service, error) {
if m.err != nil {
return nil, m.err
}
for _, services := range m.proxiesByAccount {
for _, svc := range services {
if svc.Domain == domain {

View File

@@ -11,8 +11,11 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
grpcstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
@@ -304,6 +307,58 @@ func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
assert.Contains(t, err.Error(), "invalid state format")
}
func scopedCtx(accountID string) context.Context {
token := &types.ProxyAccessToken{
ID: "token-1",
AccountID: &accountID,
}
return context.WithValue(context.Background(), ProxyTokenContextKey, token)
}
func globalCtx() context.Context {
token := &types.ProxyAccessToken{
ID: "token-global",
}
return context.WithValue(context.Background(), ProxyTokenContextKey, token)
}
func TestEnforceAccountScope_AllowsMatchingAccount(t *testing.T) {
err := enforceAccountScope(scopedCtx("acc-1"), "acc-1")
assert.NoError(t, err)
}
func TestEnforceAccountScope_BlocksMismatchedAccount(t *testing.T) {
err := enforceAccountScope(scopedCtx("acc-1"), "acc-2")
require.Error(t, err)
st, ok := grpcstatus.FromError(err)
require.True(t, ok)
assert.Equal(t, codes.PermissionDenied, st.Code())
}
func TestEnforceAccountScope_BlocksEmptyRequestAccountID(t *testing.T) {
err := enforceAccountScope(scopedCtx("acc-1"), "")
require.Error(t, err)
st, ok := grpcstatus.FromError(err)
require.True(t, ok)
assert.Equal(t, codes.PermissionDenied, st.Code())
}
func TestEnforceAccountScope_AllowsGlobalToken(t *testing.T) {
err := enforceAccountScope(globalCtx(), "acc-1")
assert.NoError(t, err)
err = enforceAccountScope(globalCtx(), "acc-2")
assert.NoError(t, err)
err = enforceAccountScope(globalCtx(), "")
assert.NoError(t, err)
}
func TestEnforceAccountScope_AllowsNoTokenInContext(t *testing.T) {
err := enforceAccountScope(context.Background(), "acc-1")
assert.NoError(t, err)
}
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
ctx := context.Background()
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)