Merge remote-tracking branch 'origin/main' into proto-ipv6-overlay

# Conflicts:
#	client/firewall/uspfilter/forwarder/endpoint.go
#	client/wasm/cmd/main.go
#	proxy/cmd/proxy/cmd/debug.go
This commit is contained in:
Viktor Liu
2026-05-04 11:40:41 +02:00
105 changed files with 6385 additions and 415 deletions

View File

@@ -423,7 +423,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
}
if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers {
go am.UpdateAccountPeers(ctx, accountID)
go am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceAccountSettings, Operation: types.UpdateOperationUpdate})
}
return newSettings, nil
@@ -1654,7 +1654,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
if removedGroupAffectsPeers || newGroupsAffectsPeers {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId)
am.BufferUpdateAccountPeers(ctx, userAuth.AccountId)
am.BufferUpdateAccountPeers(ctx, userAuth.AccountId, types.UpdateReason{Resource: types.UpdateResourceUser, Operation: types.UpdateOperationUpdate})
}
return nil

View File

@@ -125,8 +125,8 @@ type Manager interface {
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
UpdateAccountPeers(ctx context.Context, accountID string)
BufferUpdateAccountPeers(ctx context.Context, accountID string)
UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason)
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason)
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error
GetStore() store.Store

View File

@@ -111,15 +111,15 @@ func (mr *MockManagerMockRecorder) ApproveUser(ctx, accountID, initiatorUserID,
}
// BufferUpdateAccountPeers mocks base method.
func (m *MockManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
func (m *MockManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID)
m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID, reason)
}
// BufferUpdateAccountPeers indicates an expected call of BufferUpdateAccountPeers.
func (mr *MockManagerMockRecorder) BufferUpdateAccountPeers(ctx, accountID interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) BufferUpdateAccountPeers(ctx, accountID, reason interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).BufferUpdateAccountPeers), ctx, accountID)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).BufferUpdateAccountPeers), ctx, accountID, reason)
}
// BuildUserInfosForAccount mocks base method.
@@ -1597,15 +1597,15 @@ func (mr *MockManagerMockRecorder) UpdateAccountOnboarding(ctx, accountID, userI
}
// UpdateAccountPeers mocks base method.
func (m *MockManager) UpdateAccountPeers(ctx context.Context, accountID string) {
func (m *MockManager) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID)
m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID, reason)
}
// UpdateAccountPeers indicates an expected call of UpdateAccountPeers.
func (mr *MockManagerMockRecorder) UpdateAccountPeers(ctx, accountID interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) UpdateAccountPeers(ctx, accountID, reason interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).UpdateAccountPeers), ctx, accountID)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).UpdateAccountPeers), ctx, accountID, reason)
}
// UpdateAccountSettings mocks base method.

View File

@@ -0,0 +1,8 @@
package account
const (
// PATMinExpireDays is the minimum allowed Personal Access Token expiration in days.
PATMinExpireDays = 1
// PATMaxExpireDays is the maximum allowed Personal Access Token expiration in days.
PATMaxExpireDays = 365
)

View File

@@ -1765,7 +1765,7 @@ func hasNilField(x interface{}) error {
if f := rv.Field(i); f.IsValid() {
k := f.Kind()
switch k {
case reflect.Ptr:
case reflect.Pointer:
if f.IsNil() {
return fmt.Errorf("field %s is nil", f.String())
}

View File

@@ -0,0 +1,61 @@
package auth
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"time"
"github.com/eko/gocache/lib/v4/cache"
"github.com/eko/gocache/lib/v4/store"
)
const (
usedTokenKeyPrefix = "jwt-used:"
usedTokenMarker = "1"
)
var (
ErrTokenAlreadyUsed = errors.New("JWT already used")
ErrTokenExpired = errors.New("JWT expired")
)
type SessionStore struct {
cache *cache.Cache[string]
}
func NewSessionStore(cacheStore store.StoreInterface) *SessionStore {
return &SessionStore{cache: cache.New[string](cacheStore)}
}
// RegisterToken records a JWT until its exp time and rejects reuse.
func (s *SessionStore) RegisterToken(ctx context.Context, token string, expiresAt time.Time) error {
ttl := time.Until(expiresAt)
if ttl <= 0 {
return ErrTokenExpired
}
key := usedTokenKeyPrefix + hashToken(token)
_, err := s.cache.Get(ctx, key)
if err == nil {
return ErrTokenAlreadyUsed
}
var notFound *store.NotFound
if !errors.As(err, &notFound) {
return fmt.Errorf("failed to lookup used token entry: %w", err)
}
if err := s.cache.Set(ctx, key, usedTokenMarker, store.WithExpiration(ttl)); err != nil {
return fmt.Errorf("failed to store used token entry: %w", err)
}
return nil
}
func hashToken(token string) string {
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:])
}

View File

@@ -0,0 +1,82 @@
package auth
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbcache "github.com/netbirdio/netbird/management/server/cache"
)
func newTestSessionStore(t *testing.T) *SessionStore {
t.Helper()
cacheStore, err := nbcache.NewStore(context.Background(), time.Hour, time.Hour, 100)
require.NoError(t, err)
return NewSessionStore(cacheStore)
}
func TestSessionStore_FirstRegisterSucceeds(t *testing.T) {
s := newTestSessionStore(t)
ctx := context.Background()
require.NoError(t, s.RegisterToken(ctx, "token", time.Now().Add(time.Hour)))
}
func TestSessionStore_RegisterSameTokenTwiceIsRejected(t *testing.T) {
s := newTestSessionStore(t)
ctx := context.Background()
token := "token"
exp := time.Now().Add(time.Hour)
require.NoError(t, s.RegisterToken(ctx, token, exp))
err := s.RegisterToken(ctx, token, exp)
require.Error(t, err)
assert.ErrorIs(t, err, ErrTokenAlreadyUsed)
}
func TestSessionStore_RegisterDifferentTokensAreIndependent(t *testing.T) {
s := newTestSessionStore(t)
ctx := context.Background()
exp := time.Now().Add(time.Hour)
require.NoError(t, s.RegisterToken(ctx, "tokenA", exp))
require.NoError(t, s.RegisterToken(ctx, "tokenB", exp))
}
func TestSessionStore_RegisterWithPastExpiryIsRejected(t *testing.T) {
s := newTestSessionStore(t)
ctx := context.Background()
token := "token"
err := s.RegisterToken(ctx, token, time.Now().Add(-time.Second))
require.Error(t, err)
assert.ErrorIs(t, err, ErrTokenExpired)
}
func TestSessionStore_EntryEvictsAtTTLAndAllowsReRegistration(t *testing.T) {
s := newTestSessionStore(t)
ctx := context.Background()
token := "token"
require.NoError(t, s.RegisterToken(ctx, token, time.Now().Add(50*time.Millisecond)))
err := s.RegisterToken(ctx, token, time.Now().Add(50*time.Millisecond))
assert.ErrorIs(t, err, ErrTokenAlreadyUsed)
time.Sleep(120 * time.Millisecond)
require.NoError(t, s.RegisterToken(ctx, token, time.Now().Add(time.Hour)))
}
func TestHashToken_StableAndDoesNotLeak(t *testing.T) {
a := hashToken("tokenA")
b := hashToken("tokenB")
assert.Equal(t, a, hashToken("tokenA"), "hash must be deterministic")
assert.NotEqual(t, a, b, "different tokens must hash differently")
assert.Len(t, a, 64, "sha256 hex must be 64 chars")
assert.NotContains(t, a, "tokenA", "raw token must not appear in hash")
}

