Merge remote-tracking branch 'origin/main' into refactor/permissions-manager

# Conflicts:
#	management/internals/modules/reverseproxy/domain/manager/manager.go
#	management/internals/modules/reverseproxy/service/manager/api.go
#	management/internals/server/modules.go
#	management/server/http/testing/testing_tools/channel/channel.go
This commit is contained in:
pascal
2026-03-17 12:38:08 +01:00
244 changed files with 17304 additions and 3509 deletions

View File

@@ -325,7 +325,8 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled ||
oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled ||
oldSettings.DNSDomain != newSettings.DNSDomain ||
oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion {
oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion ||
oldSettings.AutoUpdateAlways != newSettings.AutoUpdateAlways {
updateAccountPeers = true
}
@@ -366,6 +367,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handleAutoUpdateVersionSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handleAutoUpdateAlwaysSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handlePeerExposeSettings(ctx, oldSettings, newSettings, userID, accountID)
if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil {
return nil, err
@@ -483,6 +485,16 @@ func (am *DefaultAccountManager) handleAutoUpdateVersionSettings(ctx context.Con
}
}
func (am *DefaultAccountManager) handleAutoUpdateAlwaysSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
if oldSettings.AutoUpdateAlways != newSettings.AutoUpdateAlways {
if newSettings.AutoUpdateAlways {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountAutoUpdateAlwaysEnabled, nil)
} else {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountAutoUpdateAlwaysDisabled, nil)
}
}
}
func (am *DefaultAccountManager) handlePeerExposeSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
oldEnabled := oldSettings.PeerExposeEnabled
newEnabled := newSettings.PeerExposeEnabled

View File

@@ -63,11 +63,20 @@ func (ac *AccountRequestBuffer) GetAccountWithBackpressure(ctx context.Context,
log.WithContext(ctx).Tracef("requesting account %s with backpressure", accountID)
startTime := time.Now()
ac.getAccountRequestCh <- req
result := <-req.ResultChan
log.WithContext(ctx).Tracef("got account with backpressure after %s", time.Since(startTime))
return result.Account, result.Err
select {
case <-ctx.Done():
return nil, ctx.Err()
case ac.getAccountRequestCh <- req:
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case result := <-req.ResultChan:
log.WithContext(ctx).Tracef("got account with backpressure after %s", time.Since(startTime))
return result.Account, result.Err
}
}
func (ac *AccountRequestBuffer) processGetAccountBatch(ctx context.Context, accountID string) {
@@ -86,7 +95,14 @@ func (ac *AccountRequestBuffer) processGetAccountBatch(ctx context.Context, acco
result := &AccountResult{Account: account, Err: err}
for _, req := range requests {
req.ResultChan <- result
if account != nil {
// Shallow copy the account so each goroutine gets its own struct value.
// This prevents data races when callers mutate fields like Policies.
accountCopy := *account
req.ResultChan <- &AccountResult{Account: &accountCopy, Err: err}
} else {
req.ResultChan <- result
}
close(req.ResultChan)
}
}

View File

@@ -3116,7 +3116,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
AnyTimes()
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
peersManager := peers.NewManager(store)
proxyManager := proxy.NewMockManager(ctrl)
proxyManager.EXPECT().
@@ -3128,18 +3128,18 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store)), &config.Config{})
manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, nil, err
}
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager)
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager)
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
if err != nil {
return nil, nil, err
}
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, nil))
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, proxyController, nil))
return manager, updateManager, nil
}

View File

@@ -220,6 +220,18 @@ const (
// AccountPeerExposeDisabled indicates that a user disabled peer expose for the account
AccountPeerExposeDisabled Activity = 115
// AccountAutoUpdateAlwaysEnabled indicates that a user enabled always auto-update for the account
AccountAutoUpdateAlwaysEnabled Activity = 116
// AccountAutoUpdateAlwaysDisabled indicates that a user disabled always auto-update for the account
AccountAutoUpdateAlwaysDisabled Activity = 117
// DomainAdded indicates that a user added a custom domain
DomainAdded Activity = 118
// DomainDeleted indicates that a user deleted a custom domain
DomainDeleted Activity = 119
// DomainValidated indicates that a custom domain was validated
DomainValidated Activity = 120
AccountDeleted Activity = 99999
)
@@ -332,6 +344,8 @@ var activityMap = map[Activity]Code{
UserCreated: {"User created", "user.create"},
AccountAutoUpdateVersionUpdated: {"Account AutoUpdate Version updated", "account.settings.auto.version.update"},
AccountAutoUpdateAlwaysEnabled: {"Account auto-update always enabled", "account.setting.auto.update.always.enable"},
AccountAutoUpdateAlwaysDisabled: {"Account auto-update always disabled", "account.setting.auto.update.always.disable"},
IdentityProviderCreated: {"Identity provider created", "identityprovider.create"},
IdentityProviderUpdated: {"Identity provider updated", "identityprovider.update"},
@@ -364,6 +378,10 @@ var activityMap = map[Activity]Code{
AccountPeerExposeEnabled: {"Account peer expose enabled", "account.setting.peer.expose.enable"},
AccountPeerExposeDisabled: {"Account peer expose disabled", "account.setting.peer.expose.disable"},
DomainAdded: {"Domain added", "domain.add"},
DomainDeleted: {"Domain deleted", "domain.delete"},
DomainValidated: {"Domain validated", "domain.validate"},
}
// StringCode returns a string code of the activity

View File

@@ -222,12 +222,12 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
// return empty extra settings for expected calls to UpdateAccountPeers
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes()
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
peersManager := peers.NewManager(store)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store)), &config.Config{})
return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
}

