diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index b06ba73ab..71badf0d4 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -13,6 +13,7 @@ import ( "strings" log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -50,28 +51,21 @@ func (s *systemConfigurator) supportCustomPort() bool { } func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { - var err error - - if err := stateManager.UpdateState(&ShutdownState{}); err != nil { - log.Errorf("failed to update shutdown state: %s", err) - } - var ( searchDomains []string matchDomains []string ) - err = s.recordSystemDNSSettings(true) - if err != nil { + if err := s.recordSystemDNSSettings(true); err != nil { log.Errorf("unable to update record of System's DNS config: %s", err.Error()) } if config.RouteAll { searchDomains = append(searchDomains, "\"\"") - err = s.addLocalDNS() - if err != nil { - log.Infof("failed to enable split DNS") + if err := s.addLocalDNS(); err != nil { + log.Warnf("failed to add local DNS: %v", err) } + s.updateState(stateManager) } for _, dConf := range config.Domains { @@ -86,6 +80,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * } matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) + var err error if len(matchDomains) != 0 { err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort) } else { @@ -95,6 +90,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * if err != nil { return fmt.Errorf("add match domains: %w", err) } + s.updateState(stateManager) searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) if len(searchDomains) != 0 { @@ -106,6 +102,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * if err != nil { return fmt.Errorf("add search domains: %w", err) } + s.updateState(stateManager) if err := s.flushDNSCache(); err != nil { log.Errorf("failed to flush DNS cache: %v", err) @@ -114,6 +111,12 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * return nil } +func (s *systemConfigurator) updateState(stateManager *statemanager.Manager) { + if err := stateManager.UpdateState(&ShutdownState{CreatedKeys: maps.Keys(s.createdKeys)}); err != nil { + log.Errorf("failed to update shutdown state: %s", err) + } +} + func (s *systemConfigurator) string() string { return "scutil" } @@ -167,18 +170,20 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error { func (s *systemConfigurator) addLocalDNS() error { if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { if err := s.recordSystemDNSSettings(true); err != nil { - log.Errorf("Unable to get system DNS configuration") return fmt.Errorf("recordSystemDNSSettings(): %w", err) } } localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) - if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 { - err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort) - if err != nil { - return fmt.Errorf("couldn't add local network DNS conf: %w", err) - } - } else { + if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { log.Info("Not enabling local DNS server") + return nil + } + + if err := s.addSearchDomains( + localKey, + strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort, + ); err != nil { + return fmt.Errorf("add search domains: %w", err) } return nil diff --git a/client/internal/dns/host_darwin_test.go b/client/internal/dns/host_darwin_test.go new file mode 100644 index 000000000..c4efd17b0 --- /dev/null +++ b/client/internal/dns/host_darwin_test.go @@ -0,0 +1,111 @@ +//go:build !ios + +package dns + +import ( + "context" + "net/netip" + "os/exec" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) { + if testing.Short() { + t.Skip("skipping scutil integration test in short mode") + } + + tmpDir := t.TempDir() + stateFile := filepath.Join(tmpDir, "state.json") + + sm := statemanager.New(stateFile) + sm.RegisterState(&ShutdownState{}) + sm.Start() + defer func() { + require.NoError(t, sm.Stop(context.Background())) + }() + + configurator := &systemConfigurator{ + createdKeys: make(map[string]struct{}), + } + + config := HostDNSConfig{ + ServerIP: netip.MustParseAddr("100.64.0.1"), + ServerPort: 53, + RouteAll: true, + Domains: []DomainConfig{ + {Domain: "example.com", MatchOnly: true}, + }, + } + + err := configurator.applyDNSConfig(config, sm) + require.NoError(t, err) + + require.NoError(t, sm.PersistState(context.Background())) + + searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) + matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) + localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) + + defer func() { + for _, key := range []string{searchKey, matchKey, localKey} { + _ = removeTestDNSKey(key) + } + }() + + for _, key := range []string{searchKey, matchKey, localKey} { + exists, err := checkDNSKeyExists(key) + require.NoError(t, err) + if exists { + t.Logf("Key %s exists before cleanup", key) + } + } + + sm2 := statemanager.New(stateFile) + sm2.RegisterState(&ShutdownState{}) + err = sm2.LoadState(&ShutdownState{}) + require.NoError(t, err) + + state := sm2.GetState(&ShutdownState{}) + if state == nil { + t.Skip("State not saved, skipping cleanup test") + } + + shutdownState, ok := state.(*ShutdownState) + require.True(t, ok) + + err = shutdownState.Cleanup() + require.NoError(t, err) + + for _, key := range []string{searchKey, matchKey, localKey} { + exists, err := checkDNSKeyExists(key) + require.NoError(t, err) + assert.False(t, exists, "Key %s should NOT exist after cleanup", key) + } +} + +func checkDNSKeyExists(key string) (bool, error) { + cmd := exec.Command(scutilPath) + cmd.Stdin = strings.NewReader("show " + key + "\nquit\n") + output, err := cmd.CombinedOutput() + if err != nil { + if strings.Contains(string(output), "No such key") { + return false, nil + } + return false, err + } + return !strings.Contains(string(output), "No such key"), nil +} + +func removeTestDNSKey(key string) error { + cmd := exec.Command(scutilPath) + cmd.Stdin = strings.NewReader("remove " + key + "\nquit\n") + _, err := cmd.CombinedOutput() + return err +} diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index 74111d335..01b7edc48 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -179,13 +179,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP) } - if err := stateManager.UpdateState(&ShutdownState{ - Guid: r.guid, - GPO: r.gpo, - NRPTEntryCount: r.nrptEntryCount, - }); err != nil { - log.Errorf("failed to update shutdown state: %s", err) - } + r.updateState(stateManager) var searchDomains, matchDomains []string for _, dConf := range config.Domains { @@ -212,13 +206,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager r.nrptEntryCount = 0 } - if err := stateManager.UpdateState(&ShutdownState{ - Guid: r.guid, - GPO: r.gpo, - NRPTEntryCount: r.nrptEntryCount, - }); err != nil { - log.Errorf("failed to update shutdown state: %s", err) - } + r.updateState(stateManager) if err := r.updateSearchDomains(searchDomains); err != nil { return fmt.Errorf("update search domains: %w", err) @@ -229,6 +217,16 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager return nil } +func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) { + if err := stateManager.UpdateState(&ShutdownState{ + Guid: r.guid, + GPO: r.gpo, + NRPTEntryCount: r.nrptEntryCount, + }); err != nil { + log.Errorf("failed to update shutdown state: %s", err) + } +} + func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error { if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip.String()); err != nil { return fmt.Errorf("adding dns setup for all failed: %w", err) diff --git a/client/internal/dns/unclean_shutdown_darwin.go b/client/internal/dns/unclean_shutdown_darwin.go index 9bbdd2b56..f51b5cf8d 100644 --- a/client/internal/dns/unclean_shutdown_darwin.go +++ b/client/internal/dns/unclean_shutdown_darwin.go @@ -7,6 +7,7 @@ import ( ) type ShutdownState struct { + CreatedKeys []string } func (s *ShutdownState) Name() string { @@ -19,6 +20,10 @@ func (s *ShutdownState) Cleanup() error { return fmt.Errorf("create host manager: %w", err) } + for _, key := range s.CreatedKeys { + manager.createdKeys[key] = struct{}{} + } + if err := manager.restoreUncleanShutdownDNS(); err != nil { return fmt.Errorf("restore unclean shutdown dns: %w", err) } diff --git a/client/internal/routemanager/systemops/systemops_unix.go b/client/internal/routemanager/systemops/systemops_unix.go index 4563e5ef6..7089178fb 100644 --- a/client/internal/routemanager/systemops/systemops_unix.go +++ b/client/internal/routemanager/systemops/systemops_unix.go @@ -48,7 +48,7 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRout return r.cleanupRefCounter(stateManager) } -// FlushMarkedRoutes removes all routes marked with the configured RTF_PROTO flag. +// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag. func (r *SysOps) FlushMarkedRoutes() error { rib, err := retryFetchRIB() if err != nil { diff --git a/go.mod b/go.mod index 79dd92e6b..06dc921f1 100644 --- a/go.mod +++ b/go.mod @@ -62,7 +62,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20251022080146-b1caade69396 + github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index f0065e081..ce68ed99e 100644 --- a/go.sum +++ b/go.sum @@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251022080146-b1caade69396 h1:aXHS63QWf0Z5fDN19Swl6npdJjGMyXthAvvgW7rbKJQ= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251022080146-b1caade69396/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 h1:moJbL1uuaWR35yUgHZ6suijjqqW8/qGCuPPBXu5MeWQ= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48/go.mod h1:ifKa2jGPsOzZhJFo72v2AE5nMP3GYvlhoZ9JV6lHlJ8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= diff --git a/infrastructure_files/configure.sh b/infrastructure_files/configure.sh index e3fcbfdde..92252d0b3 100755 --- a/infrastructure_files/configure.sh +++ b/infrastructure_files/configure.sh @@ -185,12 +185,15 @@ if [[ "$NETBIRD_DISABLE_LETSENCRYPT" == "true" ]]; then echo "You are also free to remove any occurrences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME" echo "" - export NETBIRD_SIGNAL_PROTOCOL="https" unset NETBIRD_LETSENCRYPT_DOMAIN unset NETBIRD_MGMT_API_CERT_FILE unset NETBIRD_MGMT_API_CERT_KEY_FILE fi +if [[ -n "$NETBIRD_MGMT_API_CERT_FILE" && -n "$NETBIRD_MGMT_API_CERT_KEY_FILE" ]]; then + export NETBIRD_SIGNAL_PROTOCOL="https" +fi + # Check if management identity provider is set if [ -n "$NETBIRD_MGMT_IDP" ]; then EXTRA_CONFIG={} diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl index b24e853b4..2bc49d3e5 100644 --- a/infrastructure_files/docker-compose.yml.tmpl +++ b/infrastructure_files/docker-compose.yml.tmpl @@ -40,13 +40,21 @@ services: signal: <<: *default image: netbirdio/signal:$NETBIRD_SIGNAL_TAG + depends_on: + - dashboard volumes: - $SIGNAL_VOLUMENAME:/var/lib/netbird + - $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt:ro ports: - $NETBIRD_SIGNAL_PORT:80 # # port and command for Let's Encrypt validation # - 443:443 # command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"] + command: [ + "--cert-file", "$NETBIRD_MGMT_API_CERT_FILE", + "--cert-key", "$NETBIRD_MGMT_API_CERT_KEY_FILE", + "--log-file", "console" + ] # Relay relay: diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index daec4ef6f..209a20065 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -35,7 +35,13 @@ func (s *BaseServer) GeoLocationManager() geolocation.Geolocation { func (s *BaseServer) PermissionsManager() permissions.Manager { return Create(s, func() permissions.Manager { - return integrations.InitPermissionsManager(s.Store()) + manager := integrations.InitPermissionsManager(s.Store(), s.Metrics().GetMeter()) + + s.AfterInit(func(s *BaseServer) { + manager.SetAccountManager(s.AccountManager()) + }) + + return manager }) } diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 741f03f18..bdf56db6e 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -7,9 +7,10 @@ import ( "time" "github.com/golang-jwt/jwt/v5" - "github.com/netbirdio/management-integrations/integrations" "github.com/stretchr/testify/assert" + "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" diff --git a/management/server/permissions/manager.go b/management/server/permissions/manager.go index 891fa59bb..e6bdd2025 100644 --- a/management/server/permissions/manager.go +++ b/management/server/permissions/manager.go @@ -7,6 +7,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" @@ -22,6 +23,7 @@ type Manager interface { ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error) + SetAccountManager(accountManager account.Manager) } type managerImpl struct { @@ -121,3 +123,7 @@ func (m *managerImpl) GetPermissionsByRole(ctx context.Context, role types.UserR return permissions, nil } + +func (m *managerImpl) SetAccountManager(accountManager account.Manager) { + // no-op +} diff --git a/management/server/permissions/manager_mock.go b/management/server/permissions/manager_mock.go index fa115d628..ec9f263f9 100644 --- a/management/server/permissions/manager_mock.go +++ b/management/server/permissions/manager_mock.go @@ -9,6 +9,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + account "github.com/netbirdio/netbird/management/server/account" modules "github.com/netbirdio/netbird/management/server/permissions/modules" operations "github.com/netbirdio/netbird/management/server/permissions/operations" roles "github.com/netbirdio/netbird/management/server/permissions/roles" @@ -53,6 +54,18 @@ func (mr *MockManagerMockRecorder) GetPermissionsByRole(ctx, role interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPermissionsByRole", reflect.TypeOf((*MockManager)(nil).GetPermissionsByRole), ctx, role) } +// SetAccountManager mocks base method. +func (m *MockManager) SetAccountManager(accountManager account.Manager) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetAccountManager", accountManager) +} + +// SetAccountManager indicates an expected call of SetAccountManager. +func (mr *MockManagerMockRecorder) SetAccountManager(accountManager interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccountManager", reflect.TypeOf((*MockManager)(nil).SetAccountManager), accountManager) +} + // ValidateAccountAccess mocks base method. func (m *MockManager) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error { m.ctrl.T.Helper() diff --git a/management/server/types/account.go b/management/server/types/account.go index f830023c7..50bdc6ab3 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -301,7 +301,7 @@ func (a *Account) GetPeerNetworkMap( if dnsManagementStatus { var zones []nbdns.CustomZone if peersCustomZone.Domain != "" { - records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect) + records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers) zones = append(zones, nbdns.CustomZone{ Domain: peersCustomZone.Domain, Records: records, @@ -1682,7 +1682,7 @@ func peerSupportsPortRanges(peerVer string) bool { } // filterZoneRecordsForPeers filters DNS records to only include peers to connect. -func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect []*nbpeer.Peer) []nbdns.SimpleRecord { +func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect, expiredPeers []*nbpeer.Peer) []nbdns.SimpleRecord { filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records)) peerIPs := make(map[string]struct{}) @@ -1693,6 +1693,10 @@ func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, p peerIPs[peerToConnect.IP.String()] = struct{}{} } + for _, expiredPeer := range expiredPeers { + peerIPs[expiredPeer.IP.String()] = struct{}{} + } + for _, record := range customZone.Records { if _, exists := peerIPs[record.RData]; exists { filteredRecords = append(filteredRecords, record) diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index cd221b590..32538933a 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -845,6 +845,7 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { peer *nbpeer.Peer customZone nbdns.CustomZone peersToConnect []*nbpeer.Peer + expiredPeers []*nbpeer.Peer expectedRecords []nbdns.SimpleRecord }{ { @@ -857,6 +858,7 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { }, }, peersToConnect: []*nbpeer.Peer{}, + expiredPeers: []*nbpeer.Peer{}, peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expectedRecords: []nbdns.SimpleRecord{ {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, @@ -890,7 +892,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { } return peers }(), - peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expiredPeers: []*nbpeer.Peer{}, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expectedRecords: func() []nbdns.SimpleRecord { var records []nbdns.SimpleRecord for _, i := range []int{1, 5, 10, 25, 50, 75, 100} { @@ -924,7 +927,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { {ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}}, {ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}}, }, - peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expiredPeers: []*nbpeer.Peer{}, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expectedRecords: []nbdns.SimpleRecord{ {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, {Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, @@ -934,11 +938,35 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, }, }, + { + name: "expired peers are included in DNS entries", + customZone: nbdns.CustomZone{ + Domain: "netbird.cloud.", + Records: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + peersToConnect: []*nbpeer.Peer{ + {ID: "peer1", IP: net.ParseIP("10.0.0.1")}, + }, + expiredPeers: []*nbpeer.Peer{ + {ID: "expired-peer", IP: net.ParseIP("10.0.0.99")}, + }, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expectedRecords: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect) + result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect, tt.expiredPeers) assert.Equal(t, len(tt.expectedRecords), len(result)) assert.ElementsMatch(t, tt.expectedRecords, result) }) diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 96873dee7..bf8f8e327 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -94,7 +94,7 @@ var ( startPprof() - opts, certManager, err := getTLSConfigurations() + opts, certManager, tlsConfig, err := getTLSConfigurations() if err != nil { return err } @@ -132,7 +132,7 @@ var ( // Start the main server - always serve HTTP with WebSocket proxy support // If certManager is configured and signalPort == 443, it's already handled by startServerWithCertManager - if certManager == nil { + if tlsConfig == nil { // Without TLS, serve plain HTTP httpListener, err = net.Listen("tcp", fmt.Sprintf(":%d", signalPort)) if err != nil { @@ -140,9 +140,10 @@ var ( } log.Infof("running HTTP server with WebSocket proxy (no TLS): %s", httpListener.Addr().String()) serveHTTP(httpListener, grpcRootHandler) - } else if signalPort != 443 { - // With TLS but not on port 443, serve HTTPS - httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), certManager.TLSConfig()) + } else if certManager == nil || signalPort != 443 { + // Serve HTTPS if not already handled by startServerWithCertManager + // (custom certificates or Let's Encrypt with custom port) + httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), tlsConfig) if err != nil { return err } @@ -202,7 +203,7 @@ func startPprof() { }() } -func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) { +func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, *tls.Config, error) { var ( err error certManager *autocert.Manager @@ -211,33 +212,33 @@ func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) { if signalLetsencryptDomain == "" && signalCertFile == "" && signalCertKey == "" { log.Infof("running without TLS") - return nil, nil, nil + return nil, nil, nil, nil } if signalLetsencryptDomain != "" { certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain) if err != nil { - return nil, certManager, err + return nil, certManager, nil, err } tlsConfig = certManager.TLSConfig() log.Infof("setting up TLS with LetsEncrypt.") } else { if signalCertFile == "" || signalCertKey == "" { log.Errorf("both cert-file and cert-key must be provided when not using LetsEncrypt") - return nil, certManager, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt") + return nil, certManager, nil, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt") } tlsConfig, err = loadTLSConfig(signalCertFile, signalCertKey) if err != nil { log.Errorf("cannot load TLS credentials: %v", err) - return nil, certManager, err + return nil, certManager, nil, err } log.Infof("setting up TLS with custom certificates.") } transportCredentials := credentials.NewTLS(tlsConfig) - return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, err + return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, tlsConfig, err } func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler http.Handler) {