diff --git a/client/cmd/root.go b/client/cmd/root.go index aa5b98dfd..c872fe9f6 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -75,6 +75,7 @@ var ( mtu uint16 profilesDisabled bool updateSettingsDisabled bool + networksDisabled bool rootCmd = &cobra.Command{ Use: "netbird", diff --git a/client/cmd/service.go b/client/cmd/service.go index 5ff16eaeb..f1123ce8c 100644 --- a/client/cmd/service.go +++ b/client/cmd/service.go @@ -44,10 +44,13 @@ func init() { serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd) serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles") serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings") + serviceCmd.PersistentFlags().BoolVar(&networksDisabled, "disable-networks", false, "Disables network selection. If enabled, the client will not allow listing, selecting, or deselecting networks. To persist, use: netbird service install --disable-networks") rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name") serviceEnvDesc := `Sets extra environment variables for the service. ` + `You can specify a comma-separated list of KEY=VALUE pairs. ` + + `New keys are merged with previously saved env vars; existing keys are overwritten. ` + + `Use --service-env "" to clear all saved env vars. ` + `E.g. --service-env NB_LOG_LEVEL=debug,CUSTOM_VAR=value` installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc) diff --git a/client/cmd/service_controller.go b/client/cmd/service_controller.go index 5fe318ddf..0943b6184 100644 --- a/client/cmd/service_controller.go +++ b/client/cmd/service_controller.go @@ -61,7 +61,7 @@ func (p *program) Start(svc service.Service) error { } } - serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled) + serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, networksDisabled) if err := serverInstance.Start(); err != nil { log.Fatalf("failed to start daemon: %v", err) } diff --git a/client/cmd/service_installer.go b/client/cmd/service_installer.go index 28770ea16..5ada6f633 100644 --- a/client/cmd/service_installer.go +++ b/client/cmd/service_installer.go @@ -59,6 +59,10 @@ func buildServiceArguments() []string { args = append(args, "--disable-update-settings") } + if networksDisabled { + args = append(args, "--disable-networks") + } + return args } diff --git a/client/cmd/service_params.go b/client/cmd/service_params.go index 81bd2dbb5..5a86aebc6 100644 --- a/client/cmd/service_params.go +++ b/client/cmd/service_params.go @@ -28,6 +28,7 @@ type serviceParams struct { LogFiles []string `json:"log_files,omitempty"` DisableProfiles bool `json:"disable_profiles,omitempty"` DisableUpdateSettings bool `json:"disable_update_settings,omitempty"` + DisableNetworks bool `json:"disable_networks,omitempty"` ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"` } @@ -78,11 +79,12 @@ func currentServiceParams() *serviceParams { LogFiles: logFiles, DisableProfiles: profilesDisabled, DisableUpdateSettings: updateSettingsDisabled, + DisableNetworks: networksDisabled, } if len(serviceEnvVars) > 0 { parsed, err := parseServiceEnvVars(serviceEnvVars) - if err == nil && len(parsed) > 0 { + if err == nil { params.ServiceEnvVars = parsed } } @@ -142,31 +144,46 @@ func applyServiceParams(cmd *cobra.Command, params *serviceParams) { updateSettingsDisabled = params.DisableUpdateSettings } + if !serviceCmd.PersistentFlags().Changed("disable-networks") { + networksDisabled = params.DisableNetworks + } + applyServiceEnvParams(cmd, params) } // applyServiceEnvParams merges saved service environment variables. -// If --service-env was explicitly set, explicit values win on key conflict -// but saved keys not in the explicit set are carried over. +// If --service-env was explicitly set with values, explicit values win on key +// conflict but saved keys not in the explicit set are carried over. +// If --service-env was explicitly set to empty, all saved env vars are cleared. // If --service-env was not set, saved env vars are used entirely. func applyServiceEnvParams(cmd *cobra.Command, params *serviceParams) { - if len(params.ServiceEnvVars) == 0 { - return - } - if !cmd.Flags().Changed("service-env") { - // No explicit env vars: rebuild serviceEnvVars from saved params. - serviceEnvVars = envMapToSlice(params.ServiceEnvVars) + if len(params.ServiceEnvVars) > 0 { + // No explicit env vars: rebuild serviceEnvVars from saved params. + serviceEnvVars = envMapToSlice(params.ServiceEnvVars) + } return } - // Explicit env vars were provided: merge saved values underneath. + // Flag was explicitly set: parse what the user provided. explicit, err := parseServiceEnvVars(serviceEnvVars) if err != nil { cmd.PrintErrf("Warning: parse explicit service env vars for merge: %v\n", err) return } + // If the user passed an empty value (e.g. --service-env ""), clear all + // saved env vars rather than merging. + if len(explicit) == 0 { + serviceEnvVars = nil + return + } + + if len(params.ServiceEnvVars) == 0 { + return + } + + // Merge saved values underneath explicit ones. merged := make(map[string]string, len(params.ServiceEnvVars)+len(explicit)) maps.Copy(merged, params.ServiceEnvVars) maps.Copy(merged, explicit) // explicit wins on conflict diff --git a/client/cmd/service_params_test.go b/client/cmd/service_params_test.go index 3bc8e4f60..7e04e5abe 100644 --- a/client/cmd/service_params_test.go +++ b/client/cmd/service_params_test.go @@ -327,6 +327,41 @@ func TestApplyServiceEnvParams_NotChanged(t *testing.T) { assert.Equal(t, map[string]string{"FROM_SAVED": "val"}, result) } +func TestApplyServiceEnvParams_ExplicitEmptyClears(t *testing.T) { + origServiceEnvVars := serviceEnvVars + t.Cleanup(func() { serviceEnvVars = origServiceEnvVars }) + + // Simulate --service-env "" which produces [""] in the slice. + serviceEnvVars = []string{""} + + cmd := &cobra.Command{} + cmd.Flags().StringSlice("service-env", nil, "") + require.NoError(t, cmd.Flags().Set("service-env", "")) + + saved := &serviceParams{ + ServiceEnvVars: map[string]string{"OLD_VAR": "should_be_cleared"}, + } + + applyServiceEnvParams(cmd, saved) + + assert.Nil(t, serviceEnvVars, "explicit empty --service-env should clear all saved env vars") +} + +func TestCurrentServiceParams_EmptyEnvVarsAfterParse(t *testing.T) { + origServiceEnvVars := serviceEnvVars + t.Cleanup(func() { serviceEnvVars = origServiceEnvVars }) + + // Simulate --service-env "" which produces [""] in the slice. + serviceEnvVars = []string{""} + + params := currentServiceParams() + + // After parsing, the empty string is skipped, resulting in an empty map. + // The map should still be set (not nil) so it overwrites saved values. + assert.NotNil(t, params.ServiceEnvVars, "empty env vars should produce empty map, not nil") + assert.Empty(t, params.ServiceEnvVars, "no valid env vars should be parsed from empty string") +} + // TestServiceParams_FieldsCoveredInFunctions ensures that all serviceParams fields are // referenced in both currentServiceParams() and applyServiceParams(). If a new field is // added to serviceParams but not wired into these functions, this test fails. @@ -500,6 +535,7 @@ func fieldToGlobalVar(field string) string { "LogFiles": "logFiles", "DisableProfiles": "profilesDisabled", "DisableUpdateSettings": "updateSettingsDisabled", + "DisableNetworks": "networksDisabled", "ServiceEnvVars": "serviceEnvVars", } if v, ok := m[field]; ok { diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 4bda33e65..d7564c353 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -13,6 +13,8 @@ import ( "github.com/netbirdio/management-integrations/integrations" + nbcache "github.com/netbirdio/netbird/management/server/cache" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/modules/peers" @@ -100,9 +102,16 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp jobManager := job.NewJobManager(nil, store, peersmanager) - iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore) + ctx := context.Background() - metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + t.Fatal(err) + } + + iv, _ := integrations.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore) + + metrics, err := telemetry.NewDefaultAppMetrics(ctx) require.NoError(t, err) settingsMockManager := settings.NewMockManager(ctrl) @@ -113,12 +122,11 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp Return(&types.Settings{}, nil). AnyTimes() - ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config) - accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + accountManager, err := mgmt.BuildManager(ctx, config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore) if err != nil { t.Fatal(err) } @@ -152,7 +160,7 @@ func startClientDaemon( s := grpc.NewServer() server := client.New(ctx, - "", "", false, false) + "", "", false, false, false) if err := server.Start(); err != nil { t.Fatal(err) } diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index d83798f09..e629f7881 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -21,6 +21,10 @@ const ( // rules chains contains the effective ACL rules chainNameInputRules = "NETBIRD-ACL-INPUT" + + // mangleFwdKey is the entries map key for mangle FORWARD guard rules that prevent + // external DNAT from bypassing ACL rules. + mangleFwdKey = "MANGLE-FORWARD" ) type aclEntries map[string][][]string @@ -274,6 +278,12 @@ func (m *aclManager) cleanChains() error { } } + for _, rule := range m.entries[mangleFwdKey] { + if err := m.iptablesClient.DeleteIfExists(tableMangle, chainFORWARD, rule...); err != nil { + log.Errorf("failed to delete mangle FORWARD guard rule: %v, %s", rule, err) + } + } + for _, ipsetName := range m.ipsetStore.ipsetNames() { if err := m.flushIPSet(ipsetName); err != nil { if errors.Is(err, ipset.ErrSetNotExist) { @@ -303,6 +313,10 @@ func (m *aclManager) createDefaultChains() error { } for chainName, rules := range m.entries { + // mangle FORWARD guard rules are handled separately below + if chainName == mangleFwdKey { + continue + } for _, rule := range rules { if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil { log.Debugf("failed to create input chain jump rule: %s", err) @@ -322,6 +336,13 @@ func (m *aclManager) createDefaultChains() error { } clear(m.optionalEntries) + // Insert mangle FORWARD guard rules to prevent external DNAT bypass. + for _, rule := range m.entries[mangleFwdKey] { + if err := m.iptablesClient.AppendUnique(tableMangle, chainFORWARD, rule...); err != nil { + log.Errorf("failed to add mangle FORWARD guard rule: %v", err) + } + } + return nil } @@ -343,6 +364,22 @@ func (m *aclManager) seedInitialEntries() { m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT}) m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN}) + + // Mangle FORWARD guard: when external DNAT redirects traffic from the wg interface, it + // traverses FORWARD instead of INPUT, bypassing ACL rules. ACCEPT rules in filter FORWARD + // can be inserted above ours. Mangle runs before filter, so these guard rules enforce the + // ACL mark check where it cannot be overridden. + m.appendToEntries(mangleFwdKey, []string{ + "-i", m.wgIface.Name(), + "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", + "-j", "ACCEPT", + }) + m.appendToEntries(mangleFwdKey, []string{ + "-i", m.wgIface.Name(), + "-m", "conntrack", "--ctstate", "DNAT", + "-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), + "-j", "DROP", + }) } func (m *aclManager) seedInitialOptionalEntries() { diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 1f6fe384a..9fa4e51b2 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -55,6 +55,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" @@ -1634,7 +1635,12 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri peersManager := peers.NewManager(store, permissionsManager) jobManager := job.NewJobManager(nil, store, peersManager) - ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore) + cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + return nil, "", err + } + + ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) @@ -1656,7 +1662,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config) - accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) if err != nil { return nil, "", err } diff --git a/client/internal/netflow/conntrack/conntrack.go b/client/internal/netflow/conntrack/conntrack.go index a4ffa3a25..2420b1fdf 100644 --- a/client/internal/netflow/conntrack/conntrack.go +++ b/client/internal/netflow/conntrack/conntrack.go @@ -7,7 +7,9 @@ import ( "fmt" "net/netip" "sync" + "time" + "github.com/cenkalti/backoff/v4" "github.com/google/uuid" log "github.com/sirupsen/logrus" nfct "github.com/ti-mo/conntrack" @@ -17,31 +19,64 @@ import ( nbnet "github.com/netbirdio/netbird/client/net" ) -const defaultChannelSize = 100 +const ( + defaultChannelSize = 100 + reconnectInitInterval = 5 * time.Second + reconnectMaxInterval = 5 * time.Minute + reconnectRandomization = 0.5 +) + +// listener abstracts a netlink conntrack connection for testability. +type listener interface { + Listen(evChan chan<- nfct.Event, numWorkers uint8, groups []netfilter.NetlinkGroup) (chan error, error) + Close() error +} // ConnTrack manages kernel-based conntrack events type ConnTrack struct { flowLogger nftypes.FlowLogger iface nftypes.IFaceMapper - conn *nfct.Conn + conn listener mux sync.Mutex + dial func() (listener, error) instanceID uuid.UUID started bool done chan struct{} sysctlModified bool } +// DialFunc is a constructor for netlink conntrack connections. +type DialFunc func() (listener, error) + +// Option configures a ConnTrack instance. +type Option func(*ConnTrack) + +// WithDialer overrides the default netlink dialer, primarily for testing. +func WithDialer(dial DialFunc) Option { + return func(c *ConnTrack) { + c.dial = dial + } +} + +func defaultDial() (listener, error) { + return nfct.Dial(nil) +} + // New creates a new connection tracker that interfaces with the kernel's conntrack system -func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) *ConnTrack { - return &ConnTrack{ +func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper, opts ...Option) *ConnTrack { + ct := &ConnTrack{ flowLogger: flowLogger, iface: iface, instanceID: uuid.New(), - started: false, + dial: defaultDial, done: make(chan struct{}, 1), } + for _, opt := range opts { + opt(ct) + } + return ct } // Start begins tracking connections by listening for conntrack events. This method is idempotent. @@ -59,8 +94,9 @@ func (c *ConnTrack) Start(enableCounters bool) error { c.EnableAccounting() } - conn, err := nfct.Dial(nil) + conn, err := c.dial() if err != nil { + c.RestoreAccounting() return fmt.Errorf("dial conntrack: %w", err) } c.conn = conn @@ -76,9 +112,16 @@ func (c *ConnTrack) Start(enableCounters bool) error { log.Errorf("Error closing conntrack connection: %v", err) } c.conn = nil + c.RestoreAccounting() return fmt.Errorf("start conntrack listener: %w", err) } + // Drain any stale stop signal from a previous cycle. + select { + case <-c.done: + default: + } + c.started = true go c.receiverRoutine(events, errChan) @@ -92,17 +135,98 @@ func (c *ConnTrack) receiverRoutine(events chan nfct.Event, errChan chan error) case event := <-events: c.handleEvent(event) case err := <-errChan: - log.Errorf("Error from conntrack event listener: %v", err) - if err := c.conn.Close(); err != nil { - log.Errorf("Error closing conntrack connection: %v", err) + if events, errChan = c.handleListenerError(err); events == nil { + return } - return case <-c.done: return } } } +// handleListenerError closes the failed connection and attempts to reconnect. +// Returns new channels on success, or nil if shutdown was requested. +func (c *ConnTrack) handleListenerError(err error) (chan nfct.Event, chan error) { + log.Warnf("conntrack event listener failed: %v", err) + c.closeConn() + return c.reconnect() +} + +func (c *ConnTrack) closeConn() { + c.mux.Lock() + defer c.mux.Unlock() + + if c.conn != nil { + if err := c.conn.Close(); err != nil { + log.Debugf("close conntrack connection: %v", err) + } + c.conn = nil + } +} + +// reconnect attempts to re-establish the conntrack netlink listener with exponential backoff. +// Returns new channels on success, or nil if shutdown was requested. +func (c *ConnTrack) reconnect() (chan nfct.Event, chan error) { + bo := &backoff.ExponentialBackOff{ + InitialInterval: reconnectInitInterval, + RandomizationFactor: reconnectRandomization, + Multiplier: backoff.DefaultMultiplier, + MaxInterval: reconnectMaxInterval, + MaxElapsedTime: 0, // retry indefinitely + Clock: backoff.SystemClock, + } + bo.Reset() + + for { + delay := bo.NextBackOff() + log.Infof("reconnecting conntrack listener in %s", delay) + + select { + case <-c.done: + c.mux.Lock() + c.started = false + c.mux.Unlock() + return nil, nil + case <-time.After(delay): + } + + conn, err := c.dial() + if err != nil { + log.Warnf("reconnect conntrack dial: %v", err) + continue + } + + events := make(chan nfct.Event, defaultChannelSize) + errChan, err := conn.Listen(events, 1, []netfilter.NetlinkGroup{ + netfilter.GroupCTNew, + netfilter.GroupCTDestroy, + }) + if err != nil { + log.Warnf("reconnect conntrack listen: %v", err) + if closeErr := conn.Close(); closeErr != nil { + log.Debugf("close conntrack connection: %v", closeErr) + } + continue + } + + c.mux.Lock() + if !c.started { + // Stop() ran while we were reconnecting. + c.mux.Unlock() + if closeErr := conn.Close(); closeErr != nil { + log.Debugf("close conntrack connection: %v", closeErr) + } + return nil, nil + } + c.conn = conn + c.mux.Unlock() + + log.Infof("conntrack listener reconnected successfully") + + return events, errChan + } +} + // Stop stops the connection tracking. This method is idempotent. func (c *ConnTrack) Stop() { c.mux.Lock() @@ -136,23 +260,27 @@ func (c *ConnTrack) Close() error { c.mux.Lock() defer c.mux.Unlock() - if c.started { - select { - case c.done <- struct{}{}: - default: - } + if !c.started { + return nil } + select { + case c.done <- struct{}{}: + default: + } + + c.started = false + + var closeErr error if c.conn != nil { - err := c.conn.Close() + closeErr = c.conn.Close() c.conn = nil - c.started = false + } - c.RestoreAccounting() + c.RestoreAccounting() - if err != nil { - return fmt.Errorf("close conntrack: %w", err) - } + if closeErr != nil { + return fmt.Errorf("close conntrack: %w", closeErr) } return nil diff --git a/client/internal/netflow/conntrack/conntrack_test.go b/client/internal/netflow/conntrack/conntrack_test.go new file mode 100644 index 000000000..35ceec90d --- /dev/null +++ b/client/internal/netflow/conntrack/conntrack_test.go @@ -0,0 +1,224 @@ +//go:build linux && !android + +package conntrack + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + nfct "github.com/ti-mo/conntrack" + "github.com/ti-mo/netfilter" +) + +type mockListener struct { + errChan chan error + closed atomic.Bool + closedCh chan struct{} +} + +func newMockListener() *mockListener { + return &mockListener{ + errChan: make(chan error, 1), + closedCh: make(chan struct{}), + } +} + +func (m *mockListener) Listen(evChan chan<- nfct.Event, _ uint8, _ []netfilter.NetlinkGroup) (chan error, error) { + return m.errChan, nil +} + +func (m *mockListener) Close() error { + if m.closed.CompareAndSwap(false, true) { + close(m.closedCh) + } + return nil +} + +func TestReconnectAfterError(t *testing.T) { + first := newMockListener() + second := newMockListener() + third := newMockListener() + listeners := []*mockListener{first, second, third} + callCount := atomic.Int32{} + + ct := New(nil, nil, WithDialer(func() (listener, error) { + n := int(callCount.Add(1)) - 1 + return listeners[n], nil + })) + + err := ct.Start(false) + require.NoError(t, err) + + // Inject an error on the first listener. + first.errChan <- assert.AnError + + // Wait for reconnect to complete. + require.Eventually(t, func() bool { + return callCount.Load() >= 2 + }, 15*time.Second, 100*time.Millisecond, "reconnect should dial a new connection") + + // The first connection must have been closed. + select { + case <-first.closedCh: + case <-time.After(2 * time.Second): + t.Fatal("first connection was not closed") + } + + // Verify the receiver is still running by injecting and handling a second error. + second.errChan <- assert.AnError + + require.Eventually(t, func() bool { + return callCount.Load() >= 3 + }, 15*time.Second, 100*time.Millisecond, "second reconnect should succeed") + + ct.Stop() +} + +func TestStopDuringReconnectBackoff(t *testing.T) { + mock := newMockListener() + + ct := New(nil, nil, WithDialer(func() (listener, error) { + return mock, nil + })) + + err := ct.Start(false) + require.NoError(t, err) + + // Trigger an error so the receiver enters reconnect. + mock.errChan <- assert.AnError + + // Wait for the error handler to close the old listener before calling Stop. + select { + case <-mock.closedCh: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for reconnect to start") + } + + // Stop while reconnecting. + ct.Stop() + + ct.mux.Lock() + assert.False(t, ct.started, "started should be false after Stop") + assert.Nil(t, ct.conn, "conn should be nil after Stop") + ct.mux.Unlock() +} + +func TestStopRaceWithReconnectDial(t *testing.T) { + first := newMockListener() + dialStarted := make(chan struct{}) + dialProceed := make(chan struct{}) + second := newMockListener() + callCount := atomic.Int32{} + + ct := New(nil, nil, WithDialer(func() (listener, error) { + n := callCount.Add(1) + if n == 1 { + return first, nil + } + // Second dial: signal that we're in progress, wait for test to call Stop. + close(dialStarted) + <-dialProceed + return second, nil + })) + + err := ct.Start(false) + require.NoError(t, err) + + // Trigger error to enter reconnect. + first.errChan <- assert.AnError + + // Wait for reconnect's second dial to begin. + select { + case <-dialStarted: + case <-time.After(15 * time.Second): + t.Fatal("timed out waiting for reconnect dial") + } + + // Stop while dial is in progress (conn is nil at this point). + ct.Stop() + + // Let the dial complete. reconnect should detect started==false and close the new conn. + close(dialProceed) + + // The second connection should be closed (not leaked). + select { + case <-second.closedCh: + case <-time.After(2 * time.Second): + t.Fatal("second connection was leaked after Stop") + } + + ct.mux.Lock() + assert.False(t, ct.started) + assert.Nil(t, ct.conn) + ct.mux.Unlock() +} + +func TestCloseRaceWithReconnectDial(t *testing.T) { + first := newMockListener() + dialStarted := make(chan struct{}) + dialProceed := make(chan struct{}) + second := newMockListener() + callCount := atomic.Int32{} + + ct := New(nil, nil, WithDialer(func() (listener, error) { + n := callCount.Add(1) + if n == 1 { + return first, nil + } + close(dialStarted) + <-dialProceed + return second, nil + })) + + err := ct.Start(false) + require.NoError(t, err) + + first.errChan <- assert.AnError + + select { + case <-dialStarted: + case <-time.After(15 * time.Second): + t.Fatal("timed out waiting for reconnect dial") + } + + // Close while dial is in progress (conn is nil). + require.NoError(t, ct.Close()) + + close(dialProceed) + + // The second connection should be closed (not leaked). + select { + case <-second.closedCh: + case <-time.After(2 * time.Second): + t.Fatal("second connection was leaked after Close") + } + + ct.mux.Lock() + assert.False(t, ct.started) + assert.Nil(t, ct.conn) + ct.mux.Unlock() +} + +func TestStartIsIdempotent(t *testing.T) { + mock := newMockListener() + callCount := atomic.Int32{} + + ct := New(nil, nil, WithDialer(func() (listener, error) { + callCount.Add(1) + return mock, nil + })) + + err := ct.Start(false) + require.NoError(t, err) + + // Second Start should be a no-op. + err = ct.Start(false) + require.NoError(t, err) + + assert.Equal(t, int32(1), callCount.Load(), "dial should only be called once") + + ct.Stop() +} diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index fa0b2f93b..6506307d3 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -4979,6 +4979,7 @@ type GetFeaturesResponse struct { state protoimpl.MessageState `protogen:"open.v1"` DisableProfiles bool `protobuf:"varint,1,opt,name=disable_profiles,json=disableProfiles,proto3" json:"disable_profiles,omitempty"` DisableUpdateSettings bool `protobuf:"varint,2,opt,name=disable_update_settings,json=disableUpdateSettings,proto3" json:"disable_update_settings,omitempty"` + DisableNetworks bool `protobuf:"varint,3,opt,name=disable_networks,json=disableNetworks,proto3" json:"disable_networks,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -5027,6 +5028,13 @@ func (x *GetFeaturesResponse) GetDisableUpdateSettings() bool { return false } +func (x *GetFeaturesResponse) GetDisableNetworks() bool { + if x != nil { + return x.DisableNetworks + } + return false +} + type TriggerUpdateRequest struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -6472,10 +6480,11 @@ const file_daemon_proto_rawDesc = "" + "\f_profileNameB\v\n" + "\t_username\"\x10\n" + "\x0eLogoutResponse\"\x14\n" + - "\x12GetFeaturesRequest\"x\n" + + "\x12GetFeaturesRequest\"\xa3\x01\n" + "\x13GetFeaturesResponse\x12)\n" + "\x10disable_profiles\x18\x01 \x01(\bR\x0fdisableProfiles\x126\n" + - "\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings\"\x16\n" + + "\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings\x12)\n" + + "\x10disable_networks\x18\x03 \x01(\bR\x0fdisableNetworks\"\x16\n" + "\x14TriggerUpdateRequest\"M\n" + "\x15TriggerUpdateResponse\x12\x18\n" + "\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" + diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 89302c8c3..19976660c 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -727,6 +727,7 @@ message GetFeaturesRequest{} message GetFeaturesResponse{ bool disable_profiles = 1; bool disable_update_settings = 2; + bool disable_networks = 3; } message TriggerUpdateRequest {} diff --git a/client/server/network.go b/client/server/network.go index bb1cce56c..76c5af40e 100644 --- a/client/server/network.go +++ b/client/server/network.go @@ -9,6 +9,8 @@ import ( "strings" "golang.org/x/exp/maps" + "google.golang.org/grpc/codes" + gstatus "google.golang.org/grpc/status" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/route" @@ -27,6 +29,10 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro s.mutex.Lock() defer s.mutex.Unlock() + if s.networksDisabled { + return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled) + } + if s.connectClient == nil { return nil, fmt.Errorf("not connected") } @@ -118,6 +124,10 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ s.mutex.Lock() defer s.mutex.Unlock() + if s.networksDisabled { + return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled) + } + if s.connectClient == nil { return nil, fmt.Errorf("not connected") } @@ -164,6 +174,10 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe s.mutex.Lock() defer s.mutex.Unlock() + if s.networksDisabled { + return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled) + } + if s.connectClient == nil { return nil, fmt.Errorf("not connected") } diff --git a/client/server/server.go b/client/server/server.go index e12b6df5b..70e4c342f 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -53,6 +53,7 @@ const ( errRestoreResidualState = "failed to restore residual state: %v" errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled" errUpdateSettingsDisabled = "update settings are disabled, you cannot use this feature without update settings enabled" + errNetworksDisabled = "network selection is disabled by the administrator" ) var ErrServiceNotUp = errors.New("service is not up") @@ -88,6 +89,7 @@ type Server struct { profileManager *profilemanager.ServiceManager profilesDisabled bool updateSettingsDisabled bool + networksDisabled bool sleepHandler *sleephandler.SleepHandler @@ -104,7 +106,7 @@ type oauthAuthFlow struct { } // New server instance constructor. -func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool) *Server { +func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool, networksDisabled bool) *Server { s := &Server{ rootCtx: ctx, logFile: logFile, @@ -113,6 +115,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable profileManager: profilemanager.NewServiceManager(configFile), profilesDisabled: profilesDisabled, updateSettingsDisabled: updateSettingsDisabled, + networksDisabled: networksDisabled, jwtCache: newJWTCache(), } agent := &serverAgent{s} @@ -1628,6 +1631,7 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) features := &proto.GetFeaturesResponse{ DisableProfiles: s.checkProfilesDisabled(), DisableUpdateSettings: s.checkUpdateSettingsDisabled(), + DisableNetworks: s.networksDisabled, } return features, nil diff --git a/client/server/server_test.go b/client/server/server_test.go index 6de23d501..772997575 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -36,6 +36,7 @@ import ( daemonProto "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" @@ -103,7 +104,7 @@ func TestConnectWithRetryRuns(t *testing.T) { t.Fatalf("failed to set active profile state: %v", err) } - s := New(ctx, "debug", "", false, false) + s := New(ctx, "debug", "", false, false, false) s.config = config @@ -164,7 +165,7 @@ func TestServer_Up(t *testing.T) { t.Fatalf("failed to set active profile state: %v", err) } - s := New(ctx, "console", "", false, false) + s := New(ctx, "console", "", false, false, false) err = s.Start() require.NoError(t, err) @@ -234,7 +235,7 @@ func TestServer_SubcribeEvents(t *testing.T) { t.Fatalf("failed to set active profile state: %v", err) } - s := New(ctx, "console", "", false, false) + s := New(ctx, "console", "", false, false, false) err = s.Start() require.NoError(t, err) @@ -309,7 +310,12 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve jobManager := job.NewJobManager(nil, store, peersManager) - ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore) + cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + return nil, "", err + } + + ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) @@ -320,7 +326,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) peersUpdateManager := update_channel.NewPeersUpdateManager(metrics) networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config) - accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore) if err != nil { return nil, "", err } diff --git a/client/server/setconfig_test.go b/client/server/setconfig_test.go index 8e360175d..7f6847c43 100644 --- a/client/server/setconfig_test.go +++ b/client/server/setconfig_test.go @@ -53,7 +53,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) { require.NoError(t, err) ctx := context.Background() - s := New(ctx, "console", "", false, false) + s := New(ctx, "console", "", false, false, false) rosenpassEnabled := true rosenpassPermissive := true diff --git a/client/system/info_ios.go b/client/system/info_ios.go index 322609db4..81936cf1d 100644 --- a/client/system/info_ios.go +++ b/client/system/info_ios.go @@ -4,10 +4,12 @@ import ( "context" "runtime" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/version" ) -// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update +// UpdateStaticInfoAsync is a no-op on iOS as there is no static info to update func UpdateStaticInfoAsync() { // do nothing } @@ -15,11 +17,24 @@ func UpdateStaticInfoAsync() { // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { - // Convert fixed-size byte arrays to Go strings sysName := extractOsName(ctx, "sysName") swVersion := extractOsVersion(ctx, "swVersion") - gio := &Info{Kernel: sysName, OSVersion: swVersion, Platform: "unknown", OS: sysName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU(), KernelVersion: swVersion} + addrs, err := networkAddresses() + if err != nil { + log.Warnf("failed to discover network addresses: %s", err) + } + + gio := &Info{ + Kernel: sysName, + OSVersion: swVersion, + Platform: "unknown", + OS: sysName, + GoOS: runtime.GOOS, + CPUs: runtime.NumCPU(), + KernelVersion: swVersion, + NetworkAddresses: addrs, + } gio.Hostname = extractDeviceName(ctx, "hostname") gio.NetbirdVersion = version.NetbirdVersion() gio.UIVersion = extractUserAgent(ctx) diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index b1e0aec41..c149b2152 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -314,6 +314,7 @@ type serviceClient struct { lastNotifiedVersion string settingsEnabled bool profilesEnabled bool + networksEnabled bool showNetworks bool wNetworks fyne.Window wProfiles fyne.Window @@ -368,6 +369,7 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient { showAdvancedSettings: args.showSettings, showNetworks: args.showNetworks, + networksEnabled: true, } s.eventHandler = newEventHandler(s) @@ -920,8 +922,10 @@ func (s *serviceClient) updateStatus() error { s.mStatus.SetIcon(s.icConnectedDot) s.mUp.Disable() s.mDown.Enable() - s.mNetworks.Enable() - s.mExitNode.Enable() + if s.networksEnabled { + s.mNetworks.Enable() + s.mExitNode.Enable() + } s.startExitNodeRefresh() systrayIconState = true case status.Status == string(internal.StatusConnecting): @@ -1093,14 +1097,14 @@ func (s *serviceClient) onTrayReady() { s.getSrvConfig() time.Sleep(100 * time.Millisecond) // To prevent race condition caused by systray not being fully initialized and ignoring setIcon for { + // Check features before status so menus respect disable flags before being enabled + s.checkAndUpdateFeatures() + err := s.updateStatus() if err != nil { log.Errorf("error while updating status: %v", err) } - // Check features periodically to handle daemon restarts - s.checkAndUpdateFeatures() - time.Sleep(2 * time.Second) } }() @@ -1299,6 +1303,16 @@ func (s *serviceClient) checkAndUpdateFeatures() { s.mProfile.setEnabled(profilesEnabled) } } + + // Update networks and exit node menus based on current features + s.networksEnabled = features == nil || !features.DisableNetworks + if s.networksEnabled && s.connected { + s.mNetworks.Enable() + s.mExitNode.Enable() + } else { + s.mNetworks.Disable() + s.mExitNode.Disable() + } } // getFeatures from the daemon to determine which features are enabled/disabled. diff --git a/go.mod b/go.mod index 127aa7cb1..32ff5f89f 100644 --- a/go.mod +++ b/go.mod @@ -71,7 +71,7 @@ require ( github.com/mdlayher/socket v0.5.1 github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 - github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25 + github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/oapi-codegen/runtime v1.1.2 github.com/okta/okta-sdk-golang/v2 v2.18.0 @@ -121,7 +121,7 @@ require ( golang.org/x/sync v0.20.0 golang.org/x/term v0.42.0 golang.org/x/time v0.15.0 - google.golang.org/api v0.275.0 + google.golang.org/api v0.276.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/mysql v1.5.7 gorm.io/driver/postgres v1.5.7 diff --git a/go.sum b/go.sum index aa9f349b0..a925038cc 100644 --- a/go.sum +++ b/go.sum @@ -489,8 +489,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25 h1:iwAq/Ncaq0etl4uAlVsbNBzC1yY52o0AmY7uCm2AMTs= -github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25/go.mod h1:y7CxagMYzg9dgu+masRqYM7BQlOGA5Y8US85MCNFPlY= +github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42 h1:F3zS5fT9xzD1OFLfcdAE+3FfyiwjGukF1hvj0jErgs8= +github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42/go.mod h1:n47r67ZSPgwSmT/Z1o48JjZQW9YJ6m/6Bd/uAXkL3Pg= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= @@ -902,8 +902,8 @@ golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI= gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= -google.golang.org/api v0.275.0 h1:vfY5d9vFVJeWEZT65QDd9hbndr7FyZ2+6mIzGAh71NI= -google.golang.org/api v0.275.0/go.mod h1:Fnag/EWUPIcJXuIkP1pjoTgS5vdxlk3eeemL7Do6bvw= +google.golang.org/api v0.276.0 h1:nVArUtfLEihtW+b0DdcqRGK1xoEm2+ltAihyztq7MKY= +google.golang.org/api v0.276.0/go.mod h1:Fnag/EWUPIcJXuIkP1pjoTgS5vdxlk3eeemL7Do6bvw= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7 h1:XzmzkmB14QhVhgnawEVsOn6OFsnpyxNPRY9QV01dNB0= google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:L43LFes82YgSonw6iTXTxXUX1OlULt4AQtkik4ULL/I= diff --git a/infrastructure_files/getting-started.sh b/infrastructure_files/getting-started.sh index 9236d851d..08da48264 100755 --- a/infrastructure_files/getting-started.sh +++ b/infrastructure_files/getting-started.sh @@ -182,6 +182,23 @@ read_enable_proxy() { return 0 } +read_enable_crowdsec() { + echo "" > /dev/stderr + echo "Do you want to enable CrowdSec IP reputation blocking?" > /dev/stderr + echo "CrowdSec checks client IPs against a community threat intelligence database" > /dev/stderr + echo "and blocks known malicious sources before they reach your services." > /dev/stderr + echo "A local CrowdSec LAPI container will be added to your deployment." > /dev/stderr + echo -n "Enable CrowdSec? [y/N]: " > /dev/stderr + read -r CHOICE < /dev/tty + + if [[ "$CHOICE" =~ ^[Yy]$ ]]; then + echo "true" + else + echo "false" + fi + return 0 +} + read_traefik_acme_email() { echo "" > /dev/stderr echo "Enter your email for Let's Encrypt certificate notifications." > /dev/stderr @@ -297,6 +314,10 @@ initialize_default_values() { # NetBird Proxy configuration ENABLE_PROXY="false" PROXY_TOKEN="" + + # CrowdSec configuration + ENABLE_CROWDSEC="false" + CROWDSEC_BOUNCER_KEY="" return 0 } @@ -325,6 +346,9 @@ configure_reverse_proxy() { if [[ "$REVERSE_PROXY_TYPE" == "0" ]]; then TRAEFIK_ACME_EMAIL=$(read_traefik_acme_email) ENABLE_PROXY=$(read_enable_proxy) + if [[ "$ENABLE_PROXY" == "true" ]]; then + ENABLE_CROWDSEC=$(read_enable_crowdsec) + fi fi # Handle external Traefik-specific prompts (option 1) @@ -354,7 +378,7 @@ check_existing_installation() { echo "Generated files already exist, if you want to reinitialize the environment, please remove them first." echo "You can use the following commands:" echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes" - echo " rm -f docker-compose.yml dashboard.env config.yaml proxy.env traefik-dynamic.yaml nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt" + echo " rm -f docker-compose.yml dashboard.env config.yaml proxy.env traefik-dynamic.yaml nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt && rm -rf crowdsec/" echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard." exit 1 fi @@ -375,6 +399,9 @@ generate_configuration_files() { echo "NB_PROXY_TOKEN=placeholder" >> proxy.env # TCP ServersTransport for PROXY protocol v2 to the proxy backend render_traefik_dynamic > traefik-dynamic.yaml + if [[ "$ENABLE_CROWDSEC" == "true" ]]; then + mkdir -p crowdsec + fi fi ;; 1) @@ -417,8 +444,12 @@ start_services_and_show_instructions() { if [[ "$ENABLE_PROXY" == "true" ]]; then # Phase 1: Start core services (without proxy) + local core_services="traefik dashboard netbird-server" + if [[ "$ENABLE_CROWDSEC" == "true" ]]; then + core_services="$core_services crowdsec" + fi echo "Starting core services..." - $DOCKER_COMPOSE_COMMAND up -d traefik dashboard netbird-server + $DOCKER_COMPOSE_COMMAND up -d $core_services sleep 3 wait_management_proxy traefik @@ -438,7 +469,33 @@ start_services_and_show_instructions() { echo "Proxy token created successfully." - # Generate proxy.env with the token + if [[ "$ENABLE_CROWDSEC" == "true" ]]; then + echo "Registering CrowdSec bouncer..." + local cs_retries=0 + while ! $DOCKER_COMPOSE_COMMAND exec -T crowdsec cscli capi status >/dev/null 2>&1; do + cs_retries=$((cs_retries + 1)) + if [[ $cs_retries -ge 30 ]]; then + echo "WARNING: CrowdSec did not become ready. Skipping CrowdSec setup." > /dev/stderr + echo "You can register a bouncer manually later with:" > /dev/stderr + echo " docker exec netbird-crowdsec cscli bouncers add netbird-proxy -o raw" > /dev/stderr + ENABLE_CROWDSEC="false" + break + fi + sleep 2 + done + + if [[ "$ENABLE_CROWDSEC" == "true" ]]; then + CROWDSEC_BOUNCER_KEY=$($DOCKER_COMPOSE_COMMAND exec -T crowdsec \ + cscli bouncers add netbird-proxy -o raw 2>/dev/null) + if [[ -z "$CROWDSEC_BOUNCER_KEY" ]]; then + echo "WARNING: Failed to create CrowdSec bouncer key. Skipping CrowdSec setup." > /dev/stderr + ENABLE_CROWDSEC="false" + else + echo "CrowdSec bouncer registered." + fi + fi + fi + render_proxy_env > proxy.env # Start proxy service @@ -525,11 +582,25 @@ render_docker_compose_traefik_builtin() { # Generate proxy service section and Traefik dynamic config if enabled local proxy_service="" local proxy_volumes="" + local crowdsec_service="" + local crowdsec_volumes="" local traefik_file_provider="" local traefik_dynamic_volume="" if [[ "$ENABLE_PROXY" == "true" ]]; then traefik_file_provider=' - "--providers.file.filename=/etc/traefik/dynamic.yaml"' traefik_dynamic_volume=" - ./traefik-dynamic.yaml:/etc/traefik/dynamic.yaml:ro" + + local proxy_depends=" + netbird-server: + condition: service_started" + if [[ "$ENABLE_CROWDSEC" == "true" ]]; then + proxy_depends=" + netbird-server: + condition: service_started + crowdsec: + condition: service_healthy" + fi + proxy_service=" # NetBird Proxy - exposes internal resources to the internet proxy: @@ -539,8 +610,7 @@ render_docker_compose_traefik_builtin() { - 51820:51820/udp restart: unless-stopped networks: [netbird] - depends_on: - - netbird-server + depends_on:${proxy_depends} env_file: - ./proxy.env volumes: @@ -563,6 +633,35 @@ render_docker_compose_traefik_builtin() { " proxy_volumes=" netbird_proxy_certs:" + + if [[ "$ENABLE_CROWDSEC" == "true" ]]; then + crowdsec_service=" + crowdsec: + image: crowdsecurity/crowdsec:v1.7.7 + container_name: netbird-crowdsec + restart: unless-stopped + networks: [netbird] + environment: + COLLECTIONS: crowdsecurity/linux + volumes: + - ./crowdsec:/etc/crowdsec + - crowdsec_db:/var/lib/crowdsec/data + healthcheck: + test: ["CMD", "cscli", "lapi", "status"] + interval: 10s + timeout: 5s + retries: 15 + labels: + - traefik.enable=false + logging: + driver: \"json-file\" + options: + max-size: \"500m\" + max-file: \"2\" +" + crowdsec_volumes=" + crowdsec_db:" + fi fi cat <" + echo " Get your enrollment key at: https://app.crowdsec.net" + echo "" + fi fi return 0 } diff --git a/management/internals/modules/reverseproxy/service/manager/manager_test.go b/management/internals/modules/reverseproxy/service/manager/manager_test.go index 69d48f10a..54ac8ab18 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/service/manager/manager_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + cachestore "github.com/eko/gocache/lib/v4/store" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -18,6 +19,7 @@ import ( nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/mock_server" resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -29,6 +31,13 @@ import ( "github.com/netbirdio/netbird/shared/management/status" ) +func testCacheStore(t *testing.T) cachestore.StoreInterface { + t.Helper() + s, err := nbcache.NewStore(context.Background(), 30*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + return s +} + func TestInitializeServiceForCreate(t *testing.T) { ctx := context.Background() accountID := "test-account" @@ -422,10 +431,8 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) { newProxyServer := func(t *testing.T) *nbgrpc.ProxyServiceServer { t.Helper() - tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 1*time.Hour, 10*time.Minute, 100) - require.NoError(t, err) - pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t)) + pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t)) srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) return srv } @@ -703,10 +710,8 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) { }, } - tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100) - require.NoError(t, err) - pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t)) + pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t)) proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter("")) @@ -1128,10 +1133,8 @@ func TestDeleteService_DeletesTargets(t *testing.T) { mockPerms := permissions.NewMockManager(ctrl) mockAcct := account.NewMockManager(ctrl) - tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100) - require.NoError(t, err) - pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t)) + pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t)) proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter("")) diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index 88d37ca80..24dfb641b 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -18,6 +18,7 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" + cachestore "github.com/eko/gocache/lib/v4/store" "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/encryption" @@ -26,6 +27,7 @@ import ( accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/activity" + nbcache "github.com/netbirdio/netbird/management/server/cache" nbContext "github.com/netbirdio/netbird/management/server/context" nbhttp "github.com/netbirdio/netbird/management/server/http" "github.com/netbirdio/netbird/management/server/store" @@ -58,6 +60,18 @@ func (s *BaseServer) Metrics() telemetry.AppMetrics { }) } +// CacheStore returns a shared cache store backed by Redis or in-memory depending on the environment. +// All consumers should reuse this store to avoid creating multiple Redis connections. +func (s *BaseServer) CacheStore() cachestore.StoreInterface { + return Create(s, func() cachestore.StoreInterface { + cs, err := nbcache.NewStore(context.Background(), nbcache.DefaultStoreMaxTimeout, nbcache.DefaultStoreCleanupInterval, nbcache.DefaultStoreMaxConn) + if err != nil { + log.Fatalf("failed to create shared cache store: %v", err) + } + return cs + }) +} + func (s *BaseServer) Store() store.Store { return Create(s, func() store.Store { store, err := store.NewStore(context.Background(), s.Config.StoreConfig.Engine, s.Config.Datadir, s.Metrics(), false) @@ -195,10 +209,7 @@ func (s *BaseServer) proxyOIDCConfig() nbgrpc.ProxyOIDCConfig { func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore { return Create(s, func() *nbgrpc.OneTimeTokenStore { - tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 5*time.Minute, 10*time.Minute, 100) - if err != nil { - log.Fatalf("failed to create proxy token store: %v", err) - } + tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), s.CacheStore()) log.Info("One-time token store initialized for proxy authentication") return tokenStore }) @@ -206,11 +217,7 @@ func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore { func (s *BaseServer) PKCEVerifierStore() *nbgrpc.PKCEVerifierStore { return Create(s, func() *nbgrpc.PKCEVerifierStore { - pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100) - if err != nil { - log.Fatalf("failed to create PKCE verifier store: %v", err) - } - return pkceStore + return nbgrpc.NewPKCEVerifierStore(context.Background(), s.CacheStore()) }) } diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index c7eab3d19..9a8e45d33 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -41,7 +41,8 @@ func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValida context.Background(), s.PeersManager(), s.SettingsManager(), - s.EventStore()) + s.EventStore(), + s.CacheStore()) if err != nil { log.Errorf("failed to create integrated peer validator: %v", err) } diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index f93c66b2f..ea94245d5 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -101,7 +101,7 @@ func (s *BaseServer) PeersManager() peers.Manager { func (s *BaseServer) AccountManager() account.Manager { return Create(s, func() account.Manager { - accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.JobManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy) + accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.JobManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy, s.CacheStore()) if err != nil { log.Fatalf("failed to create account service: %v", err) } diff --git a/management/internals/shared/grpc/onetime_token.go b/management/internals/shared/grpc/onetime_token.go index 7999407db..acfd6eafb 100644 --- a/management/internals/shared/grpc/onetime_token.go +++ b/management/internals/shared/grpc/onetime_token.go @@ -14,8 +14,6 @@ import ( "github.com/eko/gocache/lib/v4/cache" "github.com/eko/gocache/lib/v4/store" log "github.com/sirupsen/logrus" - - nbcache "github.com/netbirdio/netbird/management/server/cache" ) type tokenMetadata struct { @@ -32,17 +30,12 @@ type OneTimeTokenStore struct { ctx context.Context } -// NewOneTimeTokenStore creates a token store with automatic backend selection -func NewOneTimeTokenStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*OneTimeTokenStore, error) { - cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn) - if err != nil { - return nil, fmt.Errorf("failed to create cache store: %w", err) - } - +// NewOneTimeTokenStore creates a token store using the provided shared cache store. +func NewOneTimeTokenStore(ctx context.Context, cacheStore store.StoreInterface) *OneTimeTokenStore { return &OneTimeTokenStore{ cache: cache.New[string](cacheStore), ctx: ctx, - }, nil + } } // GenerateToken creates a new cryptographically secure one-time token diff --git a/management/internals/shared/grpc/pkce_verifier.go b/management/internals/shared/grpc/pkce_verifier.go index 441e8b051..a1325256c 100644 --- a/management/internals/shared/grpc/pkce_verifier.go +++ b/management/internals/shared/grpc/pkce_verifier.go @@ -8,8 +8,6 @@ import ( "github.com/eko/gocache/lib/v4/cache" "github.com/eko/gocache/lib/v4/store" log "github.com/sirupsen/logrus" - - nbcache "github.com/netbirdio/netbird/management/server/cache" ) // PKCEVerifierStore manages PKCE verifiers for OAuth flows. @@ -19,17 +17,12 @@ type PKCEVerifierStore struct { ctx context.Context } -// NewPKCEVerifierStore creates a PKCE verifier store with automatic backend selection -func NewPKCEVerifierStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*PKCEVerifierStore, error) { - cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn) - if err != nil { - return nil, fmt.Errorf("failed to create cache store: %w", err) - } - +// NewPKCEVerifierStore creates a PKCE verifier store using the provided shared cache store. +func NewPKCEVerifierStore(ctx context.Context, cacheStore store.StoreInterface) *PKCEVerifierStore { return &PKCEVerifierStore{ cache: cache.New[string](cacheStore), ctx: ctx, - }, nil + } } // Store saves a PKCE verifier associated with an OAuth state parameter. diff --git a/management/internals/shared/grpc/proxy_test.go b/management/internals/shared/grpc/proxy_test.go index d5aed3dee..de4e96d93 100644 --- a/management/internals/shared/grpc/proxy_test.go +++ b/management/internals/shared/grpc/proxy_test.go @@ -9,13 +9,22 @@ import ( "testing" "time" + cachestore "github.com/eko/gocache/lib/v4/store" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/shared/management/proto" ) +func testCacheStore(t *testing.T) cachestore.StoreInterface { + t.Helper() + s, err := nbcache.NewStore(context.Background(), 30*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + return s +} + type testProxyController struct { mu sync.Mutex clusterProxies map[string]map[string]struct{} @@ -114,11 +123,8 @@ func drainEmpty(ch chan *proto.GetMappingUpdateResponse) bool { func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { ctx := context.Background() - tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100) - require.NoError(t, err) - - pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t)) + pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) s := &ProxyServiceServer{ tokenStore: tokenStore, @@ -174,11 +180,8 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) { ctx := context.Background() - tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100) - require.NoError(t, err) - - pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t)) + pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) s := &ProxyServiceServer{ tokenStore: tokenStore, @@ -211,11 +214,8 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) { func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) { ctx := context.Background() - tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100) - require.NoError(t, err) - - pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t)) + pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) s := &ProxyServiceServer{ tokenStore: tokenStore, @@ -267,8 +267,7 @@ func generateState(s *ProxyServiceServer, redirectURL string) string { func TestOAuthState_NeverTheSame(t *testing.T) { ctx := context.Background() - pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) s := &ProxyServiceServer{ oidcConfig: ProxyOIDCConfig{ @@ -296,8 +295,7 @@ func TestOAuthState_NeverTheSame(t *testing.T) { func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) { ctx := context.Background() - pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) s := &ProxyServiceServer{ oidcConfig: ProxyOIDCConfig{ @@ -307,7 +305,7 @@ func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) { } // Old format had only 2 parts: base64(url)|hmac - err = s.pkceVerifierStore.Store("base64url|hmac", "test", 10*time.Minute) + err := s.pkceVerifierStore.Store("base64url|hmac", "test", 10*time.Minute) require.NoError(t, err) _, _, err = s.ValidateState("base64url|hmac") @@ -317,8 +315,7 @@ func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) { func TestValidateState_RejectsInvalidHMAC(t *testing.T) { ctx := context.Background() - pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) s := &ProxyServiceServer{ oidcConfig: ProxyOIDCConfig{ @@ -328,7 +325,7 @@ func TestValidateState_RejectsInvalidHMAC(t *testing.T) { } // Store with tampered HMAC - err = s.pkceVerifierStore.Store("dGVzdA==|nonce|wrong-hmac", "test", 10*time.Minute) + err := s.pkceVerifierStore.Store("dGVzdA==|nonce|wrong-hmac", "test", 10*time.Minute) require.NoError(t, err) _, _, err = s.ValidateState("dGVzdA==|nonce|wrong-hmac") @@ -337,8 +334,7 @@ func TestValidateState_RejectsInvalidHMAC(t *testing.T) { } func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) { - tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := NewOneTimeTokenStore(context.Background(), testCacheStore(t)) s := &ProxyServiceServer{ tokenStore: tokenStore, @@ -410,8 +406,7 @@ func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) { } func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) { - tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := NewOneTimeTokenStore(context.Background(), testCacheStore(t)) s := &ProxyServiceServer{ tokenStore: tokenStore, @@ -442,8 +437,7 @@ func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) { // scenario for an existing service, verifying the correct update types // reach the correct clusters. func TestServiceModifyNotifications(t *testing.T) { - tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := NewOneTimeTokenStore(context.Background(), testCacheStore(t)) newServer := func() (*ProxyServiceServer, map[string]chan *proto.GetMappingUpdateResponse) { s := &ProxyServiceServer{ diff --git a/management/internals/shared/grpc/validate_session_test.go b/management/internals/shared/grpc/validate_session_test.go index 2f77de86e..d1d7fc8b7 100644 --- a/management/internals/shared/grpc/validate_session_test.go +++ b/management/internals/shared/grpc/validate_session_test.go @@ -39,11 +39,8 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup { usersManager := &testValidateSessionUsersManager{store: testStore} proxyManager := &testValidateSessionProxyManager{} - tokenStore, err := NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100) - require.NoError(t, err) - - pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t)) + pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager) proxyService.SetServiceManager(serviceManager) @@ -327,7 +324,7 @@ func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, type testValidateSessionProxyManager struct{} -func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string) error { +func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *proxy.Capabilities) error { return nil } @@ -335,7 +332,7 @@ func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _ string return nil } -func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _ string) error { +func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _, _, _ string) error { return nil } @@ -351,6 +348,18 @@ func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time return nil } +func (m *testValidateSessionProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool { + return nil +} + +func (m *testValidateSessionProxyManager) ClusterRequireSubdomain(_ context.Context, _ string) *bool { + return nil +} + +func (m *testValidateSessionProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string) *bool { + return nil +} + type testValidateSessionUsersManager struct { store store.Store } diff --git a/management/server/account.go b/management/server/account.go index 97ec77de7..a3cdbb302 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -181,7 +181,7 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *types.User, groups [] return modified, newUserAutoGroups, newGroupsToCreate, nil } -// BuildManager creates a new DefaultAccountManager with a provided Store +// BuildManager creates a new DefaultAccountManager with all dependencies. func BuildManager( ctx context.Context, config *nbconfig.Config, @@ -199,6 +199,7 @@ func BuildManager( settingsManager settings.Manager, permissionsManager permissions.Manager, disableDefaultPolicy bool, + sharedCacheStore cacheStore.StoreInterface, ) (*DefaultAccountManager, error) { start := time.Now() defer func() { @@ -247,16 +248,12 @@ func BuildManager( log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", accountsCounter) } - cacheStore, err := nbcache.NewStore(ctx, nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn) - if err != nil { - return nil, fmt.Errorf("getting cache store: %s", err) - } - am.externalCacheManager = nbcache.NewUserDataCache(cacheStore) - am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, cacheStore) + am.externalCacheManager = nbcache.NewUserDataCache(sharedCacheStore) + am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, sharedCacheStore) if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) { go func() { - err := am.warmupIDPCache(ctx, cacheStore) + err := am.warmupIDPCache(ctx, sharedCacheStore) if err != nil { log.WithContext(ctx).Warnf("failed warming up cache due to error: %v", err) // todo retry? diff --git a/management/server/account_test.go b/management/server/account_test.go index 2f0533281..4453d064e 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3134,10 +3134,15 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU ctx := context.Background() + cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + return nil, nil, err + } + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) - manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) if err != nil { return nil, nil, err } diff --git a/management/server/cache/store.go b/management/server/cache/store.go index 54b0242de..2ca8e8603 100644 --- a/management/server/cache/store.go +++ b/management/server/cache/store.go @@ -17,12 +17,24 @@ import ( // RedisStoreEnvVar is the environment variable that determines if a redis store should be used. // The value should follow redis URL format. https://github.com/redis/redis-specifications/blob/master/uri/redis.txt -const RedisStoreEnvVar = "NB_IDP_CACHE_REDIS_ADDRESS" +const RedisStoreEnvVar = "NB_CACHE_REDIS_ADDRESS" + +// legacyIdPCacheRedisEnvVar is the previous environment variable used for IDP cache. +const legacyIdPCacheRedisEnvVar = "NB_IDP_CACHE_REDIS_ADDRESS" + +const ( + // DefaultStoreMaxTimeout is the default max timeout for the shared cache store. + DefaultStoreMaxTimeout = 7 * 24 * time.Hour + // DefaultStoreCleanupInterval is the default cleanup interval for the shared cache store. + DefaultStoreCleanupInterval = 30 * time.Minute + // DefaultStoreMaxConn is the default max connections for the shared cache store. + DefaultStoreMaxConn = 1000 +) // NewStore creates a new cache store with the given max timeout and cleanup interval. It checks for the environment Variable RedisStoreEnvVar // to determine if a redis store should be used. If the environment variable is set, it will attempt to connect to the redis store. func NewStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (store.StoreInterface, error) { - redisAddr := os.Getenv(RedisStoreEnvVar) + redisAddr := GetAddrFromEnv() if redisAddr != "" { return getRedisStore(ctx, redisAddr, maxConn) } @@ -30,6 +42,15 @@ func NewStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, ma return gocache_store.NewGoCache(goc), nil } +// GetAddrFromEnv returns the redis address from the environment variable RedisStoreEnvVar or its legacy counterpart. +func GetAddrFromEnv() string { + addr := os.Getenv(RedisStoreEnvVar) + if addr == "" { + addr = os.Getenv(legacyIdPCacheRedisEnvVar) + } + return addr +} + func getRedisStore(ctx context.Context, redisEnvAddr string, maxConn int) (store.StoreInterface, error) { options, err := redis.ParseURL(redisEnvAddr) if err != nil { diff --git a/management/server/dns_test.go b/management/server/dns_test.go index bd0755d0d..0e37a3b22 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -15,6 +15,7 @@ import ( "github.com/netbirdio/netbird/management/internals/modules/peers" ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" "github.com/netbirdio/netbird/management/server/permissions" @@ -225,11 +226,17 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { peersManager := peers.NewManager(store, permissionsManager) ctx := context.Background() + + cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + return nil, err + } + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) - return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) } func createDNSStore(t *testing.T) (store.Store, error) { diff --git a/management/server/http/handlers/networks/routers_handler.go b/management/server/http/handlers/networks/routers_handler.go index c311a29fe..ce9efb78d 100644 --- a/management/server/http/handlers/networks/routers_handler.go +++ b/management/server/http/handlers/networks/routers_handler.go @@ -105,6 +105,12 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) { router.NetworkID = networkID router.AccountID = accountID router.Enabled = true + + if err := router.Validate(); err != nil { + util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w) + return + } + router, err = h.routersManager.CreateRouter(r.Context(), userID, router) if err != nil { util.WriteError(r.Context(), err, w) @@ -157,6 +163,11 @@ func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) { router.ID = mux.Vars(r)["routerId"] router.AccountID = accountID + if err := router.Validate(); err != nil { + util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w) + return + } + router, err = h.routersManager.UpdateRouter(r.Context(), userID, router) if err != nil { util.WriteError(r.Context(), err, w) diff --git a/management/server/http/handlers/proxy/auth_callback_integration_test.go b/management/server/http/handlers/proxy/auth_callback_integration_test.go index 922bf4352..c99acab63 100644 --- a/management/server/http/handlers/proxy/auth_callback_integration_test.go +++ b/management/server/http/handlers/proxy/auth_callback_integration_test.go @@ -22,6 +22,7 @@ import ( nbproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/users" @@ -191,11 +192,11 @@ func setupAuthCallbackTest(t *testing.T) *testSetup { oidcServer := newFakeOIDCServer() - tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100) + cacheStore, err := nbcache.NewStore(ctx, 30*time.Minute, 10*time.Minute, 100) require.NoError(t, err) - pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore) + pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore) usersManager := users.NewManager(testStore) diff --git a/management/server/http/testing/integration/networks_handler_integration_test.go b/management/server/http/testing/integration/networks_handler_integration_test.go index 4cb6b268b..54f204a8f 100644 --- a/management/server/http/testing/integration/networks_handler_integration_test.go +++ b/management/server/http/testing/integration/networks_handler_integration_test.go @@ -1170,13 +1170,17 @@ func Test_NetworkRouters_Create(t *testing.T) { Metric: 100, Enabled: true, }, - expectedStatus: http.StatusOK, - verifyResponse: func(t *testing.T, router *api.NetworkRouter) { - t.Helper() - assert.NotEmpty(t, router.Id) - assert.Equal(t, peerID, *router.Peer) - assert.Equal(t, 1, len(*router.PeerGroups)) + expectedStatus: http.StatusBadRequest, + }, + { + name: "Create router without peer and peer_groups", + networkId: "testNetworkId", + requestBody: &api.NetworkRouterRequest{ + Masquerade: true, + Metric: 100, + Enabled: true, }, + expectedStatus: http.StatusBadRequest, }, { name: "Create router in non-existing network", @@ -1341,13 +1345,18 @@ func Test_NetworkRouters_Update(t *testing.T) { Metric: 100, Enabled: true, }, - expectedStatus: http.StatusOK, - verifyResponse: func(t *testing.T, router *api.NetworkRouter) { - t.Helper() - assert.Equal(t, "testRouterId", router.Id) - assert.Equal(t, peerID, *router.Peer) - assert.Equal(t, 1, len(*router.PeerGroups)) + expectedStatus: http.StatusBadRequest, + }, + { + name: "Update router without peer and peer_groups", + networkId: "testNetworkId", + routerId: "testRouterId", + requestBody: &api.NetworkRouterRequest{ + Masquerade: true, + Metric: 100, + Enabled: true, }, + expectedStatus: http.StatusBadRequest, }, } diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index d9d85a0a2..0203d6177 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -35,6 +35,7 @@ import ( "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" serverauth "github.com/netbirdio/netbird/management/server/auth" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/groups" http2 "github.com/netbirdio/netbird/management/server/http" @@ -87,22 +88,22 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee jobManager := job.NewJobManager(nil, store, peersManager) ctx := context.Background() + + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + t.Fatalf("Failed to create cache store: %v", err) + } + requestBuffer := server.NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peersManager), &config.Config{}) - am, err := server.BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) + am, err := server.BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false, cacheStore) if err != nil { t.Fatalf("Failed to create manager: %v", err) } accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil) - proxyTokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100) - if err != nil { - t.Fatalf("Failed to create proxy token store: %v", err) - } - pkceverifierStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - if err != nil { - t.Fatalf("Failed to create PKCE verifier store: %v", err) - } + proxyTokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore) + pkceverifierStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore) noopMeter := noop.NewMeterProvider().Meter("") proxyMgr, err := proxymanager.NewManager(store, noopMeter) if err != nil { @@ -216,22 +217,22 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin jobManager := job.NewJobManager(nil, store, peersManager) ctx := context.Background() + + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + t.Fatalf("Failed to create cache store: %v", err) + } + requestBuffer := server.NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peersManager), &config.Config{}) - am, err := server.BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) + am, err := server.BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false, cacheStore) if err != nil { t.Fatalf("Failed to create manager: %v", err) } accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil) - proxyTokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100) - if err != nil { - t.Fatalf("Failed to create proxy token store: %v", err) - } - pkceverifierStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - if err != nil { - t.Fatalf("Failed to create PKCE verifier store: %v", err) - } + proxyTokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore) + pkceverifierStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore) noopMeter := noop.NewMeterProvider().Meter("") proxyMgr, err := proxymanager.NewManager(store, noopMeter) if err != nil { diff --git a/management/server/identity_provider_test.go b/management/server/identity_provider_test.go index 9fce6b9c0..d51254c55 100644 --- a/management/server/identity_provider_test.go +++ b/management/server/identity_provider_test.go @@ -8,6 +8,7 @@ import ( "net/http/httptest" "path/filepath" "testing" + "time" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" @@ -19,6 +20,7 @@ import ( ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" @@ -83,10 +85,15 @@ func createManagerWithEmbeddedIdP(t testing.TB) (*DefaultAccountManager, *update permissionsManager := permissions.NewManager(testStore) peersManager := peers.NewManager(testStore, permissionsManager) + cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + return nil, nil, err + } + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, testStore) networkMapController := controller.NewController(ctx, testStore, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(testStore, peersManager), &config.Config{}) - manager, err := BuildManager(ctx, &config.Config{}, testStore, networkMapController, job.NewJobManager(nil, testStore, peersManager), idpManager, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + manager, err := BuildManager(ctx, &config.Config{}, testStore, networkMapController, job.NewJobManager(nil, testStore, peersManager), idpManager, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) if err != nil { return nil, nil, err } diff --git a/management/server/idp/google_workspace.go b/management/server/idp/google_workspace.go index 48e4f3000..dadbfd83e 100644 --- a/management/server/idp/google_workspace.go +++ b/management/server/idp/google_workspace.go @@ -66,14 +66,14 @@ func NewGoogleWorkspaceManager(ctx context.Context, config GoogleWorkspaceClient } // Create a new Admin SDK Directory service client - adminCredentials, err := getGoogleCredentials(ctx, config.ServiceAccountKey) + credentialsOption, err := getGoogleCredentialsOption(ctx, config.ServiceAccountKey) if err != nil { return nil, err } service, err := admin.NewService(context.Background(), option.WithScopes(admin.AdminDirectoryUserReadonlyScope), - option.WithCredentials(adminCredentials), + credentialsOption, ) if err != nil { return nil, err @@ -218,39 +218,32 @@ func (gm *GoogleWorkspaceManager) DeleteUser(_ context.Context, userID string) e return nil } -// getGoogleCredentials retrieves Google credentials based on the provided serviceAccountKey. -// It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it. -// If that fails, it falls back to using the default Google credentials path. -// It returns the retrieved credentials or an error if unsuccessful. -func getGoogleCredentials(ctx context.Context, serviceAccountKey string) (*google.Credentials, error) { +// getGoogleCredentialsOption returns the google.golang.org/api option carrying +// Google credentials derived from the provided serviceAccountKey. +// It decodes the base64-encoded serviceAccountKey and uses it as the credentials JSON. +// If the key is empty, it falls back to the default Google credentials path. +func getGoogleCredentialsOption(ctx context.Context, serviceAccountKey string) (option.ClientOption, error) { log.WithContext(ctx).Debug("retrieving google credentials from the base64 encoded service account key") decodeKey, err := base64.StdEncoding.DecodeString(serviceAccountKey) if err != nil { return nil, fmt.Errorf("failed to decode service account key: %w", err) } - creds, err := google.CredentialsFromJSON( - context.Background(), - decodeKey, - admin.AdminDirectoryUserReadonlyScope, - ) - if err == nil { - // No need to fallback to the default Google credentials path - return creds, nil + if len(decodeKey) > 0 { + return option.WithAuthCredentialsJSON(option.ServiceAccount, decodeKey), nil } - log.WithContext(ctx).Debugf("failed to retrieve Google credentials from ServiceAccountKey: %v", err) - log.WithContext(ctx).Debug("falling back to default google credentials location") + log.WithContext(ctx).Debug("no service account key provided, falling back to default google credentials location") - creds, err = google.FindDefaultCredentials( - context.Background(), + creds, err := google.FindDefaultCredentials( + ctx, admin.AdminDirectoryUserReadonlyScope, ) if err != nil { return nil, err } - return creds, nil + return option.WithCredentials(creds), nil } // parseGoogleWorkspaceUser parse google user to UserData. diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 090c99877..4e6eb0a33 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -29,6 +29,7 @@ import ( "github.com/netbirdio/netbird/management/internals/server/config" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" @@ -369,9 +370,15 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config requestBuffer := NewAccountRequestBuffer(ctx, store) ephemeralMgr := manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)) + cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + cleanup() + return nil, nil, "", cleanup, err + } + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeralMgr, config) accountManager, err := BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", - eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) if err != nil { cleanup() diff --git a/management/server/management_test.go b/management/server/management_test.go index de02855bf..3ac28cd4a 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -28,6 +28,7 @@ import ( nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" @@ -207,6 +208,12 @@ func startServer( jobManager := job.NewJobManager(nil, str, peersManager) ctx := context.Background() + + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + t.Fatalf("failed creating cache store: %v", err) + } + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := server.NewAccountRequestBuffer(ctx, str) networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(str, peers.NewManager(str, permissionsManager)), config) @@ -227,7 +234,8 @@ func startServer( port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, - false) + false, + cacheStore) if err != nil { t.Fatalf("failed creating an account manager: %v", err) } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 90b4b9687..d10d4464f 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -17,6 +17,7 @@ import ( ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -794,11 +795,17 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { peersManager := peers.NewManager(store, permissionsManager) ctx := context.Background() + + cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + return nil, err + } + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) - return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) } func createNSStore(t *testing.T) (store.Store, error) { diff --git a/management/server/networks/routers/types/router.go b/management/server/networks/routers/types/router.go index e90c61a97..1293a9934 100644 --- a/management/server/networks/routers/types/router.go +++ b/management/server/networks/routers/types/router.go @@ -21,11 +21,7 @@ type NetworkRouter struct { } func NewNetworkRouter(accountID string, networkID string, peer string, peerGroups []string, masquerade bool, metric int, enabled bool) (*NetworkRouter, error) { - if peer != "" && len(peerGroups) > 0 { - return nil, errors.New("peer and peerGroups cannot be set at the same time") - } - - return &NetworkRouter{ + r := &NetworkRouter{ ID: xid.New().String(), AccountID: accountID, NetworkID: networkID, @@ -34,7 +30,25 @@ func NewNetworkRouter(accountID string, networkID string, peer string, peerGroup Masquerade: masquerade, Metric: metric, Enabled: enabled, - }, nil + } + + if err := r.Validate(); err != nil { + return nil, err + } + + return r, nil +} + +func (n *NetworkRouter) Validate() error { + if n.Peer != "" && len(n.PeerGroups) > 0 { + return errors.New("peer and peer_groups cannot be set at the same time") + } + + if n.Peer == "" && len(n.PeerGroups) == 0 { + return errors.New("either peer or peer_groups must be provided") + } + + return nil } func (n *NetworkRouter) ToAPIResponse() *api.NetworkRouter { diff --git a/management/server/networks/routers/types/router_test.go b/management/server/networks/routers/types/router_test.go index 5801e3bfa..a2f2fe6e3 100644 --- a/management/server/networks/routers/types/router_test.go +++ b/management/server/networks/routers/types/router_test.go @@ -38,7 +38,7 @@ func TestNewNetworkRouter(t *testing.T) { expectedError: false, }, { - name: "Valid with no peer or peerGroups", + name: "Invalid with no peer or peerGroups", networkID: "network-3", accountID: "account-3", peer: "", @@ -46,7 +46,18 @@ func TestNewNetworkRouter(t *testing.T) { masquerade: true, metric: 300, enabled: true, - expectedError: false, + expectedError: true, + }, + { + name: "Invalid with empty peerGroups slice", + networkID: "network-5", + accountID: "account-5", + peer: "", + peerGroups: []string{}, + masquerade: true, + metric: 500, + enabled: true, + expectedError: true, }, // Invalid cases diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 51c16d730..6f8d924fd 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -32,6 +32,7 @@ import ( ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/shared/grpc" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" @@ -1294,11 +1295,15 @@ func Test_RegisterPeerByUser(t *testing.T) { peersManager := peers.NewManager(s, permissionsManager) ctx := context.Background() + + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + require.NoError(t, err) + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) - am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1380,11 +1385,15 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { peersManager := peers.NewManager(s, permissionsManager) ctx := context.Background() + + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + require.NoError(t, err) + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) - am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1534,11 +1543,15 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { peersManager := peers.NewManager(s, permissionsManager) ctx := context.Background() + + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + require.NoError(t, err) + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) - am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1615,11 +1628,15 @@ func Test_LoginPeer(t *testing.T) { peersManager := peers.NewManager(s, permissionsManager) ctx := context.Background() + + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + require.NoError(t, err) + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) - am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" diff --git a/management/server/route_test.go b/management/server/route_test.go index d4882eff8..91b2cf982 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -20,6 +20,7 @@ import ( ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -1293,11 +1294,17 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, *update_channel. peersManager := peers.NewManager(store, permissionsManager) ctx := context.Background() + + cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + return nil, nil, err + } + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) - am, err := BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) if err != nil { return nil, nil, err } diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index 17510f37e..4b1ecf922 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -22,6 +22,7 @@ import ( nbproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/users" @@ -113,11 +114,11 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup { } // Create real token store - tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100) + cacheStore, err := nbcache.NewStore(ctx, 30*time.Minute, 10*time.Minute, 100) require.NoError(t, err) - pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore) + pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore) // Create real users manager usersManager := users.NewManager(testStore) diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index f5edb6b95..d9a1a7d65 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -31,6 +31,7 @@ import ( "github.com/netbirdio/netbird/management/internals/server/config" mgmt "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/mock_server" @@ -95,9 +96,16 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { settingsManagerMock := settings.NewMockManager(ctrl) jobManager := job.NewJobManager(nil, store, peersManger) - ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManger, settingsManagerMock, eventStore) + ctx := context.Background() - metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + t.Fatal(err) + } + + ia, _ := integrations.NewIntegratedValidator(ctx, peersManger, settingsManagerMock, eventStore, cacheStore) + + metrics, err := telemetry.NewDefaultAppMetrics(ctx) require.NoError(t, err) settingsMockManager := settings.NewMockManager(ctrl) @@ -116,11 +124,10 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { Return(&types.ExtraSettings{}, nil). AnyTimes() - ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peersManger), config) - accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore) if err != nil { t.Fatal(err) }