View File

@@ -44,6 +44,12 @@ type Record struct {
GeonameID uint `maxminddb:"geoname_id"`
ISOCode string `maxminddb:"iso_code"`
} `maxminddb:"country"`
Subdivisions []struct {
ISOCode string `maxminddb:"iso_code"`
Names struct {
En string `maxminddb:"en"`
} `maxminddb:"names"`
} `maxminddb:"subdivisions"`
}
type City struct {

View File

@@ -765,7 +765,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
// Saving a group linked to network router should update account peers and send peer update
t.Run("saving group linked to network router", func(t *testing.T) {
permissionsManager := permissions.NewManager(manager.Store)
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
groupsManager := groups.NewManager(manager.Store, manager)
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.serviceManager)
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)

View File

@@ -175,7 +175,6 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
if serviceManager != nil && reverseProxyDomainManager != nil {
reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router)
}
// Register OAuth callback handler for proxy authentication
if proxyGRPCServer != nil {
oauthHandler := proxy.NewAuthCallbackHandler(proxyGRPCServer, trustedHTTPProxies)

View File

@@ -220,6 +220,9 @@ func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJS
return nil, fmt.Errorf("invalid AutoUpdateVersion")
}
}
if req.Settings.AutoUpdateAlways != nil {
returnSettings.AutoUpdateAlways = *req.Settings.AutoUpdateAlways
}
return returnSettings, nil
}
@@ -329,6 +332,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
LazyConnectionEnabled: &settings.LazyConnectionEnabled,
DnsDomain: &settings.DNSDomain,
AutoUpdateVersion: &settings.AutoUpdateVersion,
AutoUpdateAlways: &settings.AutoUpdateAlways,
EmbeddedIdpEnabled: &settings.EmbeddedIdpEnabled,
LocalAuthDisabled: &settings.LocalAuthDisabled,
}

View File

@@ -80,7 +80,7 @@ func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.U
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
h(w, r, userAuth)
h(w, r, &userAuth)
}
}
@@ -133,6 +133,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
RoutingPeerDnsResolutionEnabled: br(false),
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
AutoUpdateAlways: br(false),
AutoUpdateVersion: sr(""),
EmbeddedIdpEnabled: br(false),
LocalAuthDisabled: br(false),
@@ -158,6 +159,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
RoutingPeerDnsResolutionEnabled: br(false),
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
AutoUpdateAlways: br(false),
AutoUpdateVersion: sr(""),
EmbeddedIdpEnabled: br(false),
LocalAuthDisabled: br(false),
@@ -183,6 +185,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
RoutingPeerDnsResolutionEnabled: br(false),
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
AutoUpdateAlways: br(false),
AutoUpdateVersion: sr("latest"),
EmbeddedIdpEnabled: br(false),
LocalAuthDisabled: br(false),
@@ -208,6 +211,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
RoutingPeerDnsResolutionEnabled: br(false),
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
AutoUpdateAlways: br(false),
AutoUpdateVersion: sr(""),
EmbeddedIdpEnabled: br(false),
LocalAuthDisabled: br(false),
@@ -233,6 +237,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
RoutingPeerDnsResolutionEnabled: br(false),
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
AutoUpdateAlways: br(false),
AutoUpdateVersion: sr(""),
EmbeddedIdpEnabled: br(false),
LocalAuthDisabled: br(false),
@@ -258,6 +263,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
RoutingPeerDnsResolutionEnabled: br(false),
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
AutoUpdateAlways: br(false),
AutoUpdateVersion: sr(""),
EmbeddedIdpEnabled: br(false),
LocalAuthDisabled: br(false),

View File

@@ -31,7 +31,7 @@ func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.U
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
h(w, r, userAuth)
h(w, r, &userAuth)
}
}

View File

@@ -24,18 +24,6 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server"
)
// wrapHandler wraps a handler function that requires userAuth parameter
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
h(w, r, userAuth)
}
}
const (
existingNSGroupID = "existingNSGroupID"
notFoundNSGroupID = "notFoundNSGroupID"

View File

@@ -171,7 +171,7 @@ func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.U
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
h(w, r, userAuth)
h(w, r, &userAuth)
}
}

View File

@@ -41,7 +41,7 @@ func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.U
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
h(w, r, userAuth)
h(w, r, &userAuth)
}
}

View File

@@ -95,7 +95,7 @@ func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.U
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
h(w, r, userAuth)
h(w, r, &userAuth)
}
}

View File

@@ -42,7 +42,7 @@ func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.U
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
h(w, r, userAuth)
h(w, r, &userAuth)
}
}

View File

@@ -73,7 +73,7 @@ func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.U
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
h(w, r, userAuth)
h(w, r, &userAuth)
}
}

View File