View File

@@ -86,7 +86,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceDNSSettings, Operation: types.UpdateOperationUpdate})
}
return nil

View File

@@ -117,7 +117,7 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationCreate})
}
return nil
@@ -189,7 +189,7 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
}
return nil
@@ -257,7 +257,7 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationCreate})
}
return globalErr
@@ -305,7 +305,7 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
}
return globalErr
@@ -512,7 +512,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
}
return nil
@@ -550,7 +550,7 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
}
return nil
@@ -582,7 +582,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
}
return nil
@@ -620,7 +620,7 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
}
return nil

View File

@@ -62,9 +62,7 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry"
)
const (
apiPrefix = "/api"
)
const apiPrefix = "/api"
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) {
@@ -141,7 +139,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
zonesManager.RegisterEndpoints(router, zManager)
recordsManager.RegisterEndpoints(router, rManager)
idp.AddEndpoints(accountManager, router)
instance.AddEndpoints(instanceManager, router)
instance.AddEndpoints(instanceManager, accountManager, router)
instance.AddVersionEndpoint(instanceManager, router)
if serviceManager != nil && reverseProxyDomainManager != nil {
reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router)

View File

@@ -7,6 +7,7 @@ import (
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account"
nbinstance "github.com/netbirdio/netbird/management/server/instance"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
@@ -15,13 +16,15 @@ import (
// handler handles the instance setup HTTP endpoints
type handler struct {
instanceManager nbinstance.Manager
setupManager *nbinstance.SetupService
}
// AddEndpoints registers the instance setup endpoints.
// These endpoints bypass authentication for initial setup.
func AddEndpoints(instanceManager nbinstance.Manager, router *mux.Router) {
func AddEndpoints(instanceManager nbinstance.Manager, accountManager account.Manager, router *mux.Router) {
h := &handler{
instanceManager: instanceManager,
setupManager: nbinstance.NewSetupService(instanceManager, accountManager),
}
router.HandleFunc("/instance", h.getInstanceStatus).Methods("GET", "OPTIONS")
@@ -55,24 +58,36 @@ func (h *handler) getInstanceStatus(w http.ResponseWriter, r *http.Request) {
// setup creates the initial admin user for the instance.
// This endpoint is unauthenticated but only works when setup is required.
func (h *handler) setup(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var req api.SetupRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("invalid request body", http.StatusBadRequest, w)
return
}
userData, err := h.instanceManager.CreateOwnerUser(r.Context(), req.Email, req.Password, req.Name)
result, err := h.setupManager.SetupOwner(ctx, req.Email, req.Password, req.Name, nbinstance.SetupOptions{
CreatePAT: req.CreatePat != nil && *req.CreatePat,
PATExpireInDays: req.PatExpireIn,
})
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(ctx, err, w)
return
}
log.WithContext(r.Context()).Infof("instance setup completed: created user %s", req.Email)
log.WithContext(ctx).Infof("instance setup completed: created user %s", req.Email)
util.WriteJSONObject(r.Context(), w, api.SetupResponse{
UserId: userData.ID,
Email: userData.Email,
})
resp := api.SetupResponse{
UserId: result.User.ID,
Email: result.User.Email,
}
if result.PATPlainToken != "" {
resp.PersonalAccessToken = &result.PATPlainToken
}
w.Header().Set("Cache-Control", "no-store")
util.WriteJSONObject(ctx, w, resp)
}
// getVersionInfo returns version information for NetBird components.

View File

@@ -10,12 +10,18 @@ import (
"net/mail"
"testing"
"github.com/golang/mock/gomock"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/idp"
nbinstance "github.com/netbirdio/netbird/management/server/instance"
"github.com/netbirdio/netbird/management/server/mock_server"
nbstore "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -25,6 +31,7 @@ type mockInstanceManager struct {
isSetupRequired bool
isSetupRequiredFn func(ctx context.Context) (bool, error)
createOwnerUserFn func(ctx context.Context, email, password, name string) (*idp.UserData, error)
rollbackSetupFn func(ctx context.Context, userID string) error
getVersionInfoFn func(ctx context.Context) (*nbinstance.VersionInfo, error)
}
@@ -67,6 +74,13 @@ func (m *mockInstanceManager) CreateOwnerUser(ctx context.Context, email, passwo
}, nil
}
func (m *mockInstanceManager) RollbackSetup(ctx context.Context, userID string) error {
if m.rollbackSetupFn != nil {
return m.rollbackSetupFn(ctx, userID)
}
return nil
}
func (m *mockInstanceManager) GetVersionInfo(ctx context.Context) (*nbinstance.VersionInfo, error) {
if m.getVersionInfoFn != nil {
return m.getVersionInfoFn(ctx)
@@ -82,8 +96,12 @@ func (m *mockInstanceManager) GetVersionInfo(ctx context.Context) (*nbinstance.V
var _ nbinstance.Manager = (*mockInstanceManager)(nil)
func setupTestRouter(manager nbinstance.Manager) *mux.Router {
return setupTestRouterWithPAT(manager, nil)
}
func setupTestRouterWithPAT(manager nbinstance.Manager, accountManager account.Manager) *mux.Router {
router := mux.NewRouter()
AddEndpoints(manager, router)
AddEndpoints(manager, accountManager, router)
return router
}
@@ -161,6 +179,7 @@ func TestSetup_Success(t *testing.T) {
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "no-store", rec.Header().Get("Cache-Control"))
var response api.SetupResponse
err := json.NewDecoder(rec.Body).Decode(&response)
@@ -293,6 +312,239 @@ func TestSetup_ManagerError(t *testing.T) {
assert.Equal(t, http.StatusInternalServerError, rec.Code)
}
func TestSetup_PAT_FeatureDisabled_IgnoresCreatePAT(t *testing.T) {
t.Setenv(nbinstance.SetupPATEnabledEnvKey, "false")
manager := &mockInstanceManager{isSetupRequired: true}
// NB_SETUP_PAT_ENABLED=false: request fields must be silently ignored
router := setupTestRouterWithPAT(manager, &mock_server.MockAccountManager{})
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response api.SetupResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&response))
assert.Nil(t, response.PersonalAccessToken)
}
func TestSetup_PAT_FlagOmitted_NoPAT(t *testing.T) {
t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true")
manager := &mockInstanceManager{isSetupRequired: true}
router := setupTestRouterWithPAT(manager, &mock_server.MockAccountManager{})
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin"}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response api.SetupResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&response))
assert.Nil(t, response.PersonalAccessToken)
}
func TestSetup_PAT_MissingExpireIn_DefaultsToOneDay(t *testing.T) {
t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true")
createCalled := false
manager := &mockInstanceManager{
isSetupRequired: true,
createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
createCalled = true
return &idp.UserData{ID: "u1", Email: email, Name: name}, nil
},
}
accountMgr := &mock_server.MockAccountManager{
GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) {
assert.Equal(t, "u1", userAuth.UserId)
return "acc-1", nil
},
CreatePATFunc: func(_ context.Context, accountID, initiator, target, name string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) {
assert.Equal(t, "acc-1", accountID)
assert.Equal(t, "u1", initiator)
assert.Equal(t, "u1", target)
assert.Equal(t, "setup-token", name)
assert.Equal(t, 1, expiresIn)
return &types.PersonalAccessTokenGenerated{PlainToken: "nbp_plain"}, nil
},
}
router := setupTestRouterWithPAT(manager, accountMgr)
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "no-store", rec.Header().Get("Cache-Control"))
assert.True(t, createCalled)
var response api.SetupResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&response))
require.NotNil(t, response.PersonalAccessToken)
assert.Equal(t, "nbp_plain", *response.PersonalAccessToken)
}
func TestSetup_PAT_ExpireOutOfRange(t *testing.T) {
t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true")
manager := &mockInstanceManager{isSetupRequired: true}
router := setupTestRouterWithPAT(manager, &mock_server.MockAccountManager{})
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 0}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnprocessableEntity, rec.Code)
}
func TestSetup_PAT_Success(t *testing.T) {
t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true")
manager := &mockInstanceManager{
isSetupRequired: true,
createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
},
}
gotAccountArgs := struct {
userID string
email string
}{}
accountMgr := &mock_server.MockAccountManager{
GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) {
gotAccountArgs.userID = userAuth.UserId
gotAccountArgs.email = userAuth.Email
return "acc-1", nil
},
CreatePATFunc: func(_ context.Context, accountID, initiator, target, name string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) {
assert.Equal(t, "acc-1", accountID)
assert.Equal(t, "owner-id", initiator)
assert.Equal(t, "owner-id", target)
assert.Equal(t, "setup-token", name)
assert.Equal(t, 30, expiresIn)
return &types.PersonalAccessTokenGenerated{PlainToken: "nbp_plain"}, nil
},
}
router := setupTestRouterWithPAT(manager, accountMgr)
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 30}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "no-store", rec.Header().Get("Cache-Control"))
var response api.SetupResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&response))
assert.Equal(t, "owner-id", response.UserId)
require.NotNil(t, response.PersonalAccessToken)
assert.Equal(t, "nbp_plain", *response.PersonalAccessToken)
assert.Equal(t, "owner-id", gotAccountArgs.userID)
assert.Equal(t, "admin@example.com", gotAccountArgs.email)
}
func TestSetup_PAT_AccountCreationFails_Rollback(t *testing.T) {
t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true")
ctrl := gomock.NewController(t)
accountStore := nbstore.NewMockStore(ctrl)
accountStore.EXPECT().GetAccountIDByUserID(gomock.Any(), nbstore.LockingStrengthNone, "owner-id").Return("", status.NewAccountNotFoundError("owner-id"))
rolledBackFor := ""
manager := &mockInstanceManager{
isSetupRequired: true,
createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
},
rollbackSetupFn: func(_ context.Context, userID string) error {
rolledBackFor = userID
return nil
},
}
accountMgr := &mock_server.MockAccountManager{
GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) {
return "", errors.New("db down")
},
GetStoreFunc: func() nbstore.Store {
return accountStore
},
}
router := setupTestRouterWithPAT(manager, accountMgr)
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 30}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Equal(t, "owner-id", rolledBackFor, "RollbackSetup must be called with the created user id")
}
func TestSetup_PAT_CreatePATFails_Rollback(t *testing.T) {
t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true")
ctrl := gomock.NewController(t)
accountStore := nbstore.NewMockStore(ctrl)
account := &types.Account{Id: "acc-1"}
accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(account, nil)
accountStore.EXPECT().DeleteAccount(gomock.Any(), account).Return(nil)
rolledBackFor := ""
manager := &mockInstanceManager{
isSetupRequired: true,
createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
},
rollbackSetupFn: func(_ context.Context, userID string) error {
rolledBackFor = userID
return nil
},
}
accountMgr := &mock_server.MockAccountManager{
GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) {
return "acc-1", nil
},
CreatePATFunc: func(_ context.Context, _, _, _, _ string, _ int) (*types.PersonalAccessTokenGenerated, error) {
return nil, status.Errorf(status.Internal, "token store unavailable")
},
GetStoreFunc: func() nbstore.Store {
return accountStore
},
}
router := setupTestRouterWithPAT(manager, accountMgr)
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 30}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Equal(t, "owner-id", rolledBackFor, "RollbackSetup must be called when CreatePAT fails")
}
func TestGetVersionInfo_Success(t *testing.T) {
manager := &mockInstanceManager{}
router := mux.NewRouter()

View File

@@ -12,6 +12,7 @@ import (
"sync"
"time"
"github.com/dexidp/dex/storage"
goversion "github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
@@ -60,6 +61,13 @@ type Manager interface {
// This should only be called when IsSetupRequired returns true.
CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error)
// RollbackSetup reverses a successful CreateOwnerUser by deleting the user
// from the embedded IDP and reloading setupRequired from persistent state, so
// /api/setup can be retried only when no accounts or local users remain. Used
// when post-user steps (account or PAT creation) fail and the caller wants a
// clean slate.
RollbackSetup(ctx context.Context, userID string) error
// GetVersionInfo returns version information for NetBird components.
GetVersionInfo(ctx context.Context) (*VersionInfo, error)
}
@@ -70,6 +78,7 @@ type instanceStore interface {
type embeddedIdP interface {
CreateUserWithPassword(ctx context.Context, email, password, name string) (*idp.UserData, error)
DeleteUser(ctx context.Context, userID string) error
GetAllAccounts(ctx context.Context) (map[string][]*idp.UserData, error)
}
@@ -187,6 +196,51 @@ func (m *DefaultManager) CreateOwnerUser(ctx context.Context, email, password, n
return userData, nil
}
// RollbackSetup undoes a successful CreateOwnerUser: deletes the user from the
// embedded IDP and reloads setupRequired from persistent state.
func (m *DefaultManager) RollbackSetup(ctx context.Context, userID string) error {
if m.embeddedIdpManager == nil {
return errors.New("embedded IDP is not enabled")
}
var deleteErr error
if err := m.embeddedIdpManager.DeleteUser(ctx, userID); err != nil {
if isNotFoundError(err) {
log.WithContext(ctx).Debugf("setup rollback user %s already deleted", userID)
} else {
deleteErr = fmt.Errorf("failed to delete user from embedded IdP: %w", err)
}
}
if err := m.loadSetupRequired(ctx); err != nil {
reloadErr := fmt.Errorf("failed to reload setup state after rollback: %w", err)
if deleteErr != nil {
return errors.Join(deleteErr, reloadErr)
}
return reloadErr
}
if deleteErr != nil {
return deleteErr
}
log.WithContext(ctx).Infof("rolled back setup for user %s", userID)
return nil
}
func isNotFoundError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, storage.ErrNotFound) {
return true
}
if s, ok := status.FromError(err); ok {
return s.Type() == status.NotFound
}
return false
}
func (m *DefaultManager) checkSetupRequiredFromDB(ctx context.Context) error {
numAccounts, err := m.store.GetAccountsCounter(ctx)
if err != nil {

View File

@@ -10,16 +10,19 @@ import (
"testing"
"time"
"github.com/dexidp/dex/storage"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/shared/management/status"
)
type mockIdP struct {
mu sync.Mutex
createUserFunc func(ctx context.Context, email, password, name string) (*idp.UserData, error)
users map[string][]*idp.UserData
mu sync.Mutex
createUserFunc func(ctx context.Context, email, password, name string) (*idp.UserData, error)
deleteUserFunc func(ctx context.Context, userID string) error
users map[string][]*idp.UserData
getAllAccountsErr error
}
@@ -30,6 +33,13 @@ func (m *mockIdP) CreateUserWithPassword(ctx context.Context, email, password, n
return &idp.UserData{ID: "test-user-id", Email: email, Name: name}, nil
}
func (m *mockIdP) DeleteUser(ctx context.Context, userID string) error {
if m.deleteUserFunc != nil {
return m.deleteUserFunc(ctx, userID)
}
return nil
}
func (m *mockIdP) GetAllAccounts(_ context.Context) (map[string][]*idp.UserData, error) {
m.mu.Lock()
defer m.mu.Unlock()
@@ -223,6 +233,77 @@ func TestIsSetupRequired_ReturnsFlag(t *testing.T) {
assert.False(t, required)
}
func TestRollbackSetup_UserAlreadyDeletedIsSuccess(t *testing.T) {
tests := []struct {
name string
err error
}{
{
name: "management status not found",
err: status.NewUserNotFoundError("owner-id"),
},
{
name: "dex storage not found",
err: fmt.Errorf("failed to get user for deletion: %w", storage.ErrNotFound),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
idpMock := &mockIdP{
deleteUserFunc: func(_ context.Context, userID string) error {
assert.Equal(t, "owner-id", userID)
return tt.err
},
}
mgr := newTestManager(idpMock, &mockStore{})
mgr.setupRequired = false
err := mgr.RollbackSetup(context.Background(), "owner-id")
require.NoError(t, err)
required, err := mgr.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.True(t, required, "setup should be required when no accounts or local users remain")
})
}
}
func TestRollbackSetup_RecomputesSetupStateWhenAccountStillExists(t *testing.T) {
idpMock := &mockIdP{
deleteUserFunc: func(_ context.Context, _ string) error {
return status.NewUserNotFoundError("owner-id")
},
}
mgr := newTestManager(idpMock, &mockStore{accountsCount: 1})
mgr.setupRequired = true
err := mgr.RollbackSetup(context.Background(), "owner-id")
require.NoError(t, err)
required, err := mgr.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.False(t, required, "setup should not be required while an account still exists")
}
func TestRollbackSetup_ReturnsDeleteErrorButReloadsSetupState(t *testing.T) {
idpMock := &mockIdP{
deleteUserFunc: func(_ context.Context, _ string) error {
return errors.New("idp unavailable")
},
}
mgr := newTestManager(idpMock, &mockStore{})
mgr.setupRequired = false
err := mgr.RollbackSetup(context.Background(), "owner-id")
require.Error(t, err)
assert.Contains(t, err.Error(), "idp unavailable")
required, err := mgr.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.True(t, required, "setup state should be reloaded even when user deletion fails")
}
func TestDefaultManager_ValidateSetupRequest(t *testing.T) {
manager := &DefaultManager{setupRequired: true}

View File

@@ -0,0 +1,216 @@
package instance
import (
"context"
"fmt"
"os"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
setupPATTokenName = "setup-token"
// SetupPATEnabledEnvKey enables setup-time Personal Access Token creation.
SetupPATEnabledEnvKey = "NB_SETUP_PAT_ENABLED"
setupPATDefaultExpireDays = 1
)
// SetupOptions controls optional work performed during initial instance setup.
type SetupOptions struct {
// CreatePAT requests creation of a setup Personal Access Token. It is honored
// only when SetupPATEnabledEnvKey is set to "true".
CreatePAT bool
// PATExpireInDays defaults to 1 day when CreatePAT is requested and setup PAT
// creation is enabled.
PATExpireInDays *int
}
// SetupResult contains resources created during initial instance setup.
type SetupResult struct {
User *idp.UserData
PATPlainToken string
}
// SetupService orchestrates the initial setup use case across the instance and
// account bounded contexts and owns the compensation logic when a later step
// fails.
type SetupService struct {
instanceManager Manager
accountManager account.Manager
setupPATEnabled bool
}
// NewSetupService creates a setup use-case service.
func NewSetupService(instanceManager Manager, accountManager account.Manager) *SetupService {
return &SetupService{
instanceManager: instanceManager,
accountManager: accountManager,
setupPATEnabled: os.Getenv(SetupPATEnabledEnvKey) == "true",
}
}
func normalizeSetupOptions(opts SetupOptions, setupPATEnabled bool) (SetupOptions, error) {
if !opts.CreatePAT {
return opts, nil
}
if !setupPATEnabled {
opts.CreatePAT = false
opts.PATExpireInDays = nil
return opts, nil
}
if opts.PATExpireInDays == nil {
defaultExpireInDays := setupPATDefaultExpireDays
opts.PATExpireInDays = &defaultExpireInDays
}
if *opts.PATExpireInDays < account.PATMinExpireDays || *opts.PATExpireInDays > account.PATMaxExpireDays {
return opts, status.Errorf(status.InvalidArgument, "pat_expire_in must be between %d and %d", account.PATMinExpireDays, account.PATMaxExpireDays)
}
return opts, nil
}
// SetupOwner creates the initial owner user and, when requested and enabled by
// SetupPATEnabledEnvKey, provisions the account and a setup Personal Access
// Token. If account or PAT provisioning fails, created resources are rolled
// back so setup can be retried. If account rollback fails, user rollback is
// skipped to avoid leaving an account without its owner user.
func (m *SetupService) SetupOwner(ctx context.Context, email, password, name string, opts SetupOptions) (*SetupResult, error) {
opts, err := normalizeSetupOptions(opts, m.setupPATEnabled)
if err != nil {
return nil, err
}
if opts.CreatePAT && m.accountManager == nil {
return nil, fmt.Errorf("account manager is required to create setup PAT")
}
userData, err := m.instanceManager.CreateOwnerUser(ctx, email, password, name)
if err != nil {
return nil, err
}
result := &SetupResult{User: userData}
if !opts.CreatePAT {
return result, nil
}
userAuth := auth.UserAuth{
UserId: userData.ID,
Email: userData.Email,
Name: userData.Name,
}
accountID, err := m.accountManager.GetAccountIDByUserID(ctx, userAuth)
if err != nil {
err = fmt.Errorf("create account for setup user: %w", err)
if rollbackErr := m.rollbackSetup(ctx, userData.ID, "account provisioning failed", err, ""); rollbackErr != nil {
return nil, fmt.Errorf("%w; failed to roll back setup resources: %v", err, rollbackErr)
}
return nil, err
}
pat, err := m.accountManager.CreatePAT(ctx, accountID, userData.ID, userData.ID, setupPATTokenName, *opts.PATExpireInDays)
if err != nil {
err = fmt.Errorf("create setup PAT: %w", err)
if rollbackErr := m.rollbackSetup(ctx, userData.ID, "setup PAT provisioning failed", err, accountID); rollbackErr != nil {
return nil, fmt.Errorf("%w; failed to roll back setup resources: %v", err, rollbackErr)
}
return nil, err
}
result.PATPlainToken = pat.PlainToken
return result, nil
}
func (m *SetupService) rollbackSetup(ctx context.Context, userID, reason string, origErr error, accountID string) error {
if accountID == "" {
resolvedAccountID, err := m.lookupSetupAccountIDForRollback(ctx, userID)
if err != nil {
rollbackErr := fmt.Errorf("resolve setup account for rollback: %w", err)
log.WithContext(ctx).Errorf("failed to resolve setup account for user %s after %s: original error: %v, rollback error: %v", userID, reason, origErr, rollbackErr)
return rollbackErr
}
accountID = resolvedAccountID
}
if accountID != "" {
if err := m.rollbackSetupAccount(ctx, accountID); err != nil {
rollbackErr := fmt.Errorf("roll back setup account %s: %w", accountID, err)
log.WithContext(ctx).Errorf("failed to roll back setup account %s for user %s after %s: original error: %v, rollback error: %v", accountID, userID, reason, origErr, rollbackErr)
return rollbackErr
}
log.WithContext(ctx).Warnf("rolled back setup account %s for user %s after %s: %v", accountID, userID, reason, origErr)
}
if err := m.instanceManager.RollbackSetup(ctx, userID); err != nil {
rollbackErr := fmt.Errorf("roll back setup user %s: %w", userID, err)
log.WithContext(ctx).Errorf("failed to roll back setup user %s after %s: original error: %v, rollback error: %v", userID, reason, origErr, rollbackErr)
return rollbackErr
}
log.WithContext(ctx).Warnf("rolled back setup user %s after %s: %v", userID, reason, origErr)
return nil
}
func (m *SetupService) lookupSetupAccountIDForRollback(ctx context.Context, userID string) (string, error) {
if m.accountManager == nil {
return "", fmt.Errorf("account manager is required to resolve setup account")
}
accountStore := m.accountManager.GetStore()
if accountStore == nil {
return "", fmt.Errorf("account store is unavailable")
}
accountID, err := accountStore.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil {
if isNotFoundError(err) {
return "", nil
}
return "", fmt.Errorf("get setup account ID for rollback: %w", err)
}
return accountID, nil
}
// rollbackSetupAccount removes only the setup-created account data from the
// store. It intentionally avoids accountManager.DeleteAccount because the normal
// account deletion path also deletes users from the IdP; embedded IdP cleanup is
// owned by instanceManager.RollbackSetup.
func (m *SetupService) rollbackSetupAccount(ctx context.Context, accountID string) error {
if m.accountManager == nil {
return fmt.Errorf("account manager is required to roll back setup account")
}
accountStore := m.accountManager.GetStore()
if accountStore == nil {
return fmt.Errorf("account store is unavailable")
}
account, err := accountStore.GetAccount(ctx, accountID)
if err != nil {
if isNotFoundError(err) {
return nil
}
return fmt.Errorf("get setup account for rollback: %w", err)
}
if err := accountStore.DeleteAccount(ctx, account); err != nil {
if isNotFoundError(err) {
return nil
}
return fmt.Errorf("delete setup account for rollback: %w", err)
}
return nil
}

View File

@@ -0,0 +1,318 @@
package instance
import (
"context"
"errors"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/mock_server"
nbstore "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/status"
)
type setupInstanceManagerMock struct {
createOwnerUserFn func(ctx context.Context, email, password, name string) (*idp.UserData, error)
rollbackSetupFn func(ctx context.Context, userID string) error
}
func (m *setupInstanceManagerMock) IsSetupRequired(context.Context) (bool, error) {
return true, nil
}
func (m *setupInstanceManagerMock) CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) {
if m.createOwnerUserFn != nil {
return m.createOwnerUserFn(ctx, email, password, name)
}
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
}
func (m *setupInstanceManagerMock) RollbackSetup(ctx context.Context, userID string) error {
if m.rollbackSetupFn != nil {
return m.rollbackSetupFn(ctx, userID)
}
return nil
}
func (m *setupInstanceManagerMock) GetVersionInfo(context.Context) (*VersionInfo, error) {
return &VersionInfo{}, nil
}
var _ Manager = (*setupInstanceManagerMock)(nil)
func intPtr(v int) *int {
return &v
}
func TestSetupOwner_PATFeatureDisabled_IgnoresCreatePAT(t *testing.T) {
t.Setenv(SetupPATEnabledEnvKey, "false")
createCalls := 0
setupManager := NewSetupService(
&setupInstanceManagerMock{
createOwnerUserFn: func(_ context.Context, email, _, name string) (*idp.UserData, error) {
createCalls++
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
},
},
&mock_server.MockAccountManager{},
)
result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{
CreatePAT: true,
})
require.NoError(t, err)
require.NotNil(t, result)
assert.Equal(t, "owner-id", result.User.ID)
assert.Empty(t, result.PATPlainToken)
assert.Equal(t, 1, createCalls)
}
func TestSetupOwner_PATFeatureEnabled_MissingExpireDefaultsToOneDay(t *testing.T) {
t.Setenv(SetupPATEnabledEnvKey, "true")
createCalled := false
setupManager := NewSetupService(
&setupInstanceManagerMock{
createOwnerUserFn: func(_ context.Context, email, _, name string) (*idp.UserData, error) {
createCalled = true
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
},
},
&mock_server.MockAccountManager{
GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) {
assert.Equal(t, "owner-id", userAuth.UserId)
return "acc-1", nil
},
CreatePATFunc: func(_ context.Context, accountID, initiatorUserID, targetUserID, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) {
assert.Equal(t, "acc-1", accountID)
assert.Equal(t, "owner-id", initiatorUserID)
assert.Equal(t, "owner-id", targetUserID)
assert.Equal(t, setupPATTokenName, tokenName)
assert.Equal(t, setupPATDefaultExpireDays, expiresIn)
return &types.PersonalAccessTokenGenerated{PlainToken: "nbp_plain"}, nil
},
},
)
result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{
CreatePAT: true,
})
require.NoError(t, err)
require.NotNil(t, result)
assert.True(t, createCalled)
assert.Equal(t, "nbp_plain", result.PATPlainToken)
}
func TestSetupOwner_PATFeatureEnabled_MissingAccountManagerFailsBeforeCreateUser(t *testing.T) {
t.Setenv(SetupPATEnabledEnvKey, "true")
createCalled := false
rollbackCalled := false
setupManager := NewSetupService(
&setupInstanceManagerMock{
createOwnerUserFn: func(_ context.Context, email, _, name string) (*idp.UserData, error) {
createCalled = true
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
},
rollbackSetupFn: func(_ context.Context, _ string) error {
rollbackCalled = true
return nil
},
},
nil,
)
result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{
CreatePAT: true,
})
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "account manager is required")
assert.False(t, createCalled)
assert.False(t, rollbackCalled)
}
func TestSetupOwner_AccountProvisioningFails_RollsBackSideEffectAccountAndUser(t *testing.T) {
t.Setenv(SetupPATEnabledEnvKey, "true")
ctrl := gomock.NewController(t)
accountStore := nbstore.NewMockStore(ctrl)
account := &types.Account{Id: "acc-1"}
accountStore.EXPECT().GetAccountIDByUserID(gomock.Any(), nbstore.LockingStrengthNone, "owner-id").Return("acc-1", nil)
accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(account, nil)
accountStore.EXPECT().DeleteAccount(gomock.Any(), account).Return(nil)
rolledBackFor := ""
rollbackCalls := 0
setupManager := NewSetupService(
&setupInstanceManagerMock{
rollbackSetupFn: func(_ context.Context, userID string) error {
rollbackCalls++
rolledBackFor = userID
return nil
},
},
&mock_server.MockAccountManager{
GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) {
assert.Equal(t, "owner-id", userAuth.UserId)
return "", errors.New("metadata update failed")
},
GetStoreFunc: func() nbstore.Store {
return accountStore
},
},
)
result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{
CreatePAT: true,
PATExpireInDays: intPtr(30),
})
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "create account for setup user")
assert.Equal(t, "owner-id", rolledBackFor)
assert.Equal(t, 1, rollbackCalls)
}
func TestSetupOwner_CreatePATFails_RollsBackSetupAccountAndUser(t *testing.T) {
t.Setenv(SetupPATEnabledEnvKey, "true")
ctrl := gomock.NewController(t)
accountStore := nbstore.NewMockStore(ctrl)
account := &types.Account{Id: "acc-1"}
accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(account, nil)
accountStore.EXPECT().DeleteAccount(gomock.Any(), account).Return(nil)
rollbackCalls := 0
setupManager := NewSetupService(
&setupInstanceManagerMock{
rollbackSetupFn: func(_ context.Context, userID string) error {
rollbackCalls++
assert.Equal(t, "owner-id", userID)
return nil
},
},
&mock_server.MockAccountManager{
GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) {
assert.Equal(t, "owner-id", userAuth.UserId)
return "acc-1", nil
},
CreatePATFunc: func(_ context.Context, accountID, initiatorUserID, targetUserID, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) {
assert.Equal(t, "acc-1", accountID)
assert.Equal(t, "owner-id", initiatorUserID)
assert.Equal(t, "owner-id", targetUserID)
assert.Equal(t, setupPATTokenName, tokenName)
assert.Equal(t, 30, expiresIn)
return nil, status.Errorf(status.Internal, "token store unavailable")
},
GetStoreFunc: func() nbstore.Store {
return accountStore
},
},
)
result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{
CreatePAT: true,
PATExpireInDays: intPtr(30),
})
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "create setup PAT")
assert.Equal(t, 1, rollbackCalls)
}
func TestSetupOwner_CreatePATFails_AccountAlreadyGoneStillRollsBackUser(t *testing.T) {
t.Setenv(SetupPATEnabledEnvKey, "true")
ctrl := gomock.NewController(t)
accountStore := nbstore.NewMockStore(ctrl)
accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(nil, status.NewAccountNotFoundError("acc-1"))
rolledBackFor := ""
rollbackCalls := 0
setupManager := NewSetupService(
&setupInstanceManagerMock{
rollbackSetupFn: func(_ context.Context, userID string) error {
rollbackCalls++
rolledBackFor = userID
return nil
},
},
&mock_server.MockAccountManager{
GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) {
return "acc-1", nil
},
CreatePATFunc: func(_ context.Context, _, _, _, _ string, _ int) (*types.PersonalAccessTokenGenerated, error) {
return nil, errors.New("token failure")
},
GetStoreFunc: func() nbstore.Store {
return accountStore
},
},
)
result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{
CreatePAT: true,
PATExpireInDays: intPtr(30),
})
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "create setup PAT")
assert.Equal(t, "owner-id", rolledBackFor)
assert.Equal(t, 1, rollbackCalls)
}
func TestSetupOwner_CreatePATFails_AccountRollbackFailureStopsBeforeUserRollback(t *testing.T) {
t.Setenv(SetupPATEnabledEnvKey, "true")
ctrl := gomock.NewController(t)
accountStore := nbstore.NewMockStore(ctrl)
account := &types.Account{Id: "acc-1"}
accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(account, nil)
accountStore.EXPECT().DeleteAccount(gomock.Any(), account).Return(errors.New("delete failed"))
rollbackCalls := 0
setupManager := NewSetupService(
&setupInstanceManagerMock{
rollbackSetupFn: func(_ context.Context, userID string) error {
rollbackCalls++
return nil
},
},
&mock_server.MockAccountManager{
GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) {
return "acc-1", nil
},
CreatePATFunc: func(_ context.Context, _, _, _, _ string, _ int) (*types.PersonalAccessTokenGenerated, error) {
return nil, errors.New("token failure")
},
GetStoreFunc: func() nbstore.Store {
return accountStore
},
},
)
result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{
CreatePAT: true,
PATExpireInDays: intPtr(30),
})
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "create setup PAT")
assert.Contains(t, err.Error(), "failed to roll back setup resources")
assert.Equal(t, 0, rollbackCalls)
}

