diff --git a/client/internal/dns/file_linux.go b/client/internal/dns/file_linux.go index c62da7016..b427a30e1 100644 --- a/client/internal/dns/file_linux.go +++ b/client/internal/dns/file_linux.go @@ -8,6 +8,7 @@ import ( "net/netip" "os" "strings" + "time" log "github.com/sirupsen/logrus" ) @@ -23,10 +24,16 @@ const ( fileMaxNumberOfSearchDomains = 6 ) +const ( + dnsFailoverTimeout = 4 * time.Second + dnsFailoverAttempts = 1 +) + type fileConfigurator struct { repair *repair - originalPerms os.FileMode + originalPerms os.FileMode + nbNameserverIP string } func newFileConfigurator() (hostManager, error) { @@ -64,7 +71,7 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error { } nbSearchDomains := searchDomains(config) - nbNameserverIP := config.ServerIP + f.nbNameserverIP = config.ServerIP resolvConf, err := parseBackupResolvConf() if err != nil { @@ -73,11 +80,11 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error { f.repair.stopWatchFileChanges() - err = f.updateConfig(nbSearchDomains, nbNameserverIP, resolvConf) + err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf) if err != nil { return err } - f.repair.watchFileChanges(nbSearchDomains, nbNameserverIP) + f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP) return nil } @@ -85,10 +92,11 @@ func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains) nameServers := generateNsList(nbNameserverIP, cfg) + options := prepareOptionsWithTimeout(cfg.others, int(dnsFailoverTimeout.Seconds()), dnsFailoverAttempts) buf := prepareResolvConfContent( searchDomainList, nameServers, - cfg.others) + options) log.Debugf("creating managed file %s", defaultResolvConfPath) err := os.WriteFile(defaultResolvConfPath, buf.Bytes(), f.originalPerms) @@ -131,7 +139,12 @@ func (f *fileConfigurator) backup() error { } func (f *fileConfigurator) restore() error { - err := copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath) + err := removeFirstNbNameserver(fileDefaultResolvConfBackupLocation, f.nbNameserverIP) + if err != nil { + log.Errorf("Failed to remove netbird nameserver from %s on backup restore: %s", fileDefaultResolvConfBackupLocation, err) + } + + err = copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath) if err != nil { return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err) } @@ -157,7 +170,7 @@ func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Add currentDNSAddress, err := netip.ParseAddr(resolvConf.nameServers[0]) // not a valid first nameserver -> restore if err != nil { - log.Errorf("restoring unclean shutdown: parse dns address %s failed: %s", resolvConf.nameServers[1], err) + log.Errorf("restoring unclean shutdown: parse dns address %s failed: %s", resolvConf.nameServers[0], err) return restoreResolvConfFile() } diff --git a/client/internal/dns/file_parser_linux.go b/client/internal/dns/file_parser_linux.go index 95e1ddb54..02f6d03a5 100644 --- a/client/internal/dns/file_parser_linux.go +++ b/client/internal/dns/file_parser_linux.go @@ -5,6 +5,7 @@ package dns import ( "fmt" "os" + "regexp" "strings" log "github.com/sirupsen/logrus" @@ -14,6 +15,9 @@ const ( defaultResolvConfPath = "/etc/resolv.conf" ) +var timeoutRegex = regexp.MustCompile(`timeout:\d+`) +var attemptsRegex = regexp.MustCompile(`attempts:\d+`) + type resolvConf struct { nameServers []string searchDomains []string @@ -103,3 +107,62 @@ func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) { } return rconf, nil } + +// prepareOptionsWithTimeout appends timeout to existing options if it doesn't exist, +// otherwise it adds a new option with timeout and attempts. +func prepareOptionsWithTimeout(input []string, timeout int, attempts int) []string { + configs := make([]string, len(input)) + copy(configs, input) + + for i, config := range configs { + if strings.HasPrefix(config, "options") { + config = strings.ReplaceAll(config, "rotate", "") + config = strings.Join(strings.Fields(config), " ") + + if strings.Contains(config, "timeout:") { + config = timeoutRegex.ReplaceAllString(config, fmt.Sprintf("timeout:%d", timeout)) + } else { + config = strings.Replace(config, "options ", fmt.Sprintf("options timeout:%d ", timeout), 1) + } + + if strings.Contains(config, "attempts:") { + config = attemptsRegex.ReplaceAllString(config, fmt.Sprintf("attempts:%d", attempts)) + } else { + config = strings.Replace(config, "options ", fmt.Sprintf("options attempts:%d ", attempts), 1) + } + + configs[i] = config + return configs + } + } + + return append(configs, fmt.Sprintf("options timeout:%d attempts:%d", timeout, attempts)) +} + +// removeFirstNbNameserver removes the given nameserver from the given file if it is in the first position +// and writes the file back to the original location +func removeFirstNbNameserver(filename, nameserverIP string) error { + resolvConf, err := parseResolvConfFile(filename) + if err != nil { + return fmt.Errorf("parse backup resolv.conf: %w", err) + } + content, err := os.ReadFile(filename) + if err != nil { + return fmt.Errorf("read %s: %w", filename, err) + } + + if len(resolvConf.nameServers) > 1 && resolvConf.nameServers[0] == nameserverIP { + newContent := strings.Replace(string(content), fmt.Sprintf("nameserver %s\n", nameserverIP), "", 1) + + stat, err := os.Stat(filename) + if err != nil { + return fmt.Errorf("stat %s: %w", filename, err) + } + if err := os.WriteFile(filename, []byte(newContent), stat.Mode()); err != nil { + return fmt.Errorf("write %s: %w", filename, err) + } + + } + + return nil +} diff --git a/client/internal/dns/file_parser_linux_test.go b/client/internal/dns/file_parser_linux_test.go index 180ad2f9d..4263d4063 100644 --- a/client/internal/dns/file_parser_linux_test.go +++ b/client/internal/dns/file_parser_linux_test.go @@ -6,6 +6,8 @@ import ( "os" "path/filepath" "testing" + + "github.com/stretchr/testify/assert" ) func Test_parseResolvConf(t *testing.T) { @@ -172,3 +174,131 @@ nameserver 192.168.0.1 t.Errorf("unexpected resolv.conf content: %v", cfg) } } + +func TestPrepareOptionsWithTimeout(t *testing.T) { + tests := []struct { + name string + others []string + timeout int + attempts int + expected []string + }{ + { + name: "Append new options with timeout and attempts", + others: []string{"some config"}, + timeout: 2, + attempts: 2, + expected: []string{"some config", "options timeout:2 attempts:2"}, + }, + { + name: "Modify existing options to exclude rotate and include timeout and attempts", + others: []string{"some config", "options rotate someother"}, + timeout: 3, + attempts: 2, + expected: []string{"some config", "options attempts:2 timeout:3 someother"}, + }, + { + name: "Existing options with timeout and attempts are updated", + others: []string{"some config", "options timeout:4 attempts:3"}, + timeout: 5, + attempts: 4, + expected: []string{"some config", "options timeout:5 attempts:4"}, + }, + { + name: "Modify existing options, add missing attempts before timeout", + others: []string{"some config", "options timeout:4"}, + timeout: 4, + attempts: 3, + expected: []string{"some config", "options attempts:3 timeout:4"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := prepareOptionsWithTimeout(tc.others, tc.timeout, tc.attempts) + assert.Equal(t, tc.expected, result) + }) + } +} + +func TestRemoveFirstNbNameserver(t *testing.T) { + testCases := []struct { + name string + content string + ipToRemove string + expected string + }{ + { + name: "Unrelated nameservers with comments and options", + content: `# This is a comment +options rotate +nameserver 1.1.1.1 +# Another comment +nameserver 8.8.4.4 +search example.com`, + ipToRemove: "9.9.9.9", + expected: `# This is a comment +options rotate +nameserver 1.1.1.1 +# Another comment +nameserver 8.8.4.4 +search example.com`, + }, + { + name: "First nameserver matches", + content: `search example.com +nameserver 9.9.9.9 +# oof, a comment +nameserver 8.8.4.4 +options attempts:5`, + ipToRemove: "9.9.9.9", + expected: `search example.com +# oof, a comment +nameserver 8.8.4.4 +options attempts:5`, + }, + { + name: "Target IP not the first nameserver", + // nolint:dupword + content: `# Comment about the first nameserver +nameserver 8.8.4.4 +# Comment before our target +nameserver 9.9.9.9 +options timeout:2`, + ipToRemove: "9.9.9.9", + // nolint:dupword + expected: `# Comment about the first nameserver +nameserver 8.8.4.4 +# Comment before our target +nameserver 9.9.9.9 +options timeout:2`, + }, + { + name: "Only nameserver matches", + content: `options debug +nameserver 9.9.9.9 +search localdomain`, + ipToRemove: "9.9.9.9", + expected: `options debug +nameserver 9.9.9.9 +search localdomain`, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tempDir := t.TempDir() + tempFile := filepath.Join(tempDir, "resolv.conf") + err := os.WriteFile(tempFile, []byte(tc.content), 0644) + assert.NoError(t, err) + + err = removeFirstNbNameserver(tempFile, tc.ipToRemove) + assert.NoError(t, err) + + content, err := os.ReadFile(tempFile) + assert.NoError(t, err) + + assert.Equal(t, tc.expected, string(content), "The resulting content should match the expected output.") + }) + } +} diff --git a/client/internal/dns/host_linux.go b/client/internal/dns/host_linux.go index 37d8f704a..cb246bcfe 100644 --- a/client/internal/dns/host_linux.go +++ b/client/internal/dns/host_linux.go @@ -65,7 +65,7 @@ func newHostManager(wgInterface string) (hostManager, error) { return nil, err } - log.Debugf("discovered mode is: %s", osManager) + log.Infof("System DNS manager discovered: %s", osManager) return newHostManagerFromType(wgInterface, osManager) } diff --git a/client/internal/dns/resolvconf_linux.go b/client/internal/dns/resolvconf_linux.go index b8f753e28..72db5faf1 100644 --- a/client/internal/dns/resolvconf_linux.go +++ b/client/internal/dns/resolvconf_linux.go @@ -53,10 +53,12 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error { searchDomainList := searchDomains(config) searchDomainList = mergeSearchDomains(searchDomainList, r.originalSearchDomains) + options := prepareOptionsWithTimeout(r.othersConfigs, int(dnsFailoverTimeout.Seconds()), dnsFailoverAttempts) + buf := prepareResolvConfContent( searchDomainList, append([]string{config.ServerIP}, r.originalNameServers...), - r.othersConfigs) + options) // create a backup for unclean shutdown detection before the resolv.conf is changed if err := createUncleanShutdownIndicator(defaultResolvConfPath, resolvConfManager, config.ServerIP); err != nil { diff --git a/management/client/grpc.go b/management/client/grpc.go index a426413a0..0234f866c 100644 --- a/management/client/grpc.go +++ b/management/client/grpc.go @@ -26,6 +26,8 @@ import ( "github.com/netbirdio/netbird/management/proto" ) +const ConnectTimeout = 10 * time.Second + // ConnStateNotifier is a wrapper interface of the status recorders type ConnStateNotifier interface { MarkManagementDisconnected(error) @@ -49,7 +51,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})) } - mgmCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + mgmCtx, cancel := context.WithTimeout(ctx, ConnectTimeout) defer cancel() conn, err := grpc.DialContext( mgmCtx, @@ -318,7 +320,7 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro log.Errorf("failed to encrypt message: %s", err) return nil, err } - mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) + mgmCtx, cancel := context.WithTimeout(c.ctx, ConnectTimeout) defer cancel() resp, err := c.realClient.Login(mgmCtx, &proto.EncryptedMessage{ WgPubKey: c.key.PublicKey().String(), diff --git a/signal/client/grpc.go b/signal/client/grpc.go index 07276aef1..7531608c3 100644 --- a/signal/client/grpc.go +++ b/signal/client/grpc.go @@ -21,11 +21,10 @@ import ( "google.golang.org/grpc/status" "github.com/netbirdio/netbird/encryption" + "github.com/netbirdio/netbird/management/client" "github.com/netbirdio/netbird/signal/proto" ) -const defaultSendTimeout = 5 * time.Second - // ConnStateNotifier is a wrapper interface of the status recorder type ConnStateNotifier interface { MarkSignalDisconnected(error) @@ -71,7 +70,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})) } - sigCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + sigCtx, cancel := context.WithTimeout(ctx, client.ConnectTimeout) defer cancel() conn, err := grpc.DialContext( sigCtx, @@ -353,7 +352,7 @@ func (c *GrpcClient) Send(msg *proto.Message) error { return err } - attemptTimeout := defaultSendTimeout + attemptTimeout := client.ConnectTimeout for attempt := 0; attempt < 4; attempt++ { if attempt > 1 {