@@ -21,18 +21,6 @@ import (
"github.com/netbirdio/netbird/shared/management/status"
)
// wrapHandler wraps a handler function that requires userAuth parameter
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
h(w, r, userAuth)
}
}
func initPoliciesTestData(policies ...*types.Policy) *handler {
testPolicies := make(map[string]*types.Policy, len(policies))
for _, policy := range policies {

View File

@@ -26,18 +26,6 @@ import (
var berlin = "Berlin"
var losAngeles = "Los Angeles"
// wrapHandler wraps a handler function that requires userAuth parameter
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
h(w, r, userAuth)
}
}
func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksHandler {
testPostureChecks := make(map[string]*posture.Checks, len(postureChecks))
for _, postureCheck := range postureChecks {

View File

@@ -193,6 +193,9 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
usersManager := users.NewManager(testStore)
oidcConfig := nbgrpc.ProxyOIDCConfig{
@@ -206,6 +209,7 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
proxyService := nbgrpc.NewProxyServiceServer(
&testAccessLogManager{},
tokenStore,
pkceStore,
oidcConfig,
nil,
usersManager,

View File

@@ -33,7 +33,7 @@ func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.U
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
h(w, r, userAuth)
h(w, r, &userAuth)
}
}

View File

@@ -30,7 +30,7 @@ func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.U
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
h(w, r, userAuth)
h(w, r, &userAuth)
}
}

View File

@@ -32,7 +32,7 @@ func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.U
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
h(w, r, userAuth)
h(w, r, &userAuth)
}
}

View File

@@ -857,7 +857,7 @@ func TestRejectUserEndpoint(t *testing.T) {
handler := newHandler(am)
router := mux.NewRouter()
router.HandleFunc("/users/{userId}/reject", handler.rejectUser).Methods("DELETE")
router.HandleFunc("/users/{userId}/reject", wrapHandler(handler.rejectUser)).Methods("DELETE")
req, err := http.NewRequest("DELETE", "/users/pending-user/reject", nil)
require.NoError(t, err)
@@ -948,7 +948,7 @@ func TestChangePasswordEndpoint(t *testing.T) {
handler := newHandler(am)
router := mux.NewRouter()
router.HandleFunc("/users/{userId}/password", handler.changePassword).Methods("PUT")
router.HandleFunc("/users/{userId}/password", wrapHandler(handler.changePassword)).Methods("PUT")
reqPath := "/users/" + tc.targetUserID + "/password"
req, err := http.NewRequest("PUT", reqPath, bytes.NewBufferString(tc.requestBody))
@@ -987,7 +987,7 @@ func TestChangePasswordEndpoint_WrongMethod(t *testing.T) {
req = nbcontext.SetUserAuthInRequest(req, userAuth)
rr := httptest.NewRecorder()
handler.changePassword(rr, req)
handler.changePassword(rr, req, &userAuth)
assert.Equal(t, http.StatusMethodNotAllowed, rr.Code)
}

View File

@@ -51,19 +51,28 @@ func GetList() []string {
// This can be used to bypass authz/authn middlewares for certain paths, such as webhooks that implement their own authentication.
func ShouldBypass(requestPath string, h http.Handler, w http.ResponseWriter, r *http.Request) bool {
byPassMutex.RLock()
defer byPassMutex.RUnlock()
var matched bool
for bypassPath := range bypassPaths {
matched, err := path.Match(bypassPath, requestPath)
m, err := path.Match(bypassPath, requestPath)
if err != nil {
log.WithContext(r.Context()).Errorf("Error matching path %s with %s from %s: %v", bypassPath, requestPath, GetList(), err)
list := make([]string, 0, len(bypassPaths))
for k := range bypassPaths {
list = append(list, k)
}
log.WithContext(r.Context()).Errorf("Error matching path %s with %s from %v: %v", bypassPath, requestPath, list, err)
continue
}
if matched {
h.ServeHTTP(w, r)
return true
if m {
matched = true
break
}
}
byPassMutex.RUnlock()
if matched {
h.ServeHTTP(w, r)
return true
}
return false
}

View File

@@ -80,8 +80,8 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
proxyController := integrations.NewController(store)
userManager := users.NewManager(store)
permissionsManager := permissions.NewManager(store)
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager, settings.IdpConfig{})
peersManager := peers.NewManager(store, permissionsManager)
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), settings.IdpConfig{})
peersManager := peers.NewManager(store)
jobManager := job.NewJobManager(nil, store, peersManager)
@@ -93,23 +93,28 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
t.Fatalf("Failed to create manager: %v", err)
}
accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil)
accessLogsManager := accesslogsmanager.NewManager(store, nil)
proxyTokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100)
if err != nil {
t.Fatalf("Failed to create proxy token store: %v", err)
}
pkceverifierStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
if err != nil {
t.Fatalf("Failed to create PKCE verifier store: %v", err)
}
noopMeter := noop.NewMeterProvider().Meter("")
proxyMgr, err := proxymanager.NewManager(store, noopMeter)
if err != nil {
t.Fatalf("Failed to create proxy manager: %v", err)
}
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager)
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
domainManager := manager.NewManager(store, proxyMgr, am)
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
if err != nil {
t.Fatalf("Failed to create proxy controller: %v", err)
}
serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, domainManager)
domainManager.SetClusterCapabilities(serviceProxyController)
serviceManager := reverseproxymanager.NewManager(store, am, serviceProxyController, domainManager)
proxyServiceServer.SetServiceManager(serviceManager)
am.SetServiceManager(serviceManager)
@@ -126,8 +131,8 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
resourcesManagerMock := resources.NewManagerMock()
routersManagerMock := routers.NewManagerMock()
groupsManagerMock := groups.NewManagerMock()
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
customZonesManager := zonesManager.NewManager(store, am, "")
zoneRecordsManager := recordsManager.NewManager(store, am)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
if err != nil {

View File

@@ -81,7 +81,7 @@ func createManagerWithEmbeddedIdP(t testing.TB) (*DefaultAccountManager, *update
AnyTimes()
permissionsManager := permissions.NewManager(testStore)
peersManager := peers.NewManager(testStore, permissionsManager)
peersManager := peers.NewManager(testStore)
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, testStore)

View File

@@ -28,7 +28,13 @@ func NewChannel() *Channel {
return jc
}
func (jc *Channel) AddEvent(ctx context.Context, responseWait time.Duration, event *Event) error {
func (jc *Channel) AddEvent(ctx context.Context, responseWait time.Duration, event *Event) (err error) {
defer func() {
if r := recover(); r != nil {
err = ErrJobChannelClosed
}
}()
select {
case <-ctx.Done():
return ctx.Err()

View File

@@ -362,12 +362,12 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
AnyTimes()
permissionsManager := permissions.NewManager(store)
groupsManager := groups.NewManagerMock()
peersManager := peers.NewManager(store, permissionsManager)
peersManager := peers.NewManager(store)
jobManager := job.NewJobManager(nil, store, peersManager)
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
ephemeralMgr := manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager))
ephemeralMgr := manager.NewEphemeralManager(store, peers.NewManager(store))
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeralMgr, config)
accountManager, err := BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "",

View File

@@ -203,13 +203,13 @@ func startServer(
AnyTimes()
permissionsManager := permissions.NewManager(str)
peersManager := peers.NewManager(str, permissionsManager)
peersManager := peers.NewManager(str)
jobManager := job.NewJobManager(nil, str, peersManager)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(ctx, str)
networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(str, peers.NewManager(str, permissionsManager)), config)
networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(str, peers.NewManager(str)), config)
accountManager, err := server.BuildManager(
context.Background(),
@@ -232,7 +232,7 @@ func startServer(
t.Fatalf("failed creating an account manager: %v", err)
}
groupsManager := groups.NewManager(str, permissionsManager, accountManager)
groupsManager := groups.NewManager(str, accountManager)
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
t.Fatalf("failed creating secrets manager: %v", err)

View File

@@ -219,7 +219,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
servicesStatusActive int
servicesStatusPending int
servicesStatusError int
servicesTargetType map[string]int
servicesTargetType map[rpservice.TargetType]int
servicesAuthPassword int
servicesAuthPin int
servicesAuthOIDC int
@@ -232,7 +232,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
rulesDirection = make(map[string]int)
activeUsersLastDay = make(map[string]struct{})
embeddedIdpTypes = make(map[string]int)
servicesTargetType = make(map[string]int)
servicesTargetType = make(map[rpservice.TargetType]int)
uptime = time.Since(w.startupTime).Seconds()
connections := w.connManager.GetAllConnectedPeers()
version = nbversion.NetbirdVersion()
@@ -434,7 +434,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
metricsProperties["custom_domains_validated"] = customDomainsValidated
for targetType, count := range servicesTargetType {
metricsProperties["services_target_type_"+targetType] = count
metricsProperties["services_target_type_"+string(targetType)] = count
}
for idpType, count := range embeddedIdpTypes {

View File

@@ -791,12 +791,12 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
AnyTimes()
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
peersManager := peers.NewManager(store)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store)), &config.Config{})
return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
}