View File

@@ -391,7 +391,7 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
return nil, nil, "", cleanup, err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController, nil)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController, nil, nil)
if err != nil {
return nil, nil, "", cleanup, err
}

View File

@@ -256,6 +256,7 @@ func startServer(
server.MockIntegratedValidator{},
networkMapController,
nil,
nil,
)
if err != nil {
t.Fatalf("failed creating management server: %v", err)

View File

@@ -129,8 +129,8 @@ type MockAccountManager struct {
GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
AllowSyncFunc func(string, uint64) bool
UpdateAccountPeersFunc func(ctx context.Context, accountID string)
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
UpdateAccountPeersFunc func(ctx context.Context, accountID string, reason types.UpdateReason)
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string, reason types.UpdateReason)
RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error
GetIdentityProviderFunc func(ctx context.Context, accountID, idpID, userID string) (*types.IdentityProvider, error)
@@ -201,15 +201,15 @@ func (am *MockAccountManager) UpdateGroups(ctx context.Context, accountID, userI
return status.Errorf(codes.Unimplemented, "method UpdateGroups is not implemented")
}
func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
if am.UpdateAccountPeersFunc != nil {
am.UpdateAccountPeersFunc(ctx, accountID)
am.UpdateAccountPeersFunc(ctx, accountID, reason)
}
}
func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
if am.BufferUpdateAccountPeersFunc != nil {
am.BufferUpdateAccountPeersFunc(ctx, accountID)
am.BufferUpdateAccountPeersFunc(ctx, accountID, reason)
}
}

