diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index c24965e8d..205327ef5 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -11,7 +11,7 @@ import ( "go.opentelemetry.io/otel" "google.golang.org/grpc" - "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator" nbcache "github.com/netbirdio/netbird/management/server/cache" @@ -109,7 +109,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp t.Fatal(err) } - iv, _ := integrations.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore) + iv, _ := validator.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore) metrics, err := telemetry.NewDefaultAppMetrics(ctx) require.NoError(t, err) diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 834a49a09..289f1906f 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -27,7 +27,7 @@ import ( "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/management/server/job" - "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" @@ -66,8 +66,8 @@ import ( "github.com/netbirdio/netbird/route" mgmt "github.com/netbirdio/netbird/shared/management/client" mgmtProto "github.com/netbirdio/netbird/shared/management/proto" - relayClient "github.com/netbirdio/netbird/shared/relay/client" "github.com/netbirdio/netbird/shared/netiputil" + relayClient "github.com/netbirdio/netbird/shared/relay/client" signal "github.com/netbirdio/netbird/shared/signal/client" "github.com/netbirdio/netbird/shared/signal/proto" signalServer "github.com/netbirdio/netbird/signal/server" @@ -1641,7 +1641,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri return nil, "", err } - ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore) + ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) diff --git a/client/server/server_test.go b/client/server/server_test.go index 641cd85fe..66e0fcc4c 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -13,7 +13,7 @@ import ( "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" - "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" @@ -315,7 +315,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve return nil, "", err } - ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore) + ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) diff --git a/client/wasm/internal/rdp/cert_validation.go b/client/wasm/internal/rdp/cert_validation.go index 1678c3996..47e30b6e6 100644 --- a/client/wasm/internal/rdp/cert_validation.go +++ b/client/wasm/internal/rdp/cert_validation.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "sync" "syscall/js" "time" @@ -13,7 +14,7 @@ import ( ) const ( - certValidationTimeout = 60 * time.Second + certValidationTimeout = 5 * time.Minute ) func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, certChain [][]byte) (bool, error) { @@ -46,17 +47,31 @@ func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, cert promise := conn.wsHandlers.Call("onCertificateRequest", certInfo) - resultChan := make(chan bool) - errorChan := make(chan error) + resultChan := make(chan bool, 1) + errorChan := make(chan error, 1) - promise.Call("then", js.FuncOf(func(this js.Value, args []js.Value) interface{} { - result := args[0].Bool() - resultChan <- result + // Release from inside the callbacks so a post-timeout promise resolution + // does not invoke an already-released func. + var thenFn, catchFn js.Func + var releaseOnce sync.Once + release := func() { + releaseOnce.Do(func() { + thenFn.Release() + catchFn.Release() + }) + } + thenFn = js.FuncOf(func(this js.Value, args []js.Value) interface{} { + defer release() + resultChan <- args[0].Bool() return nil - })).Call("catch", js.FuncOf(func(this js.Value, args []js.Value) interface{} { + }) + catchFn = js.FuncOf(func(this js.Value, args []js.Value) interface{} { + defer release() errorChan <- fmt.Errorf("certificate validation failed") return nil - })) + }) + + promise.Call("then", thenFn).Call("catch", catchFn) select { case result := <-resultChan: diff --git a/client/wasm/internal/rdp/rdcleanpath.go b/client/wasm/internal/rdp/rdcleanpath.go index 6c36fdec6..ee420dca4 100644 --- a/client/wasm/internal/rdp/rdcleanpath.go +++ b/client/wasm/internal/rdp/rdcleanpath.go @@ -11,6 +11,7 @@ import ( "io" "net" "sync" + "sync/atomic" "syscall/js" "time" @@ -57,6 +58,8 @@ type RDCleanPathProxy struct { } activeConnections map[string]*proxyConnection destinations map[string]string + pendingHandlers map[string]js.Func + nextID atomic.Uint64 mu sync.Mutex } @@ -66,8 +69,15 @@ type proxyConnection struct { rdpConn net.Conn tlsConn *tls.Conn wsHandlers js.Value - ctx context.Context - cancel context.CancelFunc + // Go-side callbacks exposed to JS. js.FuncOf pins the Go closure in a + // global handle map and MUST be released, otherwise every connection + // leaks the Go memory the closure captures. + wsHandlerFn js.Func + onMessageFn js.Func + onCloseFn js.Func + cleanupOnce sync.Once + ctx context.Context + cancel context.CancelFunc } // NewRDCleanPathProxy creates a new RDCleanPath proxy @@ -80,7 +90,11 @@ func NewRDCleanPathProxy(client interface { } } -// CreateProxy creates a new proxy endpoint for the given destination +// CreateProxy creates a new proxy endpoint for the given destination. +// The registered handler fn and its destinations/pendingHandlers entries are +// only released once a connection is established and cleanupConnection runs. +// If a caller invokes CreateProxy but never connects to the returned URL, +// those entries stay pinned for the lifetime of the page. func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value { destination := net.JoinHostPort(hostname, port) @@ -88,7 +102,7 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value { resolve := args[0] go func() { - proxyID := fmt.Sprintf("proxy_%d", len(p.activeConnections)) + proxyID := fmt.Sprintf("proxy_%d", p.nextID.Add(1)) p.mu.Lock() if p.destinations == nil { @@ -100,7 +114,7 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value { proxyURL := fmt.Sprintf("%s://%s/%s", RDCleanPathProxyScheme, RDCleanPathProxyHost, proxyID) // Register the WebSocket handler for this specific proxy - js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), js.FuncOf(func(_ js.Value, args []js.Value) any { + handlerFn := js.FuncOf(func(_ js.Value, args []js.Value) any { if len(args) < 1 { return js.ValueOf("error: requires WebSocket argument") } @@ -108,7 +122,14 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value { ws := args[0] p.HandleWebSocketConnection(ws, proxyID) return nil - })) + }) + p.mu.Lock() + if p.pendingHandlers == nil { + p.pendingHandlers = make(map[string]js.Func) + } + p.pendingHandlers[proxyID] = handlerFn + p.mu.Unlock() + js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), handlerFn) log.Infof("Created RDCleanPath proxy endpoint: %s for destination: %s", proxyURL, destination) resolve.Invoke(proxyURL) @@ -142,6 +163,10 @@ func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string p.mu.Lock() p.activeConnections[proxyID] = conn + if fn, ok := p.pendingHandlers[proxyID]; ok { + conn.wsHandlerFn = fn + delete(p.pendingHandlers, proxyID) + } p.mu.Unlock() p.setupWebSocketHandlers(ws, conn) @@ -150,7 +175,7 @@ func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string } func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnection) { - ws.Set("onGoMessage", js.FuncOf(func(this js.Value, args []js.Value) any { + conn.onMessageFn = js.FuncOf(func(this js.Value, args []js.Value) any { if len(args) < 1 { return nil } @@ -158,13 +183,15 @@ func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnec data := args[0] go p.handleWebSocketMessage(conn, data) return nil - })) + }) + ws.Set("onGoMessage", conn.onMessageFn) - ws.Set("onGoClose", js.FuncOf(func(_ js.Value, args []js.Value) any { + conn.onCloseFn = js.FuncOf(func(_ js.Value, args []js.Value) any { log.Debug("WebSocket closed by JavaScript") conn.cancel() return nil - })) + }) + ws.Set("onGoClose", conn.onCloseFn) } func (p *RDCleanPathProxy) handleWebSocketMessage(conn *proxyConnection, data js.Value) { @@ -261,25 +288,49 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket [] } func (p *RDCleanPathProxy) cleanupConnection(conn *proxyConnection) { - log.Debugf("Cleaning up connection %s", conn.id) - conn.cancel() - if conn.tlsConn != nil { - log.Debug("Closing TLS connection") - if err := conn.tlsConn.Close(); err != nil { - log.Debugf("Error closing TLS connection: %v", err) + conn.cleanupOnce.Do(func() { + log.Debugf("Cleaning up connection %s", conn.id) + conn.cancel() + if conn.tlsConn != nil { + log.Debug("Closing TLS connection") + if err := conn.tlsConn.Close(); err != nil { + log.Debugf("Error closing TLS connection: %v", err) + } + conn.tlsConn = nil } - conn.tlsConn = nil - } - if conn.rdpConn != nil { - log.Debug("Closing TCP connection") - if err := conn.rdpConn.Close(); err != nil { - log.Debugf("Error closing TCP connection: %v", err) + if conn.rdpConn != nil { + log.Debug("Closing TCP connection") + if err := conn.rdpConn.Close(); err != nil { + log.Debugf("Error closing TCP connection: %v", err) + } + conn.rdpConn = nil } - conn.rdpConn = nil - } - p.mu.Lock() - delete(p.activeConnections, conn.id) - p.mu.Unlock() + js.Global().Delete(fmt.Sprintf("handleRDCleanPathWebSocket_%s", conn.id)) + + // Detach before releasing so late JS calls surface as TypeError instead + // of silent "call to released function". + if conn.wsHandlers.Truthy() { + conn.wsHandlers.Set("onGoMessage", js.Undefined()) + conn.wsHandlers.Set("onGoClose", js.Undefined()) + } + + // wsHandlerFn may be zero-value if the pending handler lookup missed. + if conn.wsHandlerFn.Truthy() { + conn.wsHandlerFn.Release() + } + if conn.onMessageFn.Truthy() { + conn.onMessageFn.Release() + } + if conn.onCloseFn.Truthy() { + conn.onCloseFn.Release() + } + + p.mu.Lock() + delete(p.activeConnections, conn.id) + delete(p.destinations, conn.id) + delete(p.pendingHandlers, conn.id) + p.mu.Unlock() + }) } func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) { diff --git a/client/wasm/internal/ssh/handlers.go b/client/wasm/internal/ssh/handlers.go index ea64eb0aa..6d33916a5 100644 --- a/client/wasm/internal/ssh/handlers.go +++ b/client/wasm/internal/ssh/handlers.go @@ -13,7 +13,7 @@ import ( func CreateJSInterface(client *Client) js.Value { jsInterface := js.Global().Get("Object").Call("create", js.Null()) - jsInterface.Set("write", js.FuncOf(func(this js.Value, args []js.Value) any { + writeFunc := js.FuncOf(func(this js.Value, args []js.Value) any { if len(args) < 1 { return js.ValueOf(false) } @@ -32,9 +32,10 @@ func CreateJSInterface(client *Client) js.Value { _, err := client.Write(bytes) return js.ValueOf(err == nil) - })) + }) + jsInterface.Set("write", writeFunc) - jsInterface.Set("resize", js.FuncOf(func(this js.Value, args []js.Value) any { + resizeFunc := js.FuncOf(func(this js.Value, args []js.Value) any { if len(args) < 2 { return js.ValueOf(false) } @@ -42,14 +43,26 @@ func CreateJSInterface(client *Client) js.Value { rows := args[1].Int() err := client.Resize(cols, rows) return js.ValueOf(err == nil) - })) + }) + jsInterface.Set("resize", resizeFunc) - jsInterface.Set("close", js.FuncOf(func(this js.Value, args []js.Value) any { + closeFunc := js.FuncOf(func(this js.Value, args []js.Value) any { client.Close() return js.Undefined() - })) + }) + jsInterface.Set("close", closeFunc) - go readLoop(client, jsInterface) + go func() { + readLoop(client, jsInterface) + // Detach before releasing so late JS calls surface as TypeError instead + // of silent "call to released function". + jsInterface.Set("write", js.Undefined()) + jsInterface.Set("resize", js.Undefined()) + jsInterface.Set("close", js.Undefined()) + writeFunc.Release() + resizeFunc.Release() + closeFunc.Release() + }() return jsInterface } diff --git a/combined/cmd/root.go b/combined/cmd/root.go index db986b4d4..78290388b 100644 --- a/combined/cmd/root.go +++ b/combined/cmd/root.go @@ -332,7 +332,7 @@ func setupServerHooks(servers *serverInstances, cfg *CombinedConfig) { log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress) } - s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg)) + s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), s.IDPHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg)) if servers.relaySrv != nil { log.Infof("Relay WebSocket handler added (path: /relay)") } @@ -521,7 +521,7 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (* } // createCombinedHandler creates an HTTP handler that multiplexes Management, Signal (via wsproxy), and Relay WebSocket traffic -func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler { +func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, idpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler { wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter)) var relayAcceptFn func(conn listener.Conn) @@ -556,6 +556,10 @@ func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, re http.Error(w, "Relay service not enabled", http.StatusNotFound) } + // Embedded IdP (Dex) + case idpHandler != nil && strings.HasPrefix(r.URL.Path, "/oauth2"): + idpHandler.ServeHTTP(w, r) + // Management HTTP API (default) default: httpHandler.ServeHTTP(w, r) diff --git a/management/internals/controllers/network_map/update_channel/updatechannel.go b/management/internals/controllers/network_map/update_channel/updatechannel.go index 5f7db5300..91627bf15 100644 --- a/management/internals/controllers/network_map/update_channel/updatechannel.go +++ b/management/internals/controllers/network_map/update_channel/updatechannel.go @@ -51,7 +51,7 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda found = true select { case channel <- update: - log.WithContext(ctx).Debugf("update was sent to channel for peer %s", peerID) + log.WithContext(ctx).Tracef("update was sent to channel for peer %s", peerID) default: dropped = true log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(channel)) diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index 7c655f020..46e475143 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -10,8 +10,10 @@ import ( "slices" "time" + "github.com/gorilla/mux" grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" + "github.com/rs/cors" "github.com/rs/xid" log "github.com/sirupsen/logrus" "google.golang.org/grpc" @@ -19,7 +21,6 @@ import ( "google.golang.org/grpc/keepalive" cachestore "github.com/eko/gocache/lib/v4/store" - "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/formatter/hook" @@ -27,16 +28,20 @@ import ( accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/activity" + activitystore "github.com/netbirdio/netbird/management/server/activity/store" nbcache "github.com/netbirdio/netbird/management/server/cache" nbContext "github.com/netbirdio/netbird/management/server/context" nbhttp "github.com/netbirdio/netbird/management/server/http" "github.com/netbirdio/netbird/management/server/http/middleware" + "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" mgmtProto "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/util/crypt" ) +const apiPrefix = "/api" + var ( kaep = keepalive.EnforcementPolicy{ MinTime: 15 * time.Second, @@ -94,12 +99,17 @@ func (s *BaseServer) Store() store.Store { func (s *BaseServer) EventStore() activity.Store { return Create(s, func() activity.Store { - integrationMetrics, err := integrations.InitIntegrationMetrics(context.Background(), s.Metrics()) - if err != nil { - log.Fatalf("failed to initialize integration metrics: %v", err) + var err error + key := s.Config.DataStoreEncryptionKey + if key == "" { + log.Debugf("generate new activity store encryption key") + key, err = crypt.GenerateKey() + if err != nil { + log.Fatalf("failed to generate event store encryption key: %v", err) + } } - eventStore, _, err := integrations.InitEventStore(context.Background(), s.Config.Datadir, s.Config.DataStoreEncryptionKey, integrationMetrics) + eventStore, err := activitystore.NewSqlStore(context.Background(), s.Config.Datadir, key) if err != nil { log.Fatalf("failed to initialize event store: %v", err) } @@ -110,7 +120,7 @@ func (s *BaseServer) EventStore() activity.Store { func (s *BaseServer) APIHandler() http.Handler { return Create(s, func() http.Handler { - httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter()) + httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.Router(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.PermissionsManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter(), s.IsValidChildAccount) if err != nil { log.Fatalf("failed to create API handler: %v", err) } @@ -118,6 +128,22 @@ func (s *BaseServer) APIHandler() http.Handler { }) } +// IDPHandler returns the HTTP handler for the embedded IdP (Dex), or nil if +// the deployment isn't using the embedded variant. +func (s *BaseServer) IDPHandler() http.Handler { + embeddedIdP, ok := s.IdpManager().(*idp.EmbeddedIdPManager) + if !ok || embeddedIdP == nil { + return nil + } + return cors.AllowAll().Handler(embeddedIdP.Handler()) +} + +func (s *BaseServer) Router() *mux.Router { + return Create(s, func() *mux.Router { + return mux.NewRouter().PathPrefix(apiPrefix).Subrouter() + }) +} + func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter { return Create(s, func() *middleware.APIRateLimiter { cfg, enabled := middleware.RateLimiterConfigFromEnv() diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index 794c3ebe0..1b2556809 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -19,6 +19,7 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" + "github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" @@ -38,7 +39,7 @@ func (s *BaseServer) JobManager() *job.Manager { func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator { return Create(s, func() integrated_validator.IntegratedValidator { - integratedPeerValidator, err := integrations.NewIntegratedValidator( + integratedPeerValidator, err := validator.NewIntegratedValidator( context.Background(), s.PeersManager(), s.SettingsManager(), diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index ea94245d5..a70da855a 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -57,13 +57,7 @@ func (s *BaseServer) GeoLocationManager() geolocation.Geolocation { func (s *BaseServer) PermissionsManager() permissions.Manager { return Create(s, func() permissions.Manager { - manager := integrations.InitPermissionsManager(s.Store(), s.Metrics().GetMeter()) - - s.AfterInit(func(s *BaseServer) { - manager.SetAccountManager(s.AccountManager()) - }) - - return manager + return permissions.NewManager(s.Store()) }) } @@ -153,7 +147,6 @@ func (s *BaseServer) IdpManager() idp.Manager { return idpManager } - return nil }) } @@ -235,3 +228,7 @@ func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager { return &m }) } + +func (s *BaseServer) IsValidChildAccount(_ context.Context, _, _, _ string) bool { + return false +} diff --git a/management/internals/server/server.go b/management/internals/server/server.go index 9b8716da1..63d13baab 100644 --- a/management/internals/server/server.go +++ b/management/internals/server/server.go @@ -188,7 +188,7 @@ func (s *BaseServer) Start(ctx context.Context) error { log.WithContext(srvCtx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String()) } - rootHandler := s.handlerFunc(srvCtx, s.GRPCServer(), s.APIHandler(), s.Metrics().GetMeter()) + rootHandler := s.handlerFunc(srvCtx, s.GRPCServer(), s.APIHandler(), s.IDPHandler(), s.Metrics().GetMeter()) switch { case s.certManager != nil: // a call to certManager.Listener() always creates a new listener so we do it once @@ -299,7 +299,7 @@ func (s *BaseServer) SetHandlerFunc(handler http.Handler) { log.Tracef("custom handler set successfully") } -func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler { +func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, idpHandler http.Handler, meter metric.Meter) http.Handler { // Check if a custom handler was set (for multiplexing additional services) if customHandler, ok := s.GetContainer("customHandler"); ok { if handler, ok := customHandler.(http.Handler); ok { @@ -318,6 +318,8 @@ func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, ht gRPCHandler.ServeHTTP(writer, request) case request.URL.Path == wsproxy.ProxyPath+wsproxy.ManagementComponent: wsProxy.Handler().ServeHTTP(writer, request) + case idpHandler != nil && strings.HasPrefix(request.URL.Path, "/oauth2"): + idpHandler.ServeHTTP(writer, request) default: httpHandler.ServeHTTP(writer, request) } diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 1d8234304..d36e72045 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -437,7 +437,7 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg return nil } - log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String()) + log.WithContext(ctx).Tracef("received an update for peer %s", peerKey.String()) if debouncer.ProcessUpdate(update) { // Send immediately (first update or after quiet period) if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil { @@ -492,7 +492,7 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime) return status.Errorf(codes.Internal, "failed sending update message") } - log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String()) + log.WithContext(ctx).Tracef("sent an update to peer %s", peerKey.String()) return nil } diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 1e2c710db..0abdb854d 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -15,15 +15,13 @@ import ( "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxytoken" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" idpmanager "github.com/netbirdio/netbird/management/server/idp" - "github.com/netbirdio/management-integrations/integrations" - "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/modules/zones" zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager" @@ -32,12 +30,10 @@ import ( "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/settings" - "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/http/handlers/proxy" - nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/geolocation" nbgroups "github.com/netbirdio/netbird/management/server/groups" @@ -56,17 +52,14 @@ import ( "github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" nbinstance "github.com/netbirdio/netbird/management/server/instance" - "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" nbnetworks "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" "github.com/netbirdio/netbird/management/server/telemetry" ) -const apiPrefix = "/api" - // NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. -func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) { +func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, permissionsManager permissions.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter, isValidChildAccount middleware.IsValidChildAccountFunc) (http.Handler, error) { // Register bypass paths for unauthenticated endpoints if err := bypass.AddBypassPath("/api/instance"); err != nil { @@ -100,25 +93,16 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks accountManager.GetUserFromUserAuth, rateLimiter, appMetrics.GetMeter(), + isValidChildAccount, ) corsMiddleware := cors.AllowAll() - rootRouter := mux.NewRouter() metricsMiddleware := appMetrics.HTTPMiddleware() - prefix := apiPrefix - router := rootRouter.PathPrefix(prefix).Subrouter() - router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler) - if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, integratedValidator, appMetrics.GetMeter(), permissionsManager, peersManager, proxyController, settingsManager); err != nil { - return nil, fmt.Errorf("register integrations endpoints: %w", err) - } - - // Check if embedded IdP is enabled for instance manager - embeddedIdP, embeddedIdpEnabled := idpManager.(*idpmanager.EmbeddedIdPManager) - instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), embeddedIdP) + instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), idpManager) if err != nil { return nil, fmt.Errorf("failed to create instance manager: %w", err) } @@ -154,10 +138,5 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks oauthHandler.RegisterEndpoints(router) } - // Mount embedded IdP handler at /oauth2 path if configured - if embeddedIdpEnabled { - rootRouter.PathPrefix("/oauth2").Handler(corsMiddleware.Handler(embeddedIdP.Handler())) - } - - return rootRouter, nil + return router, nil } diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 6d075d9c2..34df0de23 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -11,8 +11,6 @@ import ( log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/metric" - "github.com/netbirdio/management-integrations/integrations" - serverauth "github.com/netbirdio/netbird/management/server/auth" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" @@ -27,6 +25,8 @@ type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth) err type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) +type IsValidChildAccountFunc func(ctx context.Context, userID, accountID, childAccountID string) bool + // AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens type AuthMiddleware struct { authManager serverauth.Manager @@ -35,6 +35,7 @@ type AuthMiddleware struct { syncUserJWTGroups SyncUserJWTGroupsFunc rateLimiter *APIRateLimiter patUsageTracker *PATUsageTracker + isValidChildAccount IsValidChildAccountFunc } // NewAuthMiddleware instance constructor @@ -45,6 +46,7 @@ func NewAuthMiddleware( getUserFromUserAuth GetUserFromUserAuthFunc, rateLimiter *APIRateLimiter, meter metric.Meter, + isValidChildAccount IsValidChildAccountFunc, ) *AuthMiddleware { var patUsageTracker *PATUsageTracker if meter != nil { @@ -62,6 +64,7 @@ func NewAuthMiddleware( getUserFromUserAuth: getUserFromUserAuth, rateLimiter: rateLimiter, patUsageTracker: patUsageTracker, + isValidChildAccount: isValidChildAccount, } } @@ -124,7 +127,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts [] } if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 { - if integrations.IsValidChildAccount(ctx, userAuth.UserId, userAuth.AccountId, impersonate[0]) { + if m.isValidChildAccount(ctx, userAuth.UserId, userAuth.AccountId, impersonate[0]) { userAuth.AccountId = impersonate[0] userAuth.IsChild = true } @@ -203,7 +206,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts [] } if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 { - if integrations.IsValidChildAccount(r.Context(), userAuth.UserId, userAuth.AccountId, impersonate[0]) { + if m.isValidChildAccount(r.Context(), userAuth.UserId, userAuth.AccountId, impersonate[0]) { userAuth.AccountId = impersonate[0] userAuth.IsChild = true } diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index 8f736fbfd..24cf8fce5 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -211,6 +211,7 @@ func TestAuthMiddleware_Handler(t *testing.T) { }, disabledLimiter, nil, + func(_ context.Context, _, _, _ string) bool { return false }, ) handlerToTest := authMiddleware.Handler(nextHandler) @@ -270,6 +271,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { }, NewAPIRateLimiter(rateLimitConfig), nil, + func(_ context.Context, _, _, _ string) bool { return false }, ) handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -322,6 +324,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { }, NewAPIRateLimiter(rateLimitConfig), nil, + func(_ context.Context, _, _, _ string) bool { return false }, ) handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -365,6 +368,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { }, NewAPIRateLimiter(rateLimitConfig), nil, + func(_ context.Context, _, _, _ string) bool { return false }, ) handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -409,6 +413,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { }, NewAPIRateLimiter(rateLimitConfig), nil, + func(_ context.Context, _, _, _ string) bool { return false }, ) handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -473,6 +478,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { }, NewAPIRateLimiter(rateLimitConfig), nil, + func(_ context.Context, _, _, _ string) bool { return false }, ) handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -532,6 +538,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { }, NewAPIRateLimiter(rateLimitConfig), nil, + func(_ context.Context, _, _, _ string) bool { return false }, ) handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -587,6 +594,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { }, NewAPIRateLimiter(rateLimitConfig), nil, + func(_ context.Context, _, _, _ string) bool { return false }, ) handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -687,6 +695,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { }, disabledLimiter, nil, + func(_ context.Context, _, _, _ string) bool { return false }, ) for _, tc := range tt { diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 3c4ea98d0..8da9c7ad4 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -7,6 +7,7 @@ import ( "time" "github.com/golang-jwt/jwt/v5" + "github.com/gorilla/mux" "github.com/stretchr/testify/assert" "go.opentelemetry.io/otel/metric/noop" @@ -135,7 +136,8 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) - apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil) + apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter() + apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil) if err != nil { t.Fatalf("Failed to create API handler: %v", err) } @@ -264,7 +266,8 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) - apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil) + apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter() + apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil) if err != nil { t.Fatalf("Failed to create API handler: %v", err) } diff --git a/management/server/integrations/integrated_validator/validator/validator.go b/management/server/integrations/integrated_validator/validator/validator.go new file mode 100644 index 000000000..db1d34373 --- /dev/null +++ b/management/server/integrations/integrated_validator/validator/validator.go @@ -0,0 +1,62 @@ +package validator + +import ( + "context" + + cachestore "github.com/eko/gocache/lib/v4/store" + + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/server/activity" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/proto" +) + +type IntegratedValidatorImpl struct{} + +func NewIntegratedValidator(_ context.Context, _ peers.Manager, _ settings.Manager, _ activity.Store, _ cachestore.StoreInterface) (*IntegratedValidatorImpl, error) { + return &IntegratedValidatorImpl{}, nil +} + +func (v *IntegratedValidatorImpl) ValidateExtraSettings(context.Context, *types.ExtraSettings, *types.ExtraSettings, string, string) error { + return nil +} + +func (v *IntegratedValidatorImpl) ValidatePeer(_ context.Context, update *nbpeer.Peer, _ *nbpeer.Peer, _ string, _ string, _ string, _ []string, _ *types.ExtraSettings) (*nbpeer.Peer, bool, error) { + return update, false, nil +} + +func (v *IntegratedValidatorImpl) PreparePeer(_ context.Context, _ string, peer *nbpeer.Peer, _ []string, _ *types.ExtraSettings, _ bool) *nbpeer.Peer { + return peer.Copy() +} + +func (v *IntegratedValidatorImpl) IsNotValidPeer(_ context.Context, _ string, _ *nbpeer.Peer, _ []string, _ *types.ExtraSettings) (bool, bool, error) { + return false, false, nil +} + +func (v *IntegratedValidatorImpl) GetValidatedPeers(_ context.Context, _ string, _ []*types.Group, peers []*nbpeer.Peer, _ *types.ExtraSettings) (map[string]struct{}, error) { + validatedPeers := make(map[string]struct{}) + for _, p := range peers { + validatedPeers[p.ID] = struct{}{} + } + return validatedPeers, nil +} + +func (v *IntegratedValidatorImpl) GetInvalidPeers(_ context.Context, _ string, _ *types.ExtraSettings) (map[string]string, error) { + return make(map[string]string), nil +} + +func (v *IntegratedValidatorImpl) PeerDeleted(_ context.Context, _, _ string, _ *types.ExtraSettings) error { + return nil +} + +func (v *IntegratedValidatorImpl) SetPeerInvalidationListener(_ func(accountID string, peerIDs []string)) { +} + +func (v *IntegratedValidatorImpl) Stop(_ context.Context) { +} + +func (v *IntegratedValidatorImpl) ValidateFlowResponse(_ context.Context, _ string, flowResponse *proto.PKCEAuthorizationFlow) *proto.PKCEAuthorizationFlow { + return flowResponse +} diff --git a/management/server/posture/nb_version.go b/management/server/posture/nb_version.go index 33bf01ad1..6e4757021 100644 --- a/management/server/posture/nb_version.go +++ b/management/server/posture/nb_version.go @@ -6,7 +6,6 @@ import ( "strings" "github.com/hashicorp/go-version" - log "github.com/sirupsen/logrus" nbpeer "github.com/netbirdio/netbird/management/server/peer" ) @@ -33,9 +32,6 @@ func (n *NBVersionCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, err return true, nil } - log.WithContext(ctx).Debugf("peer %s NB version %s is older than minimum allowed version %s", - peer.ID, peer.Meta.WtVersion, n.MinVersion) - return false, nil } diff --git a/management/server/posture/os_version.go b/management/server/posture/os_version.go index 2ef97a066..724bebfcc 100644 --- a/management/server/posture/os_version.go +++ b/management/server/posture/os_version.go @@ -100,8 +100,6 @@ func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *M return true, nil } - log.WithContext(ctx).Debugf("peer %s OS version %s is older than minimum allowed version %s", peerGoOS, peerVersion, check.MinVersion) - return false, nil } @@ -125,7 +123,5 @@ func checkMinKernelVersion(ctx context.Context, peerGoOS, peerVersion string, ch return true, nil } - log.WithContext(ctx).Debugf("peer %s kernel version %s is older than minimum allowed version %s", peerGoOS, peerVersion, check.MinKernelVersion) - return false, nil } diff --git a/management/server/user.go b/management/server/user.go index 892d982e7..60571a702 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -762,7 +762,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact } // Ensure the initiator still has admin privileges - if initiatorUser.HasAdminPower() && !freshInitiator.HasAdminPower() { + if !freshInitiator.HasAdminPower() { return false, nil, nil, nil, status.Errorf(status.PermissionDenied, "initiator role was changed during request processing") } initiatorUser = freshInitiator @@ -906,19 +906,23 @@ func validateUserUpdate(groupsMap map[string]*types.Group, initiatorUser, oldUse return nil } + if !initiatorUser.HasAdminPower() { + return status.Errorf(status.PermissionDenied, "only admins and owners can update users") + } + if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked { return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves") } if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && update.Role != initiatorUser.Role { return status.Errorf(status.PermissionDenied, "admins can't change their role") } - if initiatorUser.Role == types.UserRoleAdmin && oldUser.Role == types.UserRoleOwner && update.Role != oldUser.Role { + if initiatorUser.Role != types.UserRoleOwner && oldUser.Role == types.UserRoleOwner && update.Role != oldUser.Role { return status.Errorf(status.PermissionDenied, "only owners can remove owner role from their user") } - if initiatorUser.Role == types.UserRoleAdmin && oldUser.Role == types.UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() { + if oldUser.Role == types.UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() { return status.Errorf(status.PermissionDenied, "unable to block owner user") } - if initiatorUser.Role == types.UserRoleAdmin && update.Role == types.UserRoleOwner && update.Role != oldUser.Role { + if initiatorUser.Role != types.UserRoleOwner && update.Role == types.UserRoleOwner && update.Role != oldUser.Role { return status.Errorf(status.PermissionDenied, "only owners can add owner role to other users") } if oldUser.IsServiceUser && update.Role == types.UserRoleOwner { diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index a8e8172dc..be2c009ad 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -17,7 +17,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator" ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" @@ -103,7 +103,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { t.Fatal(err) } - ia, _ := integrations.NewIntegratedValidator(ctx, peersManger, settingsManagerMock, eventStore, cacheStore) + ia, _ := validator.NewIntegratedValidator(ctx, peersManger, settingsManagerMock, eventStore, cacheStore) metrics, err := telemetry.NewDefaultAppMetrics(ctx) require.NoError(t, err)