View File

@@ -1290,12 +1290,12 @@ func Test_RegisterPeerByUser(t *testing.T) {
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
permissionsManager := permissions.NewManager(s)
peersManager := peers.NewManager(s, permissionsManager)
peersManager := peers.NewManager(s)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s)), &config.Config{})
am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
@@ -1376,12 +1376,12 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
Return(&types.ExtraSettings{}, nil).
AnyTimes()
permissionsManager := permissions.NewManager(s)
peersManager := peers.NewManager(s, permissionsManager)
peersManager := peers.NewManager(s)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s)), &config.Config{})
am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
@@ -1530,12 +1530,12 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl)
permissionsManager := permissions.NewManager(s)
peersManager := peers.NewManager(s, permissionsManager)
peersManager := peers.NewManager(s)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s)), &config.Config{})
am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
@@ -1611,12 +1611,12 @@ func Test_LoginPeer(t *testing.T) {
Return(&types.ExtraSettings{}, nil).
AnyTimes()
permissionsManager := permissions.NewManager(s)
peersManager := peers.NewManager(s, permissionsManager)
peersManager := peers.NewManager(s)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s)), &config.Config{})
am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)

View File

@@ -1290,12 +1290,12 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, *update_channel.
Return(&types.ExtraSettings{}, nil)
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
peersManager := peers.NewManager(store)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store)), &config.Config{})
am, err := BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {

View File

@@ -30,6 +30,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/zones"
@@ -4977,9 +4978,9 @@ func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStren
return service, nil
}
func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) {
func (s *SqlStore) GetServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
var service *rpservice.Service
result := s.db.Preload("Targets").Where("account_id = ? AND domain = ?", accountID, domain).First(&service)
result := s.db.Preload("Targets").Where("domain = ?", domain).First(&service)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "service with domain %s not found", domain)
@@ -4996,6 +4997,7 @@ func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain str
return service, nil
}
func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error) {
tx := s.db.Preload("Targets")
if lockStrength != LockingStrengthNone {
@@ -5041,16 +5043,16 @@ func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingS
}
// RenewEphemeralService updates the last_renewed_at timestamp for an ephemeral service.
func (s *SqlStore) RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error {
func (s *SqlStore) RenewEphemeralService(ctx context.Context, accountID, peerID, serviceID string) error {
result := s.db.Model(&rpservice.Service{}).
Where("account_id = ? AND source_peer = ? AND domain = ? AND source = ?", accountID, peerID, domain, rpservice.SourceEphemeral).
Where("id = ? AND account_id = ? AND source_peer = ? AND source = ?", serviceID, accountID, peerID, rpservice.SourceEphemeral).
Update("meta_last_renewed_at", time.Now())
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to renew ephemeral service: %v", result.Error)
return status.Errorf(status.Internal, "renew ephemeral service")
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "no active expose session for domain %s", domain)
return status.Errorf(status.NotFound, "no active expose session for service %s", serviceID)
}
return nil
}
@@ -5133,6 +5135,37 @@ func (s *SqlStore) EphemeralServiceExists(ctx context.Context, lockStrength Lock
return id != "", nil
}
// GetServicesByClusterAndPort returns services matching the given proxy cluster, mode, and listen port.
func (s *SqlStore) GetServicesByClusterAndPort(ctx context.Context, lockStrength LockingStrength, proxyCluster string, mode string, listenPort uint16) ([]*rpservice.Service, error) {
tx := s.db.WithContext(ctx)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var services []*rpservice.Service
result := tx.Where("proxy_cluster = ? AND mode = ? AND listen_port = ?", proxyCluster, mode, listenPort).Find(&services)
if result.Error != nil {
return nil, status.Errorf(status.Internal, "query services by cluster and port")
}
return services, nil
}
// GetServicesByCluster returns all services for the given proxy cluster.
func (s *SqlStore) GetServicesByCluster(ctx context.Context, lockStrength LockingStrength, proxyCluster string) ([]*rpservice.Service, error) {
tx := s.db.WithContext(ctx)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var services []*rpservice.Service
result := tx.Where("proxy_cluster = ?", proxyCluster).Find(&services)
if result.Error != nil {
return nil, status.Errorf(status.Internal, "query services by cluster")
}
return services, nil
}
func (s *SqlStore) GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error) {
tx := s.db

View File

@@ -257,14 +257,16 @@ type Store interface {
UpdateService(ctx context.Context, service *rpservice.Service) error
DeleteService(ctx context.Context, accountID, serviceID string) error
GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error)
GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error)
GetServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error)
GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error)
GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error)
RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error
RenewEphemeralService(ctx context.Context, accountID, peerID, serviceID string) error
GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*rpservice.Service, error)
CountEphemeralServicesByPeer(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (int64, error)
EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error)
GetServicesByClusterAndPort(ctx context.Context, lockStrength LockingStrength, proxyCluster string, mode string, listenPort uint16) ([]*rpservice.Service, error)
GetServicesByCluster(ctx context.Context, lockStrength LockingStrength, proxyCluster string) ([]*rpservice.Service, error)
GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error)
ListFreeDomains(ctx context.Context, accountID string) ([]string, error)