View File

@@ -82,7 +82,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationCreate})
}
return newNSGroup.Copy(), nil
@@ -133,7 +133,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationUpdate})
}
return nil
@@ -176,7 +176,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationDelete})
}
return nil

View File

@@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
serverTypes "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -177,7 +178,7 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
event()
}
go m.accountManager.UpdateAccountPeers(ctx, accountID)
go m.accountManager.UpdateAccountPeers(ctx, accountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetwork, Operation: serverTypes.UpdateOperationDelete})
return nil
}

View File

@@ -162,7 +162,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
event()
}
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID)
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationCreate})
return resource, nil
}
@@ -270,7 +270,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
}
}()
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID)
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationUpdate})
return resource, nil
}
@@ -352,7 +352,7 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
event()
}
go m.accountManager.UpdateAccountPeers(ctx, accountID)
go m.accountManager.UpdateAccountPeers(ctx, accountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationDelete})
return nil
}

View File

@@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
serverTypes "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -119,7 +120,7 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network))
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID)
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationCreate})
return router, nil
}
@@ -183,7 +184,7 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network))
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID)
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationUpdate})
return router, nil
}
@@ -217,7 +218,7 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
event()
go m.accountManager.UpdateAccountPeers(ctx, accountID)
go m.accountManager.UpdateAccountPeers(ctx, accountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationDelete})
return nil
}

View File

@@ -1272,12 +1272,12 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
// UpdateAccountPeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
_ = am.networkMapController.UpdateAccountPeers(ctx, accountID)
func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
_ = am.networkMapController.UpdateAccountPeers(ctx, accountID, reason)
}
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
_ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID)
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
_ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID, reason)
}
// UpdateAccountPeer updates a single peer that belongs to an account.

View File

@@ -985,7 +985,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
start := time.Now()
for i := 0; i < b.N; i++ {
manager.UpdateAccountPeers(ctx, account.Id)
manager.UpdateAccountPeers(ctx, account.Id, types.UpdateReason{})
}
duration := time.Since(start)
@@ -1043,7 +1043,7 @@ func testUpdateAccountPeers(t *testing.T) {
peerChannels[peerID] = updateManager.CreateChannel(ctx, peerID)
}
manager.UpdateAccountPeers(ctx, account.Id)
manager.UpdateAccountPeers(ctx, account.Id, types.UpdateReason{})
for _, channel := range peerChannels {
update := <-channel

View File

@@ -96,7 +96,11 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
policyOp := types.UpdateOperationCreate
if isUpdate {
policyOp = types.UpdateOperationUpdate
}
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePolicy, Operation: policyOp})
}
return policy, nil
@@ -139,7 +143,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePolicy, Operation: types.UpdateOperationDelete})
}
return nil