View File

@@ -1932,18 +1932,18 @@ func (mr *MockStoreMockRecorder) GetRouteByID(ctx, lockStrength, accountID, rout
}
// GetServiceByDomain mocks base method.
func (m *MockStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*service.Service, error) {
func (m *MockStore) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, accountID, domain)
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
ret0, _ := ret[0].(*service.Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
func (mr *MockStoreMockRecorder) GetServiceByDomain(ctx, accountID, domain interface{}) *gomock.Call {
func (mr *MockStoreMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockStore)(nil).GetServiceByDomain), ctx, accountID, domain)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockStore)(nil).GetServiceByDomain), ctx, domain)
}
// GetServiceByID mocks base method.
@@ -1991,6 +1991,36 @@ func (mr *MockStoreMockRecorder) GetServices(ctx, lockStrength interface{}) *gom
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServices", reflect.TypeOf((*MockStore)(nil).GetServices), ctx, lockStrength)
}
// GetServicesByCluster mocks base method.
func (m *MockStore) GetServicesByCluster(ctx context.Context, lockStrength LockingStrength, proxyCluster string) ([]*service.Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServicesByCluster", ctx, lockStrength, proxyCluster)
ret0, _ := ret[0].([]*service.Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetServicesByCluster indicates an expected call of GetServicesByCluster.
func (mr *MockStoreMockRecorder) GetServicesByCluster(ctx, lockStrength, proxyCluster interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByCluster", reflect.TypeOf((*MockStore)(nil).GetServicesByCluster), ctx, lockStrength, proxyCluster)
}
// GetServicesByClusterAndPort mocks base method.
func (m *MockStore) GetServicesByClusterAndPort(ctx context.Context, lockStrength LockingStrength, proxyCluster, mode string, listenPort uint16) ([]*service.Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServicesByClusterAndPort", ctx, lockStrength, proxyCluster, mode, listenPort)
ret0, _ := ret[0].([]*service.Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetServicesByClusterAndPort indicates an expected call of GetServicesByClusterAndPort.
func (mr *MockStoreMockRecorder) GetServicesByClusterAndPort(ctx, lockStrength, proxyCluster, mode, listenPort interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByClusterAndPort", reflect.TypeOf((*MockStore)(nil).GetServicesByClusterAndPort), ctx, lockStrength, proxyCluster, mode, listenPort)
}
// GetSetupKeyByID mocks base method.
func (m *MockStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types2.SetupKey, error) {
m.ctrl.T.Helper()
@@ -2447,17 +2477,17 @@ func (mr *MockStoreMockRecorder) RemoveResourceFromGroup(ctx, accountId, groupID
}
// RenewEphemeralService mocks base method.
func (m *MockStore) RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error {
func (m *MockStore) RenewEphemeralService(ctx context.Context, accountID, peerID, serviceID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RenewEphemeralService", ctx, accountID, peerID, domain)
ret := m.ctrl.Call(m, "RenewEphemeralService", ctx, accountID, peerID, serviceID)
ret0, _ := ret[0].(error)
return ret0
}
// RenewEphemeralService indicates an expected call of RenewEphemeralService.
func (mr *MockStoreMockRecorder) RenewEphemeralService(ctx, accountID, peerID, domain interface{}) *gomock.Call {
func (mr *MockStoreMockRecorder) RenewEphemeralService(ctx, accountID, peerID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewEphemeralService", reflect.TypeOf((*MockStore)(nil).RenewEphemeralService), ctx, accountID, peerID, domain)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewEphemeralService", reflect.TypeOf((*MockStore)(nil).RenewEphemeralService), ctx, accountID, peerID, serviceID)
}
// RevokeProxyAccessToken mocks base method.

View File

@@ -907,8 +907,8 @@ func (a *Account) Copy() *Account {
}
services := []*service.Service{}
for _, service := range a.Services {
services = append(services, service.Copy())
for _, svc := range a.Services {
services = append(services, svc.Copy())
}
return &Account{
@@ -1605,12 +1605,12 @@ func (a *Account) GetPoliciesForNetworkResource(resourceId string) []*Policy {
networkResourceGroups := a.getNetworkResourceGroups(resourceId)
for _, policy := range a.Policies {
if !policy.Enabled {
if policy == nil || !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if !rule.Enabled {
if rule == nil || !rule.Enabled {
continue
}
@@ -1812,15 +1812,18 @@ func (a *Account) InjectProxyPolicies(ctx context.Context) {
}
a.injectServiceProxyPolicies(ctx, service, proxyPeersByCluster)
}
}
func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *service.Service, proxyPeersByCluster map[string][]*nbpeer.Peer) {
proxyPeers := proxyPeersByCluster[service.ProxyCluster]
for _, target := range service.Targets {
if !target.Enabled {
continue
}
a.injectTargetProxyPolicies(ctx, service, target, proxyPeersByCluster[service.ProxyCluster])
a.injectTargetProxyPolicies(ctx, service, target, proxyPeers)
}
}
func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *service.Service, target *service.Target, proxyPeers []*nbpeer.Peer) {
@@ -1840,13 +1843,13 @@ func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *servic
}
}
func (a *Account) resolveTargetPort(ctx context.Context, target *service.Target) (int, bool) {
func (a *Account) resolveTargetPort(ctx context.Context, target *service.Target) (uint16, bool) {
if target.Port != 0 {
return target.Port, true
}
switch target.Protocol {
case "https":
case "https", "tls":
return 443, true
case "http":
return 80, true
@@ -1856,17 +1859,23 @@ func (a *Account) resolveTargetPort(ctx context.Context, target *service.Target)
}
}
func (a *Account) createProxyPolicy(service *service.Service, target *service.Target, proxyPeer *nbpeer.Peer, port int, path string) *Policy {
policyID := fmt.Sprintf("proxy-access-%s-%s-%s", service.ID, proxyPeer.ID, path)
func (a *Account) createProxyPolicy(svc *service.Service, target *service.Target, proxyPeer *nbpeer.Peer, port uint16, path string) *Policy {
policyID := fmt.Sprintf("proxy-access-%s-%s-%s", svc.ID, proxyPeer.ID, path)
protocol := PolicyRuleProtocolTCP
if svc.Mode == service.ModeUDP {
protocol = PolicyRuleProtocolUDP
}
return &Policy{
ID: policyID,
Name: fmt.Sprintf("Proxy Access to %s", service.Name),
Name: fmt.Sprintf("Proxy Access to %s", svc.Name),
Enabled: true,
Rules: []*PolicyRule{
{
ID: policyID,
PolicyID: policyID,
Name: fmt.Sprintf("Allow access to %s", service.Name),
Name: fmt.Sprintf("Allow access to %s", svc.Name),
Enabled: true,
SourceResource: Resource{
ID: proxyPeer.ID,
@@ -1877,12 +1886,12 @@ func (a *Account) createProxyPolicy(service *service.Service, target *service.Ta
Type: ResourceType(target.TargetType),
},
Bidirectional: false,
Protocol: PolicyRuleProtocolTCP,
Protocol: protocol,
Action: PolicyTrafficActionAccept,
PortRanges: []RulePortRange{
{
Start: uint16(port),
End: uint16(port),
Start: port,
End: port,
},
},
},

View File

@@ -368,7 +368,7 @@ func (a *Account) getPeersGroupsPoliciesRoutes(
func (a *Account) getPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string,
validatedPeersMap map[string]struct{}, postureFailedPeers *map[string]map[string]struct{}) ([]string, bool) {
peerInGroups := false
filteredPeerIDs := make([]string, 0, len(a.Peers))
filteredPeerIDs := make([]string, 0, len(groups))
seenPeerIds := make(map[string]struct{}, len(groups))
for _, gid := range groups {
@@ -378,7 +378,7 @@ func (a *Account) getPeersFromGroups(ctx context.Context, groups []string, peerI
}
if group.IsGroupAll() || len(groups) == 1 {
filteredPeerIDs = filteredPeerIDs[:0]
filteredPeerIDs = make([]string, 0, len(group.Peers))
peerInGroups = false
for _, pid := range group.Peers {
peer, ok := a.Peers[pid]

View File

@@ -152,6 +152,8 @@ func (n *Network) CurrentSerial() uint64 {
}
func (n *Network) Copy() *Network {
n.Mu.Lock()
defer n.Mu.Unlock()
return &Network{
Identifier: n.Identifier,
Net: n.Net,

View File

@@ -134,7 +134,7 @@ func (c *NetworkMapComponents) Calculate(ctx context.Context) *NetworkMap {
sourcePeers,
)
dnsManagementStatus := c.getPeerDNSManagementStatus(targetPeerID)
dnsManagementStatus := c.getPeerDNSManagementStatusFromGroups(peerGroups)
dnsUpdate := nbdns.Config{
ServiceEnable: dnsManagementStatus,
}
@@ -152,7 +152,7 @@ func (c *NetworkMapComponents) Calculate(ctx context.Context) *NetworkMap {
customZones = append(customZones, c.AccountZones...)
dnsUpdate.CustomZones = customZones
dnsUpdate.NameServerGroups = c.getPeerNSGroups(targetPeerID)
dnsUpdate.NameServerGroups = c.getPeerNSGroupsFromGroups(targetPeerID, peerGroups)
}
return &NetworkMap{
@@ -278,6 +278,16 @@ func (c *NetworkMapComponents) connResourcesGenerator(targetPeer *nbpeer.Peer) (
peers := make([]*nbpeer.Peer, 0)
return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) {
protocol := rule.Protocol
if protocol == PolicyRuleProtocolNetbirdSSH {
protocol = PolicyRuleProtocolTCP
}
protocolStr := string(protocol)
actionStr := string(rule.Action)
dirStr := strconv.Itoa(direction)
portsJoined := strings.Join(rule.Ports, ",")
for _, peer := range groupPeers {
if peer == nil {
continue
@@ -288,21 +298,18 @@ func (c *NetworkMapComponents) connResourcesGenerator(targetPeer *nbpeer.Peer) (
peersExists[peer.ID] = struct{}{}
}
protocol := rule.Protocol
if protocol == PolicyRuleProtocolNetbirdSSH {
protocol = PolicyRuleProtocolTCP
}
peerIP := net.IP(peer.IP).String()
fr := FirewallRule{
PolicyID: rule.ID,
PeerIP: net.IP(peer.IP).String(),
PeerIP: peerIP,
Direction: direction,
Action: string(rule.Action),
Protocol: string(protocol),
Action: actionStr,
Protocol: protocolStr,
}
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
fr.Protocol + fr.Action + strings.Join(rule.Ports, ",")
ruleID := rule.ID + peerIP + dirStr +
protocolStr + actionStr + portsJoined
if _, ok := rulesExists[ruleID]; ok {
continue
}
@@ -313,13 +320,7 @@ func (c *NetworkMapComponents) connResourcesGenerator(targetPeer *nbpeer.Peer) (
continue
}
rules = append(rules, expandPortsAndRanges(fr, &PolicyRule{
ID: rule.ID,
Ports: rule.Ports,
PortRanges: rule.PortRanges,
Protocol: rule.Protocol,
Action: rule.Action,
}, targetPeer)...)
rules = append(rules, expandPortsAndRanges(fr, rule, targetPeer)...)
}
}, func() ([]*nbpeer.Peer, []*FirewallRule) {
return peers, rules
@@ -395,7 +396,7 @@ func (c *NetworkMapComponents) getPeerFromResource(resource Resource, peerID str
}
func (c *NetworkMapComponents) filterPeersByLoginExpiration(aclPeers []*nbpeer.Peer) ([]*nbpeer.Peer, []*nbpeer.Peer) {
var peersToConnect []*nbpeer.Peer
peersToConnect := make([]*nbpeer.Peer, 0, len(aclPeers))
var expiredPeers []*nbpeer.Peer
for _, p := range aclPeers {
@@ -410,35 +411,35 @@ func (c *NetworkMapComponents) filterPeersByLoginExpiration(aclPeers []*nbpeer.P
return peersToConnect, expiredPeers
}
func (c *NetworkMapComponents) getPeerDNSManagementStatus(peerID string) bool {
peerGroups := c.GetPeerGroups(peerID)
enabled := true
func (c *NetworkMapComponents) getPeerDNSManagementStatusFromGroups(peerGroups map[string]struct{}) bool {
for _, groupID := range c.DNSSettings.DisabledManagementGroups {
if _, found := peerGroups[groupID]; found {
enabled = false
break
return false
}
}
return enabled
return true
}
func (c *NetworkMapComponents) getPeerNSGroups(peerID string) []*nbdns.NameServerGroup {
groupList := c.GetPeerGroups(peerID)
func (c *NetworkMapComponents) getPeerNSGroupsFromGroups(peerID string, groupList map[string]struct{}) []*nbdns.NameServerGroup {
var peerNSGroups []*nbdns.NameServerGroup
targetPeerInfo := c.GetPeerInfo(peerID)
if targetPeerInfo == nil {
return peerNSGroups
}
peerIPStr := targetPeerInfo.IP.String()
for _, nsGroup := range c.NameServerGroups {
if !nsGroup.Enabled {
continue
}
for _, gID := range nsGroup.Groups {
_, found := groupList[gID]
if found {
targetPeerInfo := c.GetPeerInfo(peerID)
if targetPeerInfo != nil && !c.peerIsNameserver(targetPeerInfo, nsGroup) {
if _, found := groupList[gID]; found {
if !c.peerIsNameserver(peerIPStr, nsGroup) {
peerNSGroups = append(peerNSGroups, nsGroup.Copy())
break
}
break
}
}
}
@@ -446,9 +447,9 @@ func (c *NetworkMapComponents) getPeerNSGroups(peerID string) []*nbdns.NameServe
return peerNSGroups
}
func (c *NetworkMapComponents) peerIsNameserver(peerInfo *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool {
func (c *NetworkMapComponents) peerIsNameserver(peerIPStr string, nsGroup *nbdns.NameServerGroup) bool {
for _, ns := range nsGroup.NameServers {
if peerInfo.IP.String() == ns.IP.String() {
if peerIPStr == ns.IP.String() {
return true
}
}
@@ -489,14 +490,13 @@ func (c *NetworkMapComponents) getRoutingPeerRoutes(peerID string) (enabledRoute
}
seenRoute[r.ID] = struct{}{}
routeObj := c.copyRoute(r)
routeObj.Peer = peerInfo.Key
r.Peer = peerInfo.Key
if r.Enabled {
enabledRoutes = append(enabledRoutes, routeObj)
enabledRoutes = append(enabledRoutes, r)
return
}
disabledRoutes = append(disabledRoutes, routeObj)
disabledRoutes = append(disabledRoutes, r)
}
for _, r := range c.Routes {
@@ -510,7 +510,7 @@ func (c *NetworkMapComponents) getRoutingPeerRoutes(peerID string) (enabledRoute
continue
}
newPeerRoute := c.copyRoute(r)
newPeerRoute := r.Copy()
newPeerRoute.Peer = id
newPeerRoute.PeerGroups = nil
newPeerRoute.ID = route.ID(string(r.ID) + ":" + id)
@@ -519,50 +519,13 @@ func (c *NetworkMapComponents) getRoutingPeerRoutes(peerID string) (enabledRoute
}
}
if r.Peer == peerID {
takeRoute(c.copyRoute(r))
takeRoute(r.Copy())
}
}
return enabledRoutes, disabledRoutes
}
func (c *NetworkMapComponents) copyRoute(r *route.Route) *route.Route {
var groups, accessControlGroups, peerGroups []string
var domains domain.List
if r.Groups != nil {
groups = append([]string{}, r.Groups...)
}
if r.AccessControlGroups != nil {
accessControlGroups = append([]string{}, r.AccessControlGroups...)
}
if r.PeerGroups != nil {
peerGroups = append([]string{}, r.PeerGroups...)
}
if r.Domains != nil {
domains = append(domain.List{}, r.Domains...)
}
return &route.Route{
ID: r.ID,
AccountID: r.AccountID,
Network: r.Network,
NetworkType: r.NetworkType,
Description: r.Description,
Peer: r.Peer,
PeerID: r.PeerID,
Metric: r.Metric,
Masquerade: r.Masquerade,
NetID: r.NetID,
Enabled: r.Enabled,
Groups: groups,
AccessControlGroups: accessControlGroups,
PeerGroups: peerGroups,
Domains: domains,
KeepRoute: r.KeepRoute,
SkipAutoApply: r.SkipAutoApply,
}
}
func (c *NetworkMapComponents) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route {
var filteredRoutes []*route.Route

View File

@@ -61,6 +61,10 @@ type Settings struct {
// AutoUpdateVersion client auto-update version
AutoUpdateVersion string `gorm:"default:'disabled'"`
// AutoUpdateAlways when true, updates are installed automatically in the background;
// when false, updates require user interaction from the UI
AutoUpdateAlways bool `gorm:"default:false"`
// EmbeddedIdpEnabled indicates if the embedded identity provider is enabled.
// This is a runtime-only field, not stored in the database.
EmbeddedIdpEnabled bool `gorm:"-"`
@@ -91,6 +95,7 @@ func (s *Settings) Copy() *Settings {
DNSDomain: s.DNSDomain,
NetworkRange: s.NetworkRange,
AutoUpdateVersion: s.AutoUpdateVersion,
AutoUpdateAlways: s.AutoUpdateAlways,
EmbeddedIdpEnabled: s.EmbeddedIdpEnabled,
LocalAuthDisabled: s.LocalAuthDisabled,
}