View File

@@ -17,19 +17,48 @@ type PeerNetworkRangeCheck struct {
var _ Check = (*PeerNetworkRangeCheck)(nil)
// prefixContains reports whether outer fully contains inner (equal counts as contained).
// Requires the same address family, that outer is no more specific than inner (its
// netmask is shorter or equal), and that inner's network address falls inside outer.
// This is stricter than netip.Prefix.Contains(Addr) — a peer's /24 NIC will not match a
// configured /32 rule, since the rule covers a single host but the NIC describes a whole
// subnet whose host bits are unknown.
func prefixContains(outer, inner netip.Prefix) bool {
outer = outer.Masked()
inner = inner.Masked()
return outer.Bits() <= inner.Bits() &&
outer.Addr().BitLen() == inner.Addr().BitLen() && // same family
outer.Contains(inner.Addr())
}
// Check evaluates configured ranges against the peer's local network interface prefixes
// and its public connection IP (as a /32 or /128). A configured range matches when it
// fully contains one of those prefixes, so operators can target both private subnets
// and public CIDRs (e.g. 1.0.0.0/24, 2.2.2.2/32). Including the connection IP is what
// lets a public-range posture check work — peer.Meta.NetworkAddresses only carries
// local NIC addresses.
func (p *PeerNetworkRangeCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, error) {
if len(peer.Meta.NetworkAddresses) == 0 {
peerPrefixes := make([]netip.Prefix, 0, len(peer.Meta.NetworkAddresses)+1)
for _, peerNetAddr := range peer.Meta.NetworkAddresses {
peerPrefixes = append(peerPrefixes, peerNetAddr.NetIP)
}
// Unmap collapses 4-in-6 forms (::ffff:a.b.c.d) so an IPv4 range matches.
if connIP := peer.Location.ConnectionIP; len(connIP) > 0 {
if addr, ok := netip.AddrFromSlice(connIP); ok {
addr = addr.Unmap()
peerPrefixes = append(peerPrefixes, netip.PrefixFrom(addr, addr.BitLen()))
}
}
if len(peerPrefixes) == 0 {
return false, fmt.Errorf("peer's does not contain peer network range addresses")
}
maskedPrefixes := make([]netip.Prefix, 0, len(p.Ranges))
for _, prefix := range p.Ranges {
maskedPrefixes = append(maskedPrefixes, prefix.Masked())
}
for _, peerNetAddr := range peer.Meta.NetworkAddresses {
peerMaskedPrefix := peerNetAddr.NetIP.Masked()
if slices.Contains(maskedPrefixes, peerMaskedPrefix) {
for _, peerPrefix := range peerPrefixes {
for _, rangePrefix := range p.Ranges {
if !prefixContains(rangePrefix, peerPrefix) {
continue
}
switch p.Action {
case CheckActionDeny:
return false, nil

View File

@@ -2,6 +2,7 @@ package posture
import (
"context"
"net"
"net/netip"
"testing"
@@ -134,6 +135,205 @@ func TestPeerNetworkRangeCheck_Check(t *testing.T) {
wantErr: true,
isValid: false,
},
{
name: "Peer connection IP matches the denied /32",
check: PeerNetworkRangeCheck{
Action: CheckActionDeny,
Ranges: []netip.Prefix{
netip.MustParsePrefix("109.41.115.194/32"),
},
},
peer: nbpeer.Peer{
Meta: nbpeer.PeerSystemMeta{
NetworkAddresses: []nbpeer.NetworkAddress{
{NetIP: netip.MustParsePrefix("192.168.0.123/24")},
},
},
Location: nbpeer.Location{ConnectionIP: net.ParseIP("109.41.115.194")},
},
wantErr: false,
isValid: false,
},
{
name: "Peer connection IP does not match the denied /32",
check: PeerNetworkRangeCheck{
Action: CheckActionDeny,
Ranges: []netip.Prefix{
netip.MustParsePrefix("109.41.115.194/32"),
},
},
peer: nbpeer.Peer{
Meta: nbpeer.PeerSystemMeta{
NetworkAddresses: []nbpeer.NetworkAddress{
{NetIP: netip.MustParsePrefix("192.168.0.123/24")},
},
},
Location: nbpeer.Location{ConnectionIP: net.ParseIP("8.8.8.8")},
},
wantErr: false,
isValid: true,
},
{
name: "Peer connection IP matches the allowed /32 with no NetworkAddresses",
check: PeerNetworkRangeCheck{
Action: CheckActionAllow,
Ranges: []netip.Prefix{
netip.MustParsePrefix("109.41.115.194/32"),
},
},
peer: nbpeer.Peer{
Location: nbpeer.Location{ConnectionIP: net.ParseIP("109.41.115.194")},
},
wantErr: false,
isValid: true,
},
{
name: "IPv6 connection IP matches the denied /128",
check: PeerNetworkRangeCheck{
Action: CheckActionDeny,
Ranges: []netip.Prefix{
netip.MustParsePrefix("2001:db8::1/128"),
},
},
peer: nbpeer.Peer{
Location: nbpeer.Location{ConnectionIP: net.ParseIP("2001:db8::1")},
},
wantErr: false,
isValid: false,
},
{
name: "IPv6 connection IP does not match the denied /128",
check: PeerNetworkRangeCheck{
Action: CheckActionDeny,
Ranges: []netip.Prefix{
netip.MustParsePrefix("2001:db8::1/128"),
},
},
peer: nbpeer.Peer{
Location: nbpeer.Location{ConnectionIP: net.ParseIP("2001:db8::2")},
},
wantErr: false,
isValid: true,
},
{
name: "IPv4-mapped IPv6 connection IP matches IPv4 /32",
check: PeerNetworkRangeCheck{
Action: CheckActionDeny,
Ranges: []netip.Prefix{
netip.MustParsePrefix("109.41.115.194/32"),
},
},
peer: nbpeer.Peer{
Location: nbpeer.Location{ConnectionIP: net.ParseIP("::ffff:109.41.115.194")},
},
wantErr: false,
isValid: false,
},
{
name: "Connection IP falls inside an allowed /24 range",
check: PeerNetworkRangeCheck{
Action: CheckActionAllow,
Ranges: []netip.Prefix{
netip.MustParsePrefix("1.0.0.0/24"),
netip.MustParsePrefix("2.2.2.2/32"),
},
},
peer: nbpeer.Peer{
Location: nbpeer.Location{ConnectionIP: net.ParseIP("1.0.0.55")},
},
wantErr: false,
isValid: true,
},
{
name: "Connection IP falls inside an allowed /23 range",
check: PeerNetworkRangeCheck{
Action: CheckActionAllow,
Ranges: []netip.Prefix{
netip.MustParsePrefix("3.0.0.0/23"),
},
},
peer: nbpeer.Peer{
Location: nbpeer.Location{ConnectionIP: net.ParseIP("3.0.1.200")},
},
wantErr: false,
isValid: true,
},
{
name: "Connection IP outside the allowed /24 range",
check: PeerNetworkRangeCheck{
Action: CheckActionAllow,
Ranges: []netip.Prefix{
netip.MustParsePrefix("1.0.0.0/24"),
},
},
peer: nbpeer.Peer{
Location: nbpeer.Location{ConnectionIP: net.ParseIP("1.0.1.5")},
},
wantErr: false,
isValid: false,
},
{
name: "Connection IP inside a denied /24 range",
check: PeerNetworkRangeCheck{
Action: CheckActionDeny,
Ranges: []netip.Prefix{
netip.MustParsePrefix("1.0.0.0/24"),
},
},
peer: nbpeer.Peer{
Location: nbpeer.Location{ConnectionIP: net.ParseIP("1.0.0.7")},
},
wantErr: false,
isValid: false,
},
{
name: "Local NIC /24 does not match a /32 rule even if host bit lines up",
check: PeerNetworkRangeCheck{
Action: CheckActionAllow,
Ranges: []netip.Prefix{
netip.MustParsePrefix("192.168.0.5/32"),
},
},
peer: nbpeer.Peer{
Meta: nbpeer.PeerSystemMeta{
NetworkAddresses: []nbpeer.NetworkAddress{
{NetIP: netip.MustParsePrefix("192.168.0.5/24")},
},
},
},
wantErr: false,
isValid: false,
},
{
name: "Local NIC address inside an allowed /16 range",
check: PeerNetworkRangeCheck{
Action: CheckActionAllow,
Ranges: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
peer: nbpeer.Peer{
Meta: nbpeer.PeerSystemMeta{
NetworkAddresses: []nbpeer.NetworkAddress{
{NetIP: netip.MustParsePrefix("192.168.5.7/24")},
},
},
},
wantErr: false,
isValid: true,
},
{
name: "Empty NetworkAddresses and empty ConnectionIP still errors",
check: PeerNetworkRangeCheck{
Action: CheckActionDeny,
Ranges: []netip.Prefix{
netip.MustParsePrefix("109.41.115.194/32"),
},
},
peer: nbpeer.Peer{},
wantErr: true,
isValid: false,
},
}
for _, tt := range tests {

View File

@@ -11,6 +11,7 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -76,7 +77,11 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
postureOp := types.UpdateOperationCreate
if isUpdate {
postureOp = types.UpdateOperationUpdate
}
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePostureCheck, Operation: postureOp})
}
return postureChecks, nil

View File

@@ -191,7 +191,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceRoute, Operation: types.UpdateOperationCreate})
}
return newRoute, nil
@@ -245,7 +245,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
if oldRouteAffectsPeers || newRouteAffectsPeers {
am.UpdateAccountPeers(ctx, accountID)
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceRoute, Operation: types.UpdateOperationUpdate})
}
return nil
@@ -288,7 +288,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceRoute, Operation: types.UpdateOperationDelete})
}
return nil

View File

@@ -4,6 +4,7 @@ import (
"context"
"time"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
)
@@ -11,6 +12,7 @@ import (
type AccountManagerMetrics struct {
ctx context.Context
updateAccountPeersDurationMs metric.Float64Histogram
updateAccountPeersCounter metric.Int64Counter
getPeerNetworkMapDurationMs metric.Float64Histogram
networkMapObjectCount metric.Int64Histogram
peerMetaUpdateCount metric.Int64Counter
@@ -48,6 +50,13 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account
return nil, err
}
updateAccountPeersCounter, err := meter.Int64Counter("management.account.update.account.peers.counter",
metric.WithUnit("1"),
metric.WithDescription("Number of account peers updates triggered, labeled by resource and operation"))
if err != nil {
return nil, err
}
peerMetaUpdateCount, err := meter.Int64Counter("management.account.peer.meta.update.counter",
metric.WithUnit("1"),
metric.WithDescription("Number of updates with new meta data from the peers"))
@@ -59,6 +68,7 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account
ctx: ctx,
getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs,
updateAccountPeersDurationMs: updateAccountPeersDurationMs,
updateAccountPeersCounter: updateAccountPeersCounter,
networkMapObjectCount: networkMapObjectCount,
peerMetaUpdateCount: peerMetaUpdateCount,
}, nil
@@ -80,6 +90,16 @@ func (metrics *AccountManagerMetrics) CountNetworkMapObjects(count int64) {
metrics.networkMapObjectCount.Record(metrics.ctx, count)
}
// CountUpdateAccountPeersTriggered increments the counter for account peers updates with resource and operation labels.
func (metrics *AccountManagerMetrics) CountUpdateAccountPeersTriggered(resource, operation string) {
metrics.updateAccountPeersCounter.Add(metrics.ctx, 1,
metric.WithAttributes(
attribute.String("resource", resource),
attribute.String("operation", operation),
),
)
}
// CountPeerMetUpdate counts the number of peer meta updates
func (metrics *AccountManagerMetrics) CountPeerMetUpdate() {
metrics.peerMetaUpdateCount.Add(metrics.ctx, 1)

View File

@@ -0,0 +1,37 @@
package types
// UpdateReason describes why an account peers update was triggered.
type UpdateReason struct {
Resource UpdateResource
Operation UpdateOperation
}
// UpdateResource represents the kind of resource that triggered an account peers update.
type UpdateResource string
const (
UpdateResourceAccountSettings UpdateResource = "account_settings"
UpdateResourceDNSSettings UpdateResource = "dns_settings"
UpdateResourceGroup UpdateResource = "group"
UpdateResourceNameServerGroup UpdateResource = "nameserver_group"
UpdateResourcePolicy UpdateResource = "policy"
UpdateResourcePostureCheck UpdateResource = "posture_check"
UpdateResourceRoute UpdateResource = "route"
UpdateResourceUser UpdateResource = "user"
UpdateResourcePeer UpdateResource = "peer"
UpdateResourceNetwork UpdateResource = "network"
UpdateResourceNetworkResource UpdateResource = "network_resource"
UpdateResourceNetworkRouter UpdateResource = "network_router"
UpdateResourceService UpdateResource = "service"
UpdateResourceZone UpdateResource = "zone"
UpdateResourceZoneRecord UpdateResource = "zone_record"
)
// UpdateOperation represents the kind of change that triggered the update.
type UpdateOperation string
const (
UpdateOperationCreate UpdateOperation = "create"
UpdateOperationUpdate UpdateOperation = "update"
UpdateOperationDelete UpdateOperation = "delete"
)

View File

@@ -16,6 +16,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -396,8 +397,8 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
return nil, status.Errorf(status.InvalidArgument, "token name can't be empty")
}
if expiresIn < 1 || expiresIn > 365 {
return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365")
if expiresIn < account.PATMinExpireDays || expiresIn > account.PATMaxExpireDays {
return nil, status.Errorf(status.InvalidArgument, "expiration has to be between %d and %d", account.PATMinExpireDays, account.PATMaxExpireDays)
}
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Pats, operations.Create)
@@ -676,7 +677,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
if err = am.Store.IncrementNetworkSerial(ctx, accountID); err != nil {
return nil, fmt.Errorf("failed to increment network serial: %w", err)
}
am.UpdateAccountPeers(ctx, accountID)
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceUser, Operation: types.UpdateOperationUpdate})
}
return updatedUsersInfo, globalErr