diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 65ae0aa26..1889b58e7 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -10,8 +10,10 @@ on: env: - SIGN_PIPE_VER: "v0.0.11" + SIGN_PIPE_VER: "v0.0.12" GORELEASER_VER: "v1.14.1" + PRODUCT_NAME: "NetBird" + COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)" concurrency: group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }} @@ -23,6 +25,13 @@ jobs: env: flags: "" steps: + - name: Parse semver string + id: semver_parser + uses: booxmedialtd/ws-action-parse-semver@v1 + with: + input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }} + version_extractor_regex: '\/v(.*)$' + - if: ${{ !startsWith(github.ref, 'refs/tags/v') }} run: echo "flags=--snapshot" >> $GITHUB_ENV - @@ -68,18 +77,18 @@ jobs: - name: Install OS build dependencies run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu - - name: Install rsrc - run: go install github.com/akavel/rsrc@v0.10.2 - - name: Generate windows rsrc amd64 - run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_amd64.syso - - name: Generate windows rsrc arm64 - run: rsrc -arch arm64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm64.syso - - name: Generate windows rsrc arm - run: rsrc -arch arm -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm.syso - - name: Generate windows rsrc 386 - run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso - - - name: Run GoReleaser + - name: Install goversioninfo + run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e + - name: Generate windows syso 386 + run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_386.syso + - name: Generate windows syso arm + run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_arm.syso + - name: Generate windows syso arm64 + run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_arm64.syso + - name: Generate windows syso amd64 + run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_amd64.syso + + - name: Run GoReleaser uses: goreleaser/goreleaser-action@v4 with: version: ${{ env.GORELEASER_VER }} @@ -121,6 +130,13 @@ jobs: release_ui: runs-on: ubuntu-latest steps: + - name: Parse semver string + id: semver_parser + uses: booxmedialtd/ws-action-parse-semver@v1 + with: + input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }} + version_extractor_regex: '\/v(.*)$' + - if: ${{ !startsWith(github.ref, 'refs/tags/v') }} run: echo "flags=--snapshot" >> $GITHUB_ENV - name: Checkout @@ -151,10 +167,11 @@ jobs: - name: Install dependencies run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64 - - name: Install rsrc - run: go install github.com/akavel/rsrc@v0.10.2 - - name: Generate windows rsrc - run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/ui/manifest.xml -o client/ui/resources_windows_amd64.syso + - name: Install goversioninfo + run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e + - name: Generate windows syso amd64 + run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/ui/resources_windows_amd64.syso + - name: Run GoReleaser uses: goreleaser/goreleaser-action@v4 with: diff --git a/client/cmd/down.go b/client/cmd/down.go index 1837b13da..4d9f1eba4 100644 --- a/client/cmd/down.go +++ b/client/cmd/down.go @@ -26,7 +26,7 @@ var downCmd = &cobra.Command{ return err } - ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*7) defer cancel() conn, err := DialClientGRPCServer(ctx, daemonAddr) diff --git a/client/cmd/root.go b/client/cmd/root.go index f0b5d2bdf..1e5c56366 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -121,7 +121,7 @@ func init() { rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name") rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location") rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level") - rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout") + rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout. If syslog is specified the log will be sent to syslog daemon.") rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)") rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.") rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device") diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 427a73825..75792e9c0 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -337,7 +337,6 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) { return rule.drop, true } - return rule.drop, true case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6: return rule.drop, true } diff --git a/client/internal/dns/host.go b/client/internal/dns/host.go index 3070763a6..e55a07055 100644 --- a/client/internal/dns/host.go +++ b/client/internal/dns/host.go @@ -15,6 +15,12 @@ type hostManager interface { restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error } +type SystemDNSSettings struct { + Domains []string + ServerIP string + ServerPort int +} + type HostDNSConfig struct { Domains []DomainConfig `json:"domains"` RouteAll bool `json:"routeAll"` diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index 5ae84fb91..5dee305c2 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -7,6 +7,7 @@ import ( "bytes" "fmt" "io" + "net" "net/netip" "os/exec" "strconv" @@ -18,7 +19,7 @@ import ( const ( netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS" globalIPv4State = "State:/Network/Global/IPv4" - primaryServiceSetupKeyFormat = "Setup:/Network/Service/%s/DNS" + primaryServiceStateKeyFormat = "State:/Network/Service/%s/DNS" keySupplementalMatchDomains = "SupplementalMatchDomains" keySupplementalMatchDomainsNoSearch = "SupplementalMatchDomainsNoSearch" keyServerAddresses = "ServerAddresses" @@ -28,12 +29,12 @@ const ( scutilPath = "/usr/sbin/scutil" searchSuffix = "Search" matchSuffix = "Match" + localSuffix = "Local" ) type systemConfigurator struct { - // primaryServiceID primary interface in the system. AKA the interface with the default route - primaryServiceID string - createdKeys map[string]struct{} + createdKeys map[string]struct{} + systemDNSSettings SystemDNSSettings } func newHostManager() (hostManager, error) { @@ -49,20 +50,6 @@ func (s *systemConfigurator) supportCustomPort() bool { func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error { var err error - if config.RouteAll { - err = s.addDNSSetupForAll(config.ServerIP, config.ServerPort) - if err != nil { - return fmt.Errorf("add dns setup for all: %w", err) - } - } else if s.primaryServiceID != "" { - err = s.removeKeyFromSystemConfig(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID)) - if err != nil { - return fmt.Errorf("remote key from system config: %w", err) - } - s.primaryServiceID = "" - log.Infof("removed %s:%d as main DNS resolver for this peer", config.ServerIP, config.ServerPort) - } - // create a file for unclean shutdown detection if err := createUncleanShutdownIndicator(); err != nil { log.Errorf("failed to create unclean shutdown file: %s", err) @@ -73,6 +60,19 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error { matchDomains []string ) + err = s.recordSystemDNSSettings(true) + if err != nil { + log.Errorf("unable to update record of System's DNS config: %s", err.Error()) + } + + if config.RouteAll { + searchDomains = append(searchDomains, "\"\"") + err = s.addLocalDNS() + if err != nil { + log.Infof("failed to enable split DNS") + } + } + for _, dConf := range config.Domains { if dConf.Disabled { continue @@ -110,23 +110,17 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error { } func (s *systemConfigurator) restoreHostDNS() error { - lines := "" - for key := range s.createdKeys { - lines += buildRemoveKeyOperation(key) + keys := s.getRemovableKeysWithDefaults() + for _, key := range keys { keyType := "search" if strings.Contains(key, matchSuffix) { keyType = "match" } log.Infof("removing %s domains from system", keyType) - } - if s.primaryServiceID != "" { - lines += buildRemoveKeyOperation(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID)) - log.Infof("restoring DNS resolver configuration for system") - } - _, err := runSystemConfigCommand(wrapCommand(lines)) - if err != nil { - log.Errorf("got an error while cleaning the system configuration: %s", err) - return fmt.Errorf("clean system: %w", err) + err := s.removeKeyFromSystemConfig(key) + if err != nil { + log.Errorf("failed to remove %s domains from system: %s", keyType, err) + } } if err := removeUncleanShutdownIndicator(); err != nil { @@ -136,6 +130,19 @@ func (s *systemConfigurator) restoreHostDNS() error { return nil } +func (s *systemConfigurator) getRemovableKeysWithDefaults() []string { + if len(s.createdKeys) == 0 { + // return defaults for startup calls + return []string{getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix), getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)} + } + + keys := make([]string, 0, len(s.createdKeys)) + for key := range s.createdKeys { + keys = append(keys, key) + } + return keys +} + func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error { line := buildRemoveKeyOperation(key) _, err := runSystemConfigCommand(wrapCommand(line)) @@ -148,6 +155,97 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error { return nil } +func (s *systemConfigurator) addLocalDNS() error { + if s.systemDNSSettings.ServerIP == "" || len(s.systemDNSSettings.Domains) == 0 { + err := s.recordSystemDNSSettings(true) + log.Errorf("Unable to get system DNS configuration") + return err + } + localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) + if s.systemDNSSettings.ServerIP != "" && len(s.systemDNSSettings.Domains) != 0 { + err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort) + if err != nil { + return fmt.Errorf("couldn't add local network DNS conf: %w", err) + } + } else { + log.Info("Not enabling local DNS server") + } + + return nil +} + +func (s *systemConfigurator) recordSystemDNSSettings(force bool) error { + if s.systemDNSSettings.ServerIP != "" && len(s.systemDNSSettings.Domains) != 0 && !force { + return nil + } + + systemDNSSettings, err := s.getSystemDNSSettings() + if err != nil { + return fmt.Errorf("couldn't get current DNS config: %w", err) + } + s.systemDNSSettings = systemDNSSettings + + return nil +} + +func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) { + primaryServiceKey, _, err := s.getPrimaryService() + if err != nil || primaryServiceKey == "" { + return SystemDNSSettings{}, fmt.Errorf("couldn't find the primary service key: %w", err) + } + dnsServiceKey := getKeyWithInput(primaryServiceStateKeyFormat, primaryServiceKey) + line := buildCommandLine("show", dnsServiceKey, "") + stdinCommands := wrapCommand(line) + + b, err := runSystemConfigCommand(stdinCommands) + if err != nil { + return SystemDNSSettings{}, fmt.Errorf("sending the command: %w", err) + } + + var dnsSettings SystemDNSSettings + inSearchDomainsArray := false + inServerAddressesArray := false + + scanner := bufio.NewScanner(bytes.NewReader(b)) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + switch { + case strings.HasPrefix(line, "DomainName :"): + domainName := strings.TrimSpace(strings.Split(line, ":")[1]) + dnsSettings.Domains = append(dnsSettings.Domains, domainName) + case line == "SearchDomains : {": + inSearchDomainsArray = true + continue + case line == "ServerAddresses : {": + inServerAddressesArray = true + continue + case line == "}": + inSearchDomainsArray = false + inServerAddressesArray = false + } + + if inSearchDomainsArray { + searchDomain := strings.Split(line, " : ")[1] + dnsSettings.Domains = append(dnsSettings.Domains, searchDomain) + } else if inServerAddressesArray { + address := strings.Split(line, " : ")[1] + if ip := net.ParseIP(address); ip != nil && ip.To4() != nil { + dnsSettings.ServerIP = address + inServerAddressesArray = false // Stop reading after finding the first IPv4 address + } + } + } + + if err := scanner.Err(); err != nil { + return dnsSettings, err + } + + // default to 53 port + dnsSettings.ServerPort = 53 + + return dnsSettings, nil +} + func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, port int) error { err := s.addDNSState(key, domains, ip, port, true) if err != nil { @@ -194,23 +292,6 @@ func (s *systemConfigurator) addDNSState(state, domains, dnsServer string, port return nil } -func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error { - primaryServiceKey, existingNameserver, err := s.getPrimaryService() - if err != nil || primaryServiceKey == "" { - return fmt.Errorf("couldn't find the primary service key: %w", err) - } - - err = s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port, existingNameserver) - if err != nil { - return fmt.Errorf("add dns setup: %w", err) - } - - log.Infof("configured %s:%d as main DNS resolver for this peer", dnsServer, port) - s.primaryServiceID = primaryServiceKey - - return nil -} - func (s *systemConfigurator) getPrimaryService() (string, string, error) { line := buildCommandLine("show", globalIPv4State, "") stdinCommands := wrapCommand(line) @@ -239,19 +320,6 @@ func (s *systemConfigurator) getPrimaryService() (string, string, error) { return primaryService, router, nil } -func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int, existingDNSServer string) error { - lines := buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+strconv.Itoa(0)) - lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer+" "+existingDNSServer) - lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port)) - addDomainCommand := buildCreateStateWithOperation(setupKey, lines) - stdinCommands := wrapCommand(addDomainCommand) - _, err := runSystemConfigCommand(stdinCommands) - if err != nil { - return fmt.Errorf("applying dns setup, error: %w", err) - } - return nil -} - func (s *systemConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error { if err := s.restoreHostDNS(); err != nil { return fmt.Errorf("restoring dns via scutil: %w", err) diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index b502bf5eb..b3baf2fa8 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -24,7 +24,7 @@ const ( probeTimeout = 2 * time.Second ) -const testRecord = "." +const testRecord = "com." type upstreamClient interface { exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) @@ -42,6 +42,7 @@ type upstreamResolverBase struct { upstreamServers []string disabled bool failsCount atomic.Int32 + successCount atomic.Int32 failsTillDeact int32 mutex sync.Mutex reactivatePeriod time.Duration @@ -124,6 +125,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { return } + u.successCount.Add(1) log.Tracef("took %s to query the upstream %s", t, upstream) err = w.WriteMsg(rm) @@ -172,6 +174,11 @@ func (u *upstreamResolverBase) probeAvailability() { default: } + // avoid probe if upstreams could resolve at least one query and fails count is less than failsTillDeact + if u.successCount.Load() > 0 && u.failsCount.Load() < u.failsTillDeact { + return + } + var success bool var mu sync.Mutex var wg sync.WaitGroup @@ -183,7 +190,7 @@ func (u *upstreamResolverBase) probeAvailability() { wg.Add(1) go func() { defer wg.Done() - err := u.testNameserver(upstream) + err := u.testNameserver(upstream, 500*time.Millisecond) if err != nil { errors = multierror.Append(errors, err) log.Warnf("probing upstream nameserver %s: %s", upstream, err) @@ -224,7 +231,7 @@ func (u *upstreamResolverBase) waitUntilResponse() { } for _, upstream := range u.upstreamServers { - if err := u.testNameserver(upstream); err != nil { + if err := u.testNameserver(upstream, probeTimeout); err != nil { log.Tracef("upstream check for %s: %s", upstream, err) } else { // at least one upstream server is available, stop probing @@ -244,6 +251,7 @@ func (u *upstreamResolverBase) waitUntilResponse() { log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServers) u.failsCount.Store(0) + u.successCount.Add(1) u.reactivate() u.disabled = false } @@ -265,13 +273,14 @@ func (u *upstreamResolverBase) disable(err error) { } log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod) + u.successCount.Store(0) u.deactivate(err) u.disabled = true go u.waitUntilResponse() } -func (u *upstreamResolverBase) testNameserver(server string) error { - ctx, cancel := context.WithTimeout(u.ctx, probeTimeout) +func (u *upstreamResolverBase) testNameserver(server string, timeout time.Duration) error { + ctx, cancel := context.WithTimeout(u.ctx, timeout) defer cancel() r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA) diff --git a/client/internal/networkmonitor/monitor_bsd.go b/client/internal/networkmonitor/monitor_bsd.go index 8d6ccd51b..29df7ea7f 100644 --- a/client/internal/networkmonitor/monitor_bsd.go +++ b/client/internal/networkmonitor/monitor_bsd.go @@ -4,6 +4,7 @@ package networkmonitor import ( "context" + "errors" "fmt" "syscall" "unsafe" @@ -21,11 +22,20 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca return fmt.Errorf("failed to open routing socket: %v", err) } defer func() { - if err := unix.Close(fd); err != nil { + err := unix.Close(fd) + if err != nil && !errors.Is(err, unix.EBADF) { log.Errorf("Network monitor: failed to close routing socket: %v", err) } }() + go func() { + <-ctx.Done() + err := unix.Close(fd) + if err != nil && !errors.Is(err, unix.EBADF) { + log.Debugf("Network monitor: closed routing socket") + } + }() + for { select { case <-ctx.Done(): @@ -34,7 +44,9 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca buf := make([]byte, 2048) n, err := unix.Read(fd, buf) if err != nil { - log.Errorf("Network monitor: failed to read from routing socket: %v", err) + if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) { + log.Errorf("Network monitor: failed to read from routing socket: %v", err) + } continue } if n < unix.SizeofRtMsghdr { diff --git a/client/internal/networkmonitor/monitor_windows.go b/client/internal/networkmonitor/monitor_windows.go index e24bdd066..308b2aa45 100644 --- a/client/internal/networkmonitor/monitor_windows.go +++ b/client/internal/networkmonitor/monitor_windows.go @@ -99,6 +99,11 @@ func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes []syste return false } + if isSoftInterface(nexthop.Intf.Name) { + log.Tracef("network monitor: ignoring default route change for soft interface %s", nexthop.Intf.Name) + return false + } + unspec := getUnspecifiedPrefix(nexthop.IP) defaultRoutes, foundMatchingRoute := processRoutes(nexthop, intf, routes, unspec) @@ -119,7 +124,7 @@ func getUnspecifiedPrefix(ip netip.Addr) netip.Prefix { return netip.PrefixFrom(netip.IPv4Unspecified(), 0) } -func processRoutes(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) { +func processRoutes(nexthop systemops.Nexthop, nexthopIntf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) { var defaultRoutes []string foundMatchingRoute := false @@ -128,7 +133,7 @@ func processRoutes(nexthop systemops.Nexthop, intf *net.Interface, routes []syst routeInfo := formatRouteInfo(r) defaultRoutes = append(defaultRoutes, routeInfo) - if r.Nexthop == nexthop.IP && compareIntf(r.Interface, intf) == 0 { + if r.Nexthop == nexthop.IP && compareIntf(r.Interface, nexthopIntf) == 0 { foundMatchingRoute = true log.Debugf("network monitor: found matching default route: %s", routeInfo) } @@ -232,14 +237,18 @@ func stateFromInt(state uint8) string { } func compareIntf(a, b *net.Interface) int { - if a == nil && b == nil { + switch { + case a == nil && b == nil: return 0 - } - if a == nil { + case a == nil: return -1 - } - if b == nil { + case b == nil: return 1 + default: + return a.Index - b.Index } - return a.Index - b.Index +} + +func isSoftInterface(name string) bool { + return strings.Contains(strings.ToLower(name), "isatap") || strings.Contains(strings.ToLower(name), "teredo") } diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index adf3b6a2f..779a09c65 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -3,6 +3,7 @@ package routemanager import ( "context" "fmt" + "reflect" "time" "github.com/hashicorp/go-multierror" @@ -303,22 +304,33 @@ func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) { }() } -func (c *clientNetwork) handleUpdate(update routesUpdate) { +func (c *clientNetwork) handleUpdate(update routesUpdate) bool { + isUpdateMapDifferent := false updateMap := make(map[route.ID]*route.Route) for _, r := range update.routes { updateMap[r.ID] = r } + if len(c.routes) != len(updateMap) { + isUpdateMapDifferent = true + } + for id, r := range c.routes { _, found := updateMap[id] if !found { close(c.routePeersNotifiers[r.Peer]) delete(c.routePeersNotifiers, r.Peer) + isUpdateMapDifferent = true + continue + } + if !reflect.DeepEqual(c.routes[id], updateMap[id]) { + isUpdateMapDifferent = true } } c.routes = updateMap + return isUpdateMapDifferent } // peersStateAndUpdateWatcher is the main point of reacting on client network routing events. @@ -345,13 +357,19 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { log.Debugf("Received a new client network route update for [%v]", c.handler) - c.handleUpdate(update) + // hash update somehow + isTrueRouteUpdate := c.handleUpdate(update) c.updateSerial = update.updateSerial - err := c.recalculateRouteAndUpdatePeerAndSystem() - if err != nil { - log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err) + if isTrueRouteUpdate { + log.Debug("Client network update contains different routes, recalculating routes") + err := c.recalculateRouteAndUpdatePeerAndSystem() + if err != nil { + log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err) + } + } else { + log.Debug("Route update is not different, skipping route recalculation") } c.startPeersStatusChangeWatcher() diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 2e1378414..597eddd51 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -16,6 +16,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/routemanager/notifier" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/vars" @@ -50,7 +51,7 @@ type DefaultManager struct { statusRecorder *peer.Status wgInterface iface.IWGIface pubKey string - notifier *notifier + notifier *notifier.Notifier routeRefCounter *refcounter.RouteRefCounter allowedIPsRefCounter *refcounter.AllowedIPsRefCounter dnsRouteInterval time.Duration @@ -65,7 +66,8 @@ func NewManager( initialRoutes []*route.Route, ) *DefaultManager { mCTX, cancel := context.WithCancel(ctx) - sysOps := systemops.NewSysOps(wgInterface) + notifier := notifier.NewNotifier() + sysOps := systemops.NewSysOps(wgInterface, notifier) dm := &DefaultManager{ ctx: mCTX, @@ -77,7 +79,7 @@ func NewManager( statusRecorder: statusRecorder, wgInterface: wgInterface, pubKey: pubKey, - notifier: newNotifier(), + notifier: notifier, } dm.routeRefCounter = refcounter.New( @@ -107,7 +109,7 @@ func NewManager( if runtime.GOOS == "android" { cr := dm.clientRoutes(initialRoutes) - dm.notifier.setInitialClientRoutes(cr) + dm.notifier.SetInitialClientRoutes(cr) } return dm } @@ -186,7 +188,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap) m.updateClientNetworks(updateSerial, filteredClientRoutes) - m.notifier.onNewRoutes(filteredClientRoutes) + m.notifier.OnNewRoutes(filteredClientRoutes) if m.serverRouter != nil { err := m.serverRouter.updateRoutes(newServerRoutesMap) @@ -199,14 +201,14 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro } } -// SetRouteChangeListener set RouteListener for route change notifier +// SetRouteChangeListener set RouteListener for route change Notifier func (m *DefaultManager) SetRouteChangeListener(listener listener.NetworkChangeListener) { - m.notifier.setListener(listener) + m.notifier.SetListener(listener) } // InitialRouteRange return the list of initial routes. It used by mobile systems func (m *DefaultManager) InitialRouteRange() []string { - return m.notifier.getInitialRouteRanges() + return m.notifier.GetInitialRouteRanges() } // GetRouteSelector returns the route selector @@ -226,7 +228,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) { networks = m.routeSelector.FilterSelected(networks) - m.notifier.onNewRoutes(networks) + m.notifier.OnNewRoutes(networks) m.stopObsoleteClients(networks) diff --git a/client/internal/routemanager/notifier.go b/client/internal/routemanager/notifier/notifier.go similarity index 67% rename from client/internal/routemanager/notifier.go rename to client/internal/routemanager/notifier/notifier.go index b606c79da..ebdd60323 100644 --- a/client/internal/routemanager/notifier.go +++ b/client/internal/routemanager/notifier/notifier.go @@ -1,6 +1,7 @@ -package routemanager +package notifier import ( + "net/netip" "runtime" "sort" "strings" @@ -10,7 +11,7 @@ import ( "github.com/netbirdio/netbird/route" ) -type notifier struct { +type Notifier struct { initialRouteRanges []string routeRanges []string @@ -18,17 +19,17 @@ type notifier struct { listenerMux sync.Mutex } -func newNotifier() *notifier { - return ¬ifier{} +func NewNotifier() *Notifier { + return &Notifier{} } -func (n *notifier) setListener(listener listener.NetworkChangeListener) { +func (n *Notifier) SetListener(listener listener.NetworkChangeListener) { n.listenerMux.Lock() defer n.listenerMux.Unlock() n.listener = listener } -func (n *notifier) setInitialClientRoutes(clientRoutes []*route.Route) { +func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) { nets := make([]string, 0) for _, r := range clientRoutes { nets = append(nets, r.Network.String()) @@ -37,7 +38,10 @@ func (n *notifier) setInitialClientRoutes(clientRoutes []*route.Route) { n.initialRouteRanges = nets } -func (n *notifier) onNewRoutes(idMap route.HAMap) { +func (n *Notifier) OnNewRoutes(idMap route.HAMap) { + if runtime.GOOS != "android" { + return + } newNets := make([]string, 0) for _, routes := range idMap { for _, r := range routes { @@ -62,7 +66,30 @@ func (n *notifier) onNewRoutes(idMap route.HAMap) { n.notify() } -func (n *notifier) notify() { +func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) { + newNets := make([]string, 0) + for _, prefix := range prefixes { + newNets = append(newNets, prefix.String()) + } + + sort.Strings(newNets) + switch runtime.GOOS { + case "android": + if !n.hasDiff(n.initialRouteRanges, newNets) { + return + } + default: + if !n.hasDiff(n.routeRanges, newNets) { + return + } + } + + n.routeRanges = newNets + + n.notify() +} + +func (n *Notifier) notify() { n.listenerMux.Lock() defer n.listenerMux.Unlock() if n.listener == nil { @@ -74,7 +101,7 @@ func (n *notifier) notify() { }(n.listener) } -func (n *notifier) hasDiff(a []string, b []string) bool { +func (n *Notifier) hasDiff(a []string, b []string) bool { if len(a) != len(b) { return true } @@ -86,7 +113,7 @@ func (n *notifier) hasDiff(a []string, b []string) bool { return false } -func (n *notifier) getInitialRouteRanges() []string { +func (n *Notifier) GetInitialRouteRanges() []string { return addIPv6RangeIfNeeded(n.initialRouteRanges) } diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go index fa7ab0290..ae27b0123 100644 --- a/client/internal/routemanager/systemops/systemops.go +++ b/client/internal/routemanager/systemops/systemops.go @@ -3,7 +3,9 @@ package systemops import ( "net" "net/netip" + "sync" + "github.com/netbirdio/netbird/client/internal/routemanager/notifier" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/iface" ) @@ -18,10 +20,19 @@ type ExclusionCounter = refcounter.Counter[any, Nexthop] type SysOps struct { refCounter *ExclusionCounter wgInterface iface.IWGIface + // prefixes is tracking all the current added prefixes im memory + // (this is used in iOS as all route updates require a full table update) + //nolint + prefixes map[netip.Prefix]struct{} + //nolint + mu sync.Mutex + // notifier is used to notify the system of route changes (also used on mobile) + notifier *notifier.Notifier } -func NewSysOps(wgInterface iface.IWGIface) *SysOps { +func NewSysOps(wgInterface iface.IWGIface, notifier *notifier.Notifier) *SysOps { return &SysOps{ wgInterface: wgInterface, + notifier: notifier, } } diff --git a/client/internal/routemanager/systemops/systemops_mobile.go b/client/internal/routemanager/systemops/systemops_android.go similarity index 96% rename from client/internal/routemanager/systemops/systemops_mobile.go rename to client/internal/routemanager/systemops/systemops_android.go index 43815c657..5e97a4a5f 100644 --- a/client/internal/routemanager/systemops/systemops_mobile.go +++ b/client/internal/routemanager/systemops/systemops_android.go @@ -1,4 +1,4 @@ -//go:build ios || android +//go:build android package systemops diff --git a/client/internal/routemanager/systemops/systemops_bsd_test.go b/client/internal/routemanager/systemops/systemops_bsd_test.go index ce9a9082a..84b84483e 100644 --- a/client/internal/routemanager/systemops/systemops_bsd_test.go +++ b/client/internal/routemanager/systemops/systemops_bsd_test.go @@ -36,7 +36,7 @@ func TestConcurrentRoutes(t *testing.T) { baseIP := netip.MustParseAddr("192.0.2.0") intf := &net.Interface{Name: "lo0"} - r := NewSysOps(nil) + r := NewSysOps(nil, nil) var wg sync.WaitGroup for i := 0; i < 1024; i++ { diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 0c152d233..4190debf9 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -50,7 +50,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn nexthop, err := r.addRouteToNonVPNIntf(prefix, r.wgInterface, initialNexthop) if errors.Is(err, vars.ErrRouteNotAllowed) || errors.Is(err, vars.ErrRouteNotFound) { log.Tracef("Adding for prefix %s: %v", prefix, err) - // These errors are not critical but also we should not track and try to remove the routes either. + // These errors are not critical, but also we should not track and try to remove the routes either. return nexthop, refcounter.ErrIgnore } return nexthop, err @@ -135,6 +135,11 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.IWGIfac return Nexthop{}, vars.ErrRouteNotAllowed } + // Check if the prefix is part of any local subnets + if isLocal, subnet := r.isPrefixInLocalSubnets(prefix); isLocal { + return Nexthop{}, fmt.Errorf("prefix %s is part of local subnet %s: %w", prefix, subnet, vars.ErrRouteNotAllowed) + } + // Determine the exit interface and next hop for the prefix, so we can add a specific route nexthop, err := GetNextHop(addr) if err != nil { @@ -167,6 +172,36 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.IWGIfac return exitNextHop, nil } +func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet) { + localInterfaces, err := net.Interfaces() + if err != nil { + log.Errorf("Failed to get local interfaces: %v", err) + return false, nil + } + + for _, intf := range localInterfaces { + addrs, err := intf.Addrs() + if err != nil { + log.Errorf("Failed to get addresses for interface %s: %v", intf.Name, err) + continue + } + + for _, addr := range addrs { + ipnet, ok := addr.(*net.IPNet) + if !ok { + log.Errorf("Failed to convert address to IPNet: %v", addr) + continue + } + + if ipnet.Contains(prefix.Addr().AsSlice()) { + return true, ipnet + } + } + } + + return false, nil +} + // genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix // in two /1 prefixes to avoid replacing the existing default route func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index 292166582..94965c119 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -68,7 +68,7 @@ func TestAddRemoveRoutes(t *testing.T) { err = wgInterface.Create() require.NoError(t, err, "should create testing wireguard interface") - r := NewSysOps(wgInterface) + r := NewSysOps(wgInterface, nil) _, _, err = r.SetupRouting(nil) require.NoError(t, err) @@ -224,7 +224,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) { require.NoError(t, err, "InterfaceByName should not return err") intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()} - r := NewSysOps(wgInterface) + r := NewSysOps(wgInterface, nil) // Prepare the environment if testCase.preExistingPrefix.IsValid() { @@ -379,7 +379,7 @@ func setupTestEnv(t *testing.T) { assert.NoError(t, wgInterface.Close()) }) - r := NewSysOps(wgInterface) + r := NewSysOps(wgInterface, nil) _, _, err := r.SetupRouting(nil) require.NoError(t, err, "setupRouting should not return err") t.Cleanup(func() { diff --git a/client/internal/routemanager/systemops/systemops_ios.go b/client/internal/routemanager/systemops/systemops_ios.go new file mode 100644 index 000000000..7cfb2b298 --- /dev/null +++ b/client/internal/routemanager/systemops/systemops_ios.go @@ -0,0 +1,64 @@ +//go:build ios + +package systemops + +import ( + "net" + "net/netip" + "runtime" + + log "github.com/sirupsen/logrus" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +func (r *SysOps) SetupRouting([]net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { + r.mu.Lock() + defer r.mu.Unlock() + r.prefixes = make(map[netip.Prefix]struct{}) + return nil, nil, nil +} + +func (r *SysOps) CleanupRouting() error { + r.mu.Lock() + defer r.mu.Unlock() + + r.prefixes = make(map[netip.Prefix]struct{}) + r.notify() + return nil +} + +func (r *SysOps) AddVPNRoute(prefix netip.Prefix, _ *net.Interface) error { + r.mu.Lock() + defer r.mu.Unlock() + + r.prefixes[prefix] = struct{}{} + r.notify() + return nil +} + +func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, _ *net.Interface) error { + r.mu.Lock() + defer r.mu.Unlock() + + delete(r.prefixes, prefix) + r.notify() + return nil +} + +func EnableIPForwarding() error { + log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) + return nil +} + +func IsAddrRouted(netip.Addr, []netip.Prefix) (bool, netip.Prefix) { + return false, netip.Prefix{} +} + +func (r *SysOps) notify() { + prefixes := make([]netip.Prefix, 0, len(r.prefixes)) + for prefix := range r.prefixes { + prefixes = append(prefixes, prefix) + } + r.notifier.OnNewPrefixes(prefixes) +} diff --git a/client/internal/routemanager/systemops/systemops_windows_test.go b/client/internal/routemanager/systemops/systemops_windows_test.go index 9180ed58c..19b006017 100644 --- a/client/internal/routemanager/systemops/systemops_windows_test.go +++ b/client/internal/routemanager/systemops/systemops_windows_test.go @@ -73,7 +73,7 @@ var testCases = []testCase{ { name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence destination: "10.0.0.2:53", - expectedSourceIP: "10.0.0.1", + expectedSourceIP: "127.0.0.1", expectedDestPrefix: "10.0.0.0/8", expectedNextHop: "0.0.0.0", expectedInterface: "Loopback Pseudo-Interface 1", @@ -110,7 +110,7 @@ var testCases = []testCase{ { name: "To more specific route (local) without custom dialer via physical interface", destination: "127.0.10.2:53", - expectedSourceIP: "10.0.0.1", + expectedSourceIP: "127.0.0.1", expectedDestPrefix: "127.0.0.0/8", expectedNextHop: "0.0.0.0", expectedInterface: "Loopback Pseudo-Interface 1", @@ -181,31 +181,6 @@ func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOut return combinedOutput } -func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string { - t.Helper() - - ip, ipNet, err := net.ParseCIDR(ipAddressCIDR) - require.NoError(t, err) - subnetMaskSize, _ := ipNet.Mask.Size() - script := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -PrefixLength %d -PolicyStore ActiveStore -Confirm:$False`, interfaceName, ip.String(), subnetMaskSize) - _, err = exec.Command("powershell", "-Command", script).CombinedOutput() - require.NoError(t, err, "Failed to assign IP address to loopback adapter") - - // Wait for the IP address to be applied - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - err = waitForIPAddress(ctx, interfaceName, ip.String()) - require.NoError(t, err, "IP address not applied within timeout") - - t.Cleanup(func() { - script = fmt.Sprintf(`Remove-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -Confirm:$False`, interfaceName, ip.String()) - _, err = exec.Command("powershell", "-Command", script).CombinedOutput() - require.NoError(t, err, "Failed to remove IP address from loopback adapter") - }) - - return interfaceName -} - func fetchOriginalGateway() (*RouteInfo, error) { cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object Nexthop, RouteMetric, InterfaceAlias | ConvertTo-Json") output, err := cmd.CombinedOutput() @@ -231,30 +206,6 @@ func verifyOutput(t *testing.T, output *FindNetRouteOutput, sourceIP, destPrefix assert.Equal(t, intf, output.InterfaceAlias, "Interface mismatch") } -func waitForIPAddress(ctx context.Context, interfaceAlias, expectedIPAddress string) error { - ticker := time.NewTicker(1 * time.Second) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-ticker.C: - out, err := exec.Command("powershell", "-Command", fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Select-Object -ExpandProperty IPAddress`, interfaceAlias)).CombinedOutput() - if err != nil { - return err - } - - ipAddresses := strings.Split(strings.TrimSpace(string(out)), "\n") - for _, ip := range ipAddresses { - if strings.TrimSpace(ip) == expectedIPAddress { - return nil - } - } - } - } -} - func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput { var combined FindNetRouteOutput @@ -285,5 +236,25 @@ func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput { func setupDummyInterfacesAndRoutes(t *testing.T) { t.Helper() - createAndSetupDummyInterface(t, "Loopback Pseudo-Interface 1", "10.0.0.1/8") + addDummyRoute(t, "10.0.0.0/8") +} + +func addDummyRoute(t *testing.T, dstCIDR string) { + t.Helper() + + script := fmt.Sprintf(`New-NetRoute -DestinationPrefix "%s" -InterfaceIndex 1 -PolicyStore ActiveStore`, dstCIDR) + + output, err := exec.Command("powershell", "-Command", script).CombinedOutput() + if err != nil { + t.Logf("Failed to add dummy route: %v\nOutput: %s", err, output) + t.FailNow() + } + + t.Cleanup(func() { + script = fmt.Sprintf(`Remove-NetRoute -DestinationPrefix "%s" -InterfaceIndex 1 -Confirm:$false`, dstCIDR) + output, err := exec.Command("powershell", "-Command", script).CombinedOutput() + if err != nil { + t.Logf("Failed to remove dummy route: %v\nOutput: %s", err, output) + } + }) } diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index d96f035df..d80072c78 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -19,6 +19,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/formatter" + "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/route" ) @@ -47,6 +48,7 @@ type CustomLogger interface { type selectRoute struct { NetID string Network netip.Prefix + Domains domain.List Selected bool } @@ -279,6 +281,7 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) { route := &selectRoute{ NetID: string(id), Network: rt[0].Network, + Domains: rt[0].Domains, Selected: routeSelector.IsSelected(id), } routes = append(routes, route) @@ -299,17 +302,40 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) { return iPrefix < jPrefix }) + resolvedDomains := c.recorder.GetResolvedDomainsStates() + + return prepareRouteSelectionDetails(routes, resolvedDomains), nil + +} + +func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain][]netip.Prefix) *RoutesSelectionDetails { var routeSelection []RoutesSelectionInfo for _, r := range routes { + domainList := make([]DomainInfo, 0) + for _, d := range r.Domains { + domainResp := DomainInfo{ + Domain: d.SafeString(), + } + if prefixes, exists := resolvedDomains[d]; exists { + var ipStrings []string + for _, prefix := range prefixes { + ipStrings = append(ipStrings, prefix.Addr().String()) + } + domainResp.ResolvedIPs = strings.Join(ipStrings, ", ") + } + domainList = append(domainList, domainResp) + } + domainDetails := DomainDetails{items: domainList} routeSelection = append(routeSelection, RoutesSelectionInfo{ ID: r.NetID, Network: r.Network.String(), + Domains: &domainDetails, Selected: r.Selected, }) } routeSelectionDetails := RoutesSelectionDetails{items: routeSelection} - return &routeSelectionDetails, nil + return &routeSelectionDetails } func (c *Client) SelectRoute(id string) error { diff --git a/client/ios/NetBirdSDK/routes.go b/client/ios/NetBirdSDK/routes.go index 63536255b..30d0d0d0a 100644 --- a/client/ios/NetBirdSDK/routes.go +++ b/client/ios/NetBirdSDK/routes.go @@ -16,9 +16,25 @@ type RoutesSelectionDetails struct { type RoutesSelectionInfo struct { ID string Network string + Domains *DomainDetails Selected bool } +type DomainCollection interface { + Add(s DomainInfo) DomainCollection + Get(i int) *DomainInfo + Size() int +} + +type DomainDetails struct { + items []DomainInfo +} + +type DomainInfo struct { + Domain string + ResolvedIPs string +} + // Add new PeerInfo to the collection func (array RoutesSelectionDetails) Add(s RoutesSelectionInfo) RoutesSelectionDetails { array.items = append(array.items, s) @@ -34,3 +50,16 @@ func (array RoutesSelectionDetails) Get(i int) *RoutesSelectionInfo { func (array RoutesSelectionDetails) Size() int { return len(array.items) } + +func (array DomainDetails) Add(s DomainInfo) DomainCollection { + array.items = append(array.items, s) + return array +} + +func (array DomainDetails) Get(i int) *DomainInfo { + return &array.items[i] +} + +func (array DomainDetails) Size() int { + return len(array.items) +} diff --git a/client/server/server.go b/client/server/server.go index 502d4168c..aa70f2404 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -582,7 +582,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes } // Down engine work in the daemon. -func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) { +func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) { s.mutex.Lock() defer s.mutex.Unlock() @@ -593,7 +593,25 @@ func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownRespo state := internal.CtxGetState(s.rootCtx) state.Set(internal.StatusIdle) - return &proto.DownResponse{}, nil + maxWaitTime := 5 * time.Second + timeout := time.After(maxWaitTime) + + engine := s.connectClient.Engine() + + for { + if !engine.IsWGIfaceUp() { + return &proto.DownResponse{}, nil + } + + select { + case <-ctx.Done(): + return &proto.DownResponse{}, nil + case <-timeout: + return nil, fmt.Errorf("failed to shut down properly") + default: + time.Sleep(100 * time.Millisecond) + } + } } // Status returns the daemon status diff --git a/client/system/info_linux.go b/client/system/info_linux.go index d85a6faec..db58d913f 100644 --- a/client/system/info_linux.go +++ b/client/system/info_linux.go @@ -8,6 +8,7 @@ import ( "context" "os" "os/exec" + "regexp" "runtime" "strings" "time" @@ -89,9 +90,17 @@ func _getInfo() string { func sysInfo() (serialNumber string, productName string, manufacturer string) { var si sysinfo.SysInfo si.GetSysInfo() + isascii := regexp.MustCompile("^[[:ascii:]]+$") serial := si.Chassis.Serial if (serial == "Default string" || serial == "") && si.Product.Serial != "" { serial = si.Product.Serial } - return serial, si.Product.Name, si.Product.Vendor + if (!isascii.MatchString(serial)) && si.Board.Serial != "" { + serial = si.Board.Serial + } + name := si.Product.Name + if (!isascii.MatchString(name)) && si.Board.Name != "" { + name = si.Board.Name + } + return serial, name, si.Product.Vendor } diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index cadd14f18..58004dd4a 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -15,7 +15,6 @@ import ( "strconv" "strings" "sync" - "syscall" "time" "unicode" @@ -34,6 +33,7 @@ import ( "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/system" + "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/version" ) @@ -62,8 +62,25 @@ func main() { var errorMSG string flag.StringVar(&errorMSG, "error-msg", "", "displays a error message window") + tmpDir := "/tmp" + if runtime.GOOS == "windows" { + tmpDir = os.TempDir() + } + + var saveLogsInFile bool + flag.BoolVar(&saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", tmpDir)) + flag.Parse() + if saveLogsInFile { + logFile := path.Join(tmpDir, fmt.Sprintf("netbird-ui-%d.log", os.Getpid())) + err := util.InitLog("trace", logFile) + if err != nil { + log.Errorf("error while initializing log: %v", err) + return + } + } + a := app.NewWithID("NetBird") a.SetIcon(fyne.NewStaticResource("netbird", iconDisconnectedPNG)) @@ -76,8 +93,12 @@ func main() { if showSettings || showRoutes { a.Run() } else { - if err := checkPIDFile(); err != nil { - log.Errorf("check PID file: %v", err) + running, err := isAnotherProcessRunning() + if err != nil { + log.Errorf("error while checking process: %v", err) + } + if running { + log.Warn("another process is running") return } client.setDefaultFonts() @@ -861,104 +882,3 @@ func openURL(url string) error { } return err } - -// checkPIDFile exists and return error, or write new. -func checkPIDFile() error { - pidFile := path.Join(os.TempDir(), "wiretrustee-ui.pid") - if piddata, err := os.ReadFile(pidFile); err == nil { - if pid, err := strconv.Atoi(string(piddata)); err == nil { - if process, err := os.FindProcess(pid); err == nil { - if err := process.Signal(syscall.Signal(0)); err == nil { - return fmt.Errorf("process already exists: %d", pid) - } - } - } - } - - return os.WriteFile(pidFile, []byte(fmt.Sprintf("%d", os.Getpid())), 0o664) //nolint:gosec -} - -func (s *serviceClient) setDefaultFonts() { - var ( - defaultFontPath string - ) - - //TODO: Linux Multiple Language Support - switch runtime.GOOS { - case "darwin": - defaultFontPath = "/Library/Fonts/Arial Unicode.ttf" - case "windows": - fontPath := s.getWindowsFontFilePath() - defaultFontPath = fontPath - } - - _, err := os.Stat(defaultFontPath) - - if err == nil { - os.Setenv("FYNE_FONT", defaultFontPath) - } -} - -func (s *serviceClient) getWindowsFontFilePath() (fontPath string) { - /* - https://learn.microsoft.com/en-us/windows/apps/design/globalizing/loc-international-fonts - https://learn.microsoft.com/en-us/typography/fonts/windows_11_font_list - */ - - var ( - fontFolder string = "C:/Windows/Fonts" - fontMapping = map[string]string{ - "default": "Segoeui.ttf", - "zh-CN": "Msyh.ttc", - "am-ET": "Ebrima.ttf", - "nirmala": "Nirmala.ttf", - "chr-CHER-US": "Gadugi.ttf", - "zh-HK": "Msjh.ttc", - "zh-TW": "Msjh.ttc", - "ja-JP": "Yugothm.ttc", - "km-KH": "Leelawui.ttf", - "ko-KR": "Malgun.ttf", - "th-TH": "Leelawui.ttf", - "ti-ET": "Ebrima.ttf", - } - nirMalaLang = []string{ - "as-IN", - "bn-BD", - "bn-IN", - "gu-IN", - "hi-IN", - "kn-IN", - "kok-IN", - "ml-IN", - "mr-IN", - "ne-NP", - "or-IN", - "pa-IN", - "si-LK", - "ta-IN", - "te-IN", - } - ) - cmd := exec.Command("powershell", "-Command", "(Get-Culture).Name") - output, err := cmd.Output() - if err != nil { - log.Errorf("Failed to get Windows default language setting: %v", err) - fontPath = path.Join(fontFolder, fontMapping["default"]) - return - } - defaultLanguage := strings.TrimSpace(string(output)) - - for _, lang := range nirMalaLang { - if defaultLanguage == lang { - fontPath = path.Join(fontFolder, fontMapping["nirmala"]) - return - } - } - - if font, ok := fontMapping[defaultLanguage]; ok { - fontPath = path.Join(fontFolder, font) - } else { - fontPath = path.Join(fontFolder, fontMapping["default"]) - } - return -} diff --git a/client/ui/font_bsd.go b/client/ui/font_bsd.go new file mode 100644 index 000000000..41bccceca --- /dev/null +++ b/client/ui/font_bsd.go @@ -0,0 +1,26 @@ +//go:build darwin || dragonfly || freebsd || netbsd || openbsd + +package main + +import ( + "os" + "runtime" + + log "github.com/sirupsen/logrus" +) + +const defaultFontPath = "/Library/Fonts/Arial Unicode.ttf" + +func (s *serviceClient) setDefaultFonts() { + // TODO: add other bsd paths + if runtime.GOOS != "darwin" { + return + } + + if _, err := os.Stat(defaultFontPath); err != nil { + log.Errorf("Failed to find default font file: %v", err) + return + } + + os.Setenv("FYNE_FONT", defaultFontPath) +} diff --git a/client/ui/font_linux.go b/client/ui/font_linux.go new file mode 100644 index 000000000..4aa92494a --- /dev/null +++ b/client/ui/font_linux.go @@ -0,0 +1,7 @@ +//go:build !386 + +package main + +func (s *serviceClient) setDefaultFonts() { + //TODO: Linux Multiple Language Support +} diff --git a/client/ui/font_windows.go b/client/ui/font_windows.go new file mode 100644 index 000000000..c37a5455f --- /dev/null +++ b/client/ui/font_windows.go @@ -0,0 +1,91 @@ +package main + +import ( + "os" + "path" + "unsafe" + + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" +) + +func (s *serviceClient) setDefaultFonts() { + defaultFontPath := s.getWindowsFontFilePath() + + if _, err := os.Stat(defaultFontPath); err != nil { + log.Errorf("Failed to find default font file: %v", err) + return + } + + os.Setenv("FYNE_FONT", defaultFontPath) +} + +func (s *serviceClient) getWindowsFontFilePath() string { + var ( + fontFolder = "C:/Windows/Fonts" + fontMapping = map[string]string{ + "default": "Segoeui.ttf", + "zh-CN": "Msyh.ttc", + "am-ET": "Ebrima.ttf", + "nirmala": "Nirmala.ttf", + "chr-CHER-US": "Gadugi.ttf", + "zh-HK": "Msjh.ttc", + "zh-TW": "Msjh.ttc", + "ja-JP": "Yugothm.ttc", + "km-KH": "Leelawui.ttf", + "ko-KR": "Malgun.ttf", + "th-TH": "Leelawui.ttf", + "ti-ET": "Ebrima.ttf", + } + nirMalaLang = []string{ + "as-IN", + "bn-BD", + "bn-IN", + "gu-IN", + "hi-IN", + "kn-IN", + "kok-IN", + "ml-IN", + "mr-IN", + "ne-NP", + "or-IN", + "pa-IN", + "si-LK", + "ta-IN", + "te-IN", + } + ) + + // getUserDefaultLocaleName.Call() panics if the func is not found + defer func() { + if r := recover(); r != nil { + log.Errorf("Recovered from panic: %v", r) + } + }() + + kernel32 := windows.NewLazySystemDLL("kernel32.dll") + getUserDefaultLocaleName := kernel32.NewProc("GetUserDefaultLocaleName") + + buf := make([]uint16, 85) // LOCALE_NAME_MAX_LENGTH is usually 85 + r, _, err := getUserDefaultLocaleName.Call(uintptr(unsafe.Pointer(&buf[0])), uintptr(len(buf))) + // returns 0 on failure, err is always non-nil + // https://learn.microsoft.com/en-us/windows/win32/api/winnls/nf-winnls-getuserdefaultlocalename + if r == 0 { + log.Errorf("GetUserDefaultLocaleName call failed: %v", err) + return path.Join(fontFolder, fontMapping["default"]) + } + + defaultLanguage := windows.UTF16ToString(buf) + + for _, lang := range nirMalaLang { + if defaultLanguage == lang { + return path.Join(fontFolder, fontMapping["nirmala"]) + } + } + + if font, ok := fontMapping[defaultLanguage]; ok { + return path.Join(fontFolder, font) + } + + return path.Join(fontFolder, fontMapping["default"]) +} diff --git a/client/ui/process.go b/client/ui/process.go new file mode 100644 index 000000000..bcb3dd879 --- /dev/null +++ b/client/ui/process.go @@ -0,0 +1,37 @@ +package main + +import ( + "os" + "path/filepath" + "strings" + + "github.com/shirou/gopsutil/v3/process" +) + +func isAnotherProcessRunning() (bool, error) { + processes, err := process.Processes() + if err != nil { + return false, err + } + + pid := os.Getpid() + processName := strings.ToLower(filepath.Base(os.Args[0])) + + for _, p := range processes { + if int(p.Pid) == pid { + continue + } + + runningProcessPath, err := p.Exe() + // most errors are related to short-lived processes + if err != nil { + continue + } + + if strings.Contains(strings.ToLower(runningProcessPath), processName) && isProcessOwnedByCurrentUser(p) { + return true, nil + } + } + + return false, nil +} diff --git a/client/ui/process_nonwindows.go b/client/ui/process_nonwindows.go new file mode 100644 index 000000000..0d17be2be --- /dev/null +++ b/client/ui/process_nonwindows.go @@ -0,0 +1,26 @@ +//go:build !windows + +package main + +import ( + "os" + + "github.com/shirou/gopsutil/v3/process" + log "github.com/sirupsen/logrus" +) + +func isProcessOwnedByCurrentUser(p *process.Process) bool { + currentUserID := os.Getuid() + uids, err := p.Uids() + if err != nil { + log.Errorf("get process uids: %v", err) + return false + } + for _, id := range uids { + log.Debugf("checking process uid: %d", id) + if int(id) == currentUserID { + return true + } + } + return false +} diff --git a/client/ui/process_windows.go b/client/ui/process_windows.go new file mode 100644 index 000000000..b15b0ed24 --- /dev/null +++ b/client/ui/process_windows.go @@ -0,0 +1,24 @@ +package main + +import ( + "os/user" + + "github.com/shirou/gopsutil/v3/process" + log "github.com/sirupsen/logrus" +) + +func isProcessOwnedByCurrentUser(p *process.Process) bool { + processUsername, err := p.Username() + if err != nil { + log.Errorf("get process username error: %v", err) + return false + } + + currUser, err := user.Current() + if err != nil { + log.Errorf("get current user error: %v", err) + return false + } + + return processUsername == currUser.Username +} diff --git a/go.mod b/go.mod index 94cf5000d..251eaf777 100644 --- a/go.mod +++ b/go.mod @@ -19,12 +19,12 @@ require ( github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 github.com/vishvananda/netlink v1.2.1-beta.2 - golang.org/x/crypto v0.23.0 - golang.org/x/sys v0.20.0 + golang.org/x/crypto v0.24.0 + golang.org/x/sys v0.21.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 - google.golang.org/grpc v1.64.0 + google.golang.org/grpc v1.64.1 google.golang.org/protobuf v1.34.1 gopkg.in/natefinch/lumberjack.v2 v2.0.0 ) @@ -85,10 +85,10 @@ require ( goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028 - golang.org/x/net v0.25.0 + golang.org/x/net v0.26.0 golang.org/x/oauth2 v0.19.0 golang.org/x/sync v0.7.0 - golang.org/x/term v0.20.0 + golang.org/x/term v0.21.0 google.golang.org/api v0.177.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/postgres v1.5.7 diff --git a/go.sum b/go.sum index b0d239035..9f6ae4a76 100644 --- a/go.sum +++ b/go.sum @@ -545,8 +545,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= -golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= +golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= +golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= @@ -592,8 +592,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= -golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= -golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= +golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= +golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE= golang.org/x/oauth2 v0.19.0 h1:9+E/EZBCbTLNrbN35fHv/a/d/mOBatymz1zbtQrXpIg= @@ -650,8 +650,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= -golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= +golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -659,8 +659,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= -golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw= -golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= +golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA= +golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= @@ -719,8 +719,8 @@ google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyac google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= -google.golang.org/grpc v1.64.0 h1:KH3VH9y/MgNQg1dE7b3XfVK0GsPSIzJwdF617gUSbvY= -google.golang.org/grpc v1.64.0/go.mod h1:oxjF8E3FBnjp+/gVFYdWacaLDx9na1aqy9oovLpxQYg= +google.golang.org/grpc v1.64.1 h1:LKtvyfbX3UGVPFcGqJ9ItpVWW6oN/2XqTxfAnwRRXiA= +google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= diff --git a/iface/tun_android.go b/iface/tun_android.go index dc6abea36..504993094 100644 --- a/iface/tun_android.go +++ b/iface/tun_android.go @@ -64,7 +64,7 @@ func (t *wgTunDevice) Create(routes []string, dns string, searchDomains []string t.wrapper = newDeviceWrapper(tunDevice) log.Debugf("attaching to interface %v", name) - t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] ")) + t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] ")) // without this property mobile devices can discover remote endpoints if the configured one was wrong. // this helps with support for the older NetBird clients that had a hardcoded direct mode // t.device.DisableSomeRoamingForBrokenMobileSemantics() diff --git a/iface/tun_darwin.go b/iface/tun_darwin.go index 7d684f52e..364e5dfad 100644 --- a/iface/tun_darwin.go +++ b/iface/tun_darwin.go @@ -49,7 +49,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) { t.device = device.NewDevice( t.wrapper, t.iceBind, - device.NewLogger(device.LogLevelSilent, "[netbird] "), + device.NewLogger(wgLogLevel(), "[netbird] "), ) err = t.assignAddr() diff --git a/iface/tun_ios.go b/iface/tun_ios.go index 83e26e08d..6d53cc333 100644 --- a/iface/tun_ios.go +++ b/iface/tun_ios.go @@ -64,7 +64,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) { t.wrapper = newDeviceWrapper(tunDevice) log.Debug("Attaching to interface") - t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] ")) + t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] ")) // without this property mobile devices can discover remote endpoints if the configured one was wrong. // this helps with support for the older NetBird clients that had a hardcoded direct mode // t.device.DisableSomeRoamingForBrokenMobileSemantics() diff --git a/iface/tun_netstack.go b/iface/tun_netstack.go index beb3acc3f..df2f75c45 100644 --- a/iface/tun_netstack.go +++ b/iface/tun_netstack.go @@ -54,7 +54,7 @@ func (t *tunNetstackDevice) Create() (wgConfigurer, error) { t.device = device.NewDevice( t.wrapper, t.iceBind, - device.NewLogger(device.LogLevelSilent, "[netbird] "), + device.NewLogger(wgLogLevel(), "[netbird] "), ) t.configurer = newWGUSPConfigurer(t.device, t.name) diff --git a/iface/tun_usp_unix.go b/iface/tun_usp_unix.go index b18794b25..814c9ca89 100644 --- a/iface/tun_usp_unix.go +++ b/iface/tun_usp_unix.go @@ -57,7 +57,7 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) { t.device = device.NewDevice( t.wrapper, t.iceBind, - device.NewLogger(device.LogLevelSilent, "[netbird] "), + device.NewLogger(wgLogLevel(), "[netbird] "), ) err = t.assignAddr() diff --git a/iface/tun_windows.go b/iface/tun_windows.go index 5c77f1d16..0d658059f 100644 --- a/iface/tun_windows.go +++ b/iface/tun_windows.go @@ -41,6 +41,7 @@ func newTunDevice(name string, address WGAddress, port int, key string, mtu int, } func (t *tunDevice) Create() (wgConfigurer, error) { + log.Info("create tun interface") tunDevice, err := tun.CreateTUN(t.name, t.mtu) if err != nil { return nil, err @@ -52,7 +53,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) { t.device = device.NewDevice( t.wrapper, t.iceBind, - device.NewLogger(device.LogLevelSilent, "[netbird] "), + device.NewLogger(wgLogLevel(), "[netbird] "), ) luid := winipcfg.LUID(t.nativeTunDevice.LUID()) diff --git a/iface/wg_log.go b/iface/wg_log.go new file mode 100644 index 000000000..b44f6fc0b --- /dev/null +++ b/iface/wg_log.go @@ -0,0 +1,15 @@ +package iface + +import ( + "os" + + "golang.zx2c4.com/wireguard/device" +) + +func wgLogLevel() int { + if os.Getenv("NB_WG_DEBUG") == "true" { + return device.LogLevelVerbose + } else { + return device.LogLevelSilent + } +} diff --git a/management/client/grpc.go b/management/client/grpc.go index a8f4a91c7..eaadcd317 100644 --- a/management/client/grpc.go +++ b/management/client/grpc.go @@ -2,7 +2,6 @@ package client import ( "context" - "crypto/tls" "fmt" "io" "sync" @@ -11,15 +10,11 @@ import ( "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" + "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" "google.golang.org/grpc/connectivity" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/keepalive" - - "github.com/cenkalti/backoff/v4" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/encryption" @@ -51,26 +46,21 @@ type GrpcClient struct { // NewClient creates a new client to Management service func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) { - transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) + var conn *grpc.ClientConn - if tlsEnabled { - transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})) + operation := func() error { + var err error + conn, err = nbgrpc.CreateConnection(addr, tlsEnabled) + if err != nil { + log.Printf("createConnection error: %v", err) + return err + } + return nil } - mgmCtx, cancel := context.WithTimeout(ctx, ConnectTimeout) - defer cancel() - conn, err := grpc.DialContext( - mgmCtx, - addr, - transportOption, - nbgrpc.WithCustomDialer(), - grpc.WithBlock(), - grpc.WithKeepaliveParams(keepalive.ClientParameters{ - Time: 30 * time.Second, - Timeout: 10 * time.Second, - })) + err := backoff.Retry(operation, nbgrpc.Backoff(ctx)) if err != nil { - log.Errorf("failed creating connection to Management Service %v", err) + log.Errorf("failed creating connection to Management Service: %v", err) return nil, err } @@ -326,25 +316,44 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro if !c.ready() { return nil, fmt.Errorf(errMsgNoMgmtConnection) } + loginReq, err := encryption.EncryptMessage(serverKey, c.key, req) if err != nil { log.Errorf("failed to encrypt message: %s", err) return nil, err } - mgmCtx, cancel := context.WithTimeout(c.ctx, ConnectTimeout) - defer cancel() - resp, err := c.realClient.Login(mgmCtx, &proto.EncryptedMessage{ - WgPubKey: c.key.PublicKey().String(), - Body: loginReq, - }) + + var resp *proto.EncryptedMessage + operation := func() error { + mgmCtx, cancel := context.WithTimeout(context.Background(), ConnectTimeout) + defer cancel() + + var err error + resp, err = c.realClient.Login(mgmCtx, &proto.EncryptedMessage{ + WgPubKey: c.key.PublicKey().String(), + Body: loginReq, + }) + if err != nil { + // retry only on context canceled + if s, ok := gstatus.FromError(err); ok && s.Code() == codes.Canceled { + return err + } + return backoff.Permanent(err) + } + + return nil + } + + err = backoff.Retry(operation, nbgrpc.Backoff(c.ctx)) if err != nil { + log.Errorf("failed to login to Management Service: %v", err) return nil, err } loginResp := &proto.LoginResponse{} err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp) if err != nil { - log.Errorf("failed to decrypt registration message: %s", err) + log.Errorf("failed to decrypt login response: %s", err) return nil, err } diff --git a/management/server/account.go b/management/server/account.go index 27c21e402..558de6fbb 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -69,6 +69,7 @@ type AccountManager interface { ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) + SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) @@ -95,6 +96,7 @@ type AccountManager interface { GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) SaveGroup(ctx context.Context, accountID, userID string, group *nbgroup.Group) error + SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error DeleteGroup(ctx context.Context, accountId, userId, groupID string) error ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error) GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error diff --git a/management/server/file_store.go b/management/server/file_store.go index 3fd543797..c649602e2 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -746,3 +746,11 @@ func (s *FileStore) Close(ctx context.Context) error { func (s *FileStore) GetStoreEngine() StoreEngine { return FileStoreEngine } + +func (s *FileStore) SaveUsers(accountID string, users map[string]*User) error { + return status.Errorf(status.Internal, "SaveUsers is not implemented") +} + +func (s *FileStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error { + return status.Errorf(status.Internal, "SaveGroups is not implemented") +} diff --git a/management/server/group.go b/management/server/group.go index ea512924b..45c51bda2 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -112,61 +112,85 @@ func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error { unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() + return am.SaveGroups(ctx, accountID, userID, []*nbgroup.Group{newGroup}) +} +// SaveGroups adds new groups to the account. +// Note: This function does not acquire the global lock. +// It is the caller's responsibility to ensure proper locking is in place before invoking this method. +func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error { account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } - if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { - return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued) - } + var eventsToStore []func() - if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { + for _, newGroup := range newGroups { + if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { + return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued) + } - existingGroup, err := account.FindGroupByName(newGroup.Name) - if err != nil { - s, ok := status.FromError(err) - if !ok || s.ErrorType != status.NotFound { - return err + if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { + existingGroup, err := account.FindGroupByName(newGroup.Name) + if err != nil { + s, ok := status.FromError(err) + if !ok || s.ErrorType != status.NotFound { + return err + } + } + + // Avoid duplicate groups only for the API issued groups. + // Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of. + if existingGroup != nil { + return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name) + } + + newGroup.ID = xid.New().String() + } + + for _, peerID := range newGroup.Peers { + if account.Peers[peerID] == nil { + return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) } } - // avoid duplicate groups only for the API issued groups. Integration or JWT groups can be duplicated as they are - // coming from the IdP that we don't have control of. - if existingGroup != nil { - return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name) - } + oldGroup := account.Groups[newGroup.ID] + account.Groups[newGroup.ID] = newGroup - newGroup.ID = xid.New().String() + events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account) + eventsToStore = append(eventsToStore, events...) } - for _, peerID := range newGroup.Peers { - if account.Peers[peerID] == nil { - return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) - } - } - - oldGroup, exists := account.Groups[newGroup.ID] - account.Groups[newGroup.ID] = newGroup - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + if err = am.Store.SaveGroups(account.Id, account.Groups); err != nil { return err } am.updateAccountPeers(ctx, account) - // the following snippet tracks the activity and stores the group events in the event store. - // It has to happen after all the operations have been successfully performed. + for _, storeEvent := range eventsToStore { + storeEvent() + } + + return nil +} + +// prepareGroupEvents prepares a list of event functions to be stored. +func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() { + var eventsToStore []func() + addedPeers := make([]string, 0) removedPeers := make([]string, 0) - if exists { + + if oldGroup != nil { addedPeers = difference(newGroup.Peers, oldGroup.Peers) removedPeers = difference(oldGroup.Peers, newGroup.Peers) } else { addedPeers = append(addedPeers, newGroup.Peers...) - am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta()) + eventsToStore = append(eventsToStore, func() { + am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta()) + }) } for _, p := range addedPeers { @@ -175,11 +199,14 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) continue } - am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, - map[string]any{ - "group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(), - "peer_fqdn": peer.FQDN(am.GetDNSDomain()), - }) + peerCopy := peer // copy to avoid closure issues + eventsToStore = append(eventsToStore, func() { + am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer, + map[string]any{ + "group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(), + "peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()), + }) + }) } for _, p := range removedPeers { @@ -188,14 +215,17 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) continue } - am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, - map[string]any{ - "group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(), - "peer_fqdn": peer.FQDN(am.GetDNSDomain()), - }) + peerCopy := peer // copy to avoid closure issues + eventsToStore = append(eventsToStore, func() { + am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer, + map[string]any{ + "group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(), + "peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()), + }) + }) } - return nil + return eventsToStore } // difference returns the elements in `a` that aren't in `b`. diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 177088ac5..25bcdfcee 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -40,6 +40,7 @@ type MockAccountManager struct { GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*group.Group, error) GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*group.Group, error) SaveGroupFunc func(ctx context.Context, accountID, userID string, group *group.Group) error + SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error) GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error @@ -64,6 +65,7 @@ type MockAccountManager struct { ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*server.SetupKey, error) SaveUserFunc func(ctx context.Context, accountID, userID string, user *server.User) (*server.UserInfo, error) SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) + SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*server.User, addIfNotExists bool) ([]*server.UserInfo, error) DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error @@ -308,6 +310,14 @@ func (am *MockAccountManager) SaveGroup(ctx context.Context, accountID, userID s return status.Errorf(codes.Unimplemented, "method SaveGroup is not implemented") } +// SaveGroups mock implementation of SaveGroups from server.AccountManager interface +func (am *MockAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*group.Group) error { + if am.SaveGroupsFunc != nil { + return am.SaveGroupsFunc(ctx, accountID, userID, groups) + } + return status.Errorf(codes.Unimplemented, "method SaveGroups is not implemented") +} + // DeleteGroup mock implementation of DeleteGroup from server.AccountManager interface func (am *MockAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error { if am.DeleteGroupFunc != nil { @@ -502,6 +512,14 @@ func (am *MockAccountManager) SaveOrAddUser(ctx context.Context, accountID, user return nil, status.Errorf(codes.Unimplemented, "method SaveOrAddUser is not implemented") } +// SaveOrAddUsers mocks SaveOrAddUsers of the AccountManager interface +func (am *MockAccountManager) SaveOrAddUsers(ctx context.Context, accountID, userID string, users []*server.User, addIfNotExists bool) ([]*server.UserInfo, error) { + if am.SaveOrAddUsersFunc != nil { + return am.SaveOrAddUsersFunc(ctx, accountID, userID, users, addIfNotExists) + } + return nil, status.Errorf(codes.Unimplemented, "method SaveOrAddUsers is not implemented") +} + // DeleteUser mocks DeleteUser of the AccountManager interface func (am *MockAccountManager) DeleteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error { if am.DeleteUserFunc != nil { diff --git a/management/server/sql_store.go b/management/server/sql_store.go index b5ae82828..37cc10d8b 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -274,10 +274,15 @@ func (s *SqlStore) GetInstallationID() string { func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { var peerCopy nbpeer.Peer peerCopy.Status = &peerStatus - result := s.db.Model(&nbpeer.Peer{}). - Where("account_id = ? AND id = ?", accountID, peerID). - Updates(peerCopy) + fieldsToUpdate := []string{ + "peer_status_last_seen", "peer_status_connected", + "peer_status_login_expired", "peer_status_required_approval", + } + result := s.db.Model(&nbpeer.Peer{}). + Select(fieldsToUpdate). + Where("account_id = ? AND id = ?", accountID, peerID). + Updates(&peerCopy) if result.Error != nil { return result.Error } @@ -311,6 +316,34 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P return nil } +// SaveUsers saves the given list of users to the database. +// It updates existing users if a conflict occurs. +func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error { + usersToSave := make([]User, 0, len(users)) + for _, user := range users { + user.AccountID = accountID + for id, pat := range user.PATs { + pat.ID = id + user.PATsG = append(user.PATsG, *pat) + } + usersToSave = append(usersToSave, *user) + } + return s.db.Session(&gorm.Session{FullSaveAssociations: true}). + Clauses(clause.OnConflict{UpdateAll: true}). + Create(&usersToSave).Error +} + +// SaveGroups saves the given list of groups to the database. +// It updates existing groups if a conflict occurs. +func (s *SqlStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error { + groupsToSave := make([]nbgroup.Group, 0, len(groups)) + for _, group := range groups { + group.AccountID = accountID + groupsToSave = append(groupsToSave, *group) + } + return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&groupsToSave).Error +} + // DeleteHashedPAT2TokenIDIndex is noop in SqlStore func (s *SqlStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error { return nil @@ -662,11 +695,7 @@ func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMe } file := filepath.Join(dataDir, storeStr) - db, err := gorm.Open(sqlite.Open(file), &gorm.Config{ - Logger: logger.Default.LogMode(logger.Silent), - CreateBatchSize: 400, - PrepareStmt: true, - }) + db, err := gorm.Open(sqlite.Open(file), getGormConfig()) if err != nil { return nil, err } @@ -676,10 +705,7 @@ func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMe // NewPostgresqlStore creates a new Postgres store. func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { - db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{ - Logger: logger.Default.LogMode(logger.Silent), - PrepareStmt: true, - }) + db, err := gorm.Open(postgres.Open(dsn), getGormConfig()) if err != nil { return nil, err } @@ -687,6 +713,14 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe return NewSqlStore(ctx, db, PostgresStoreEngine, metrics) } +func getGormConfig() *gorm.Config { + return &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + CreateBatchSize: 400, + PrepareStmt: true, + } +} + // newPostgresStore initializes a new Postgres store. func newPostgresStore(ctx context.Context, metrics telemetry.AppMetrics) (Store, error) { dsn, ok := os.LookupEnv(postgresDsnEnv) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index e0e893052..f46ca7e5d 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -41,11 +41,22 @@ func TestSqlite_NewStore(t *testing.T) { } func TestSqlite_SaveAccount_Large(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") + if runtime.GOOS != "linux" && os.Getenv("CI") == "true" || runtime.GOOS == "windows" { + t.Skip("skip large test on non-linux OS due to environment restrictions") } + t.Run("SQLite", func(t *testing.T) { + store := newSqliteStore(t) + runLargeTest(t, store) + }) + // create store outside to have a better time counter for the test + store := newPostgresqlStore(t) + t.Run("PostgreSQL", func(t *testing.T) { + runLargeTest(t, store) + }) +} - store := newSqliteStore(t) +func runLargeTest(t *testing.T, store Store) { + t.Helper() account := newAccountWithId(context.Background(), "account_id", "testuser", "") groupALL, err := account.GetGroupAll() @@ -54,7 +65,7 @@ func TestSqlite_SaveAccount_Large(t *testing.T) { } setupKey := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey - const numPerAccount = 2000 + const numPerAccount = 6000 for n := 0; n < numPerAccount; n++ { netIP := randomIPv4() peerID := fmt.Sprintf("%s-peer-%d", account.Id, n) @@ -362,7 +373,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { require.NoError(t, err) // save status of non-existing peer - newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()} + newStatus := nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()} err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus) assert.Error(t, err) parsedErr, ok := status.FromError(err) @@ -377,7 +388,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { IP: net.IP{127, 0, 0, 1}, Meta: nbpeer.PeerSystemMeta{}, Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}, + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } err = store.SaveAccount(context.Background(), account) diff --git a/management/server/store.go b/management/server/store.go index 05a09b3ee..3ba73e8c7 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -12,6 +12,7 @@ import ( "strings" "time" + nbgroup "github.com/netbirdio/netbird/management/server/group" log "github.com/sirupsen/logrus" "gorm.io/gorm" @@ -41,6 +42,8 @@ type Store interface { GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) SaveAccount(ctx context.Context, account *Account) error + SaveUsers(accountID string, users map[string]*User) error + SaveGroups(accountID string, groups map[string]*nbgroup.Group) error DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteTokenID2UserIDIndex(tokenID string) error GetInstallationID() string diff --git a/management/server/user.go b/management/server/user.go index 302cfccaa..65b5c7878 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -740,7 +740,7 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin return pats, nil } -// SaveUser saves updates to the given user. If the user doesn't exit it will throw status.NotFound error. +// SaveUser saves updates to the given user. If the user doesn't exist, it will throw status.NotFound error. func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error) { return am.SaveOrAddUser(ctx, accountID, initiatorUserID, update, false) // false means do not create user and throw status.NotFound } @@ -748,11 +748,31 @@ func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initia // SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist // Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now. func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) { + if update == nil { + return nil, status.Errorf(status.InvalidArgument, "provided user update is nil") + } + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - if update == nil { - return nil, status.Errorf(status.InvalidArgument, "provided user update is nil") + updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*User{update}, addIfNotExists) + if err != nil { + return nil, err + } + + if len(updatedUsers) == 0 { + return nil, status.Errorf(status.Internal, "user was not updated") + } + + return updatedUsers[0], nil +} + +// SaveOrAddUsers updates existing users or adds new users to the account. +// Note: This function does not acquire the global lock. +// It is the caller's responsibility to ensure proper locking is in place before invoking this method. +func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) { + if len(updates) == 0 { + return nil, nil //nolint:nilnil } account, err := am.Store.GetAccount(ctx, accountID) @@ -769,144 +789,200 @@ func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, i return nil, status.Errorf(status.PermissionDenied, "only users with admin power are authorized to perform user update operations") } - oldUser := account.Users[update.Id] - if oldUser == nil { - if !addIfNotExists { - return nil, status.Errorf(status.NotFound, "user to update doesn't exist") + updatedUsers := make([]*UserInfo, 0, len(updates)) + var ( + expiredPeers []*nbpeer.Peer + eventsToStore []func() + ) + + for _, update := range updates { + if update == nil { + return nil, status.Errorf(status.InvalidArgument, "provided user update is nil") } - // when addIfNotExists is set to true the newUser will use all fields from the update input - oldUser = update - } - if initiatorUser.HasAdminPower() && initiatorUserID == update.Id && oldUser.Blocked != update.Blocked { - return nil, status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves") - } - - if initiatorUser.HasAdminPower() && initiatorUserID == update.Id && update.Role != initiatorUser.Role { - return nil, status.Errorf(status.PermissionDenied, "admins can't change their role") - } - - if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.Role != oldUser.Role { - return nil, status.Errorf(status.PermissionDenied, "only owners can remove owner role from their user") - } - - if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() { - return nil, status.Errorf(status.PermissionDenied, "unable to block owner user") - } - - if initiatorUser.Role == UserRoleAdmin && update.Role == UserRoleOwner && update.Role != oldUser.Role { - return nil, status.Errorf(status.PermissionDenied, "only owners can add owner role to other users") - } - - if oldUser.IsServiceUser && update.Role == UserRoleOwner { - return nil, status.Errorf(status.PermissionDenied, "can't update a service user with owner role") - } - - transferedOwnerRole := false - if initiatorUser.Role == UserRoleOwner && initiatorUserID != update.Id && update.Role == UserRoleOwner { - newInitiatorUser := initiatorUser.Copy() - newInitiatorUser.Role = UserRoleAdmin - account.Users[initiatorUserID] = newInitiatorUser - transferedOwnerRole = true - } - - // only auto groups, revoked status, and integration reference can be updated for now - newUser := oldUser.Copy() - newUser.Role = update.Role - newUser.Blocked = update.Blocked - // these two fields can't be set via API, only via direct call to the method - newUser.Issued = update.Issued - newUser.IntegrationReference = update.IntegrationReference - - for _, newGroupID := range update.AutoGroups { - if _, ok := account.Groups[newGroupID]; !ok { - return nil, status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist", - newGroupID, update.Id) + oldUser := account.Users[update.Id] + if oldUser == nil { + if !addIfNotExists { + return nil, status.Errorf(status.NotFound, "user to update doesn't exist: %s", update.Id) + } + // when addIfNotExists is set to true, the newUser will use all fields from the update input + oldUser = update } - } - newUser.AutoGroups = update.AutoGroups - account.Users[newUser.Id] = newUser - - if !oldUser.IsBlocked() && update.IsBlocked() { - // expire peers that belong to the user who's getting blocked - blockedPeers, err := account.FindUserPeers(update.Id) - if err != nil { + if err := validateUserUpdate(account, initiatorUser, oldUser, update); err != nil { return nil, err } - if err := am.expireAndUpdatePeers(ctx, account, blockedPeers); err != nil { + // only auto groups, revoked status, and integration reference can be updated for now + newUser := oldUser.Copy() + newUser.Role = update.Role + newUser.Blocked = update.Blocked + newUser.AutoGroups = update.AutoGroups + // these two fields can't be set via API, only via direct call to the method + newUser.Issued = update.Issued + newUser.IntegrationReference = update.IntegrationReference + + transferredOwnerRole := handleOwnerRoleTransfer(account, initiatorUser, update) + account.Users[newUser.Id] = newUser + + if !oldUser.IsBlocked() && update.IsBlocked() { + // expire peers that belong to the user who's getting blocked + blockedPeers, err := account.FindUserPeers(update.Id) + if err != nil { + return nil, err + } + expiredPeers = append(expiredPeers, blockedPeers...) + } + + if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled { + removedGroups := difference(oldUser.AutoGroups, update.AutoGroups) + // need force update all auto groups in any case they will not be duplicated + account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...) + account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...) + } + + events := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole) + eventsToStore = append(eventsToStore, events...) + + updatedUserInfo, err := getUserInfo(ctx, am, newUser, account) + if err != nil { + return nil, err + } + updatedUsers = append(updatedUsers, updatedUserInfo) + } + + if len(expiredPeers) > 0 { + if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil { log.WithContext(ctx).Errorf("failed update expired peers: %s", err) return nil, err } } - if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled { - removedGroups := difference(oldUser.AutoGroups, update.AutoGroups) - // need force update all auto groups in any case they will not be duplicated - account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...) - account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...) - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return nil, err - } + account.Network.IncSerial() + if err = am.Store.SaveUsers(account.Id, account.Users); err != nil { + return nil, err + } + if account.Settings.GroupsPropagationEnabled { am.updateAccountPeers(ctx, account) - } else { - if err = am.Store.SaveAccount(ctx, account); err != nil { - return nil, err + } + + for _, storeEvent := range eventsToStore { + storeEvent() + } + + return updatedUsers, nil +} + +// prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data. +func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, transferredOwnerRole bool) []func() { + var eventsToStore []func() + + if oldUser.IsBlocked() != newUser.IsBlocked() { + if newUser.IsBlocked() { + eventsToStore = append(eventsToStore, func() { + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserBlocked, nil) + }) + } else { + eventsToStore = append(eventsToStore, func() { + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserUnblocked, nil) + }) } } - defer func() { - if oldUser.IsBlocked() != update.IsBlocked() { - if update.IsBlocked() { - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserBlocked, nil) + switch { + case transferredOwnerRole: + eventsToStore = append(eventsToStore, func() { + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.TransferredOwnerRole, nil) + }) + case oldUser.Role != newUser.Role: + eventsToStore = append(eventsToStore, func() { + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserRoleUpdated, map[string]any{"role": newUser.Role}) + }) + } + + if newUser.AutoGroups != nil { + removedGroups := difference(oldUser.AutoGroups, newUser.AutoGroups) + addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups) + for _, g := range removedGroups { + group := account.GetGroup(g) + if group != nil { + eventsToStore = append(eventsToStore, func() { + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupRemovedFromUser, + map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) + }) + } else { - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserUnblocked, nil) + log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id) } } - - switch { - case transferedOwnerRole: - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.TransferredOwnerRole, nil) - case oldUser.Role != newUser.Role: - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role}) - default: - } - - if update.AutoGroups != nil { - removedGroups := difference(oldUser.AutoGroups, update.AutoGroups) - addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups) - for _, g := range removedGroups { - group := account.GetGroup(g) - if group != nil { - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupRemovedFromUser, + for _, g := range addedGroups { + group := account.GetGroup(g) + if group != nil { + eventsToStore = append(eventsToStore, func() { + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser, map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) - } else { - log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id) - } - } - - for _, g := range addedGroups { - group := account.GetGroup(g) - if group != nil { - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupAddedToUser, - map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) - } + }) } } - }() + } - if !isNil(am.idpManager) && !newUser.IsServiceUser { - userData, err := am.lookupUserInCache(ctx, newUser.Id, account) + return eventsToStore +} + +func handleOwnerRoleTransfer(account *Account, initiatorUser, update *User) bool { + if initiatorUser.Role == UserRoleOwner && initiatorUser.Id != update.Id && update.Role == UserRoleOwner { + newInitiatorUser := initiatorUser.Copy() + newInitiatorUser.Role = UserRoleAdmin + account.Users[initiatorUser.Id] = newInitiatorUser + return true + } + return false +} + +// getUserInfo retrieves the UserInfo for a given User and Account. +// If the AccountManager has a non-nil idpManager and the User is not a service user, +// it will attempt to look up the UserData from the cache. +func getUserInfo(ctx context.Context, am *DefaultAccountManager, user *User, account *Account) (*UserInfo, error) { + if !isNil(am.idpManager) && !user.IsServiceUser { + userData, err := am.lookupUserInCache(ctx, user.Id, account) if err != nil { return nil, err } - return newUser.ToUserInfo(userData, account.Settings) + return user.ToUserInfo(userData, account.Settings) } - return newUser.ToUserInfo(nil, account.Settings) + return user.ToUserInfo(nil, account.Settings) +} + +// validateUserUpdate validates the update operation for a user. +func validateUserUpdate(account *Account, initiatorUser, oldUser, update *User) error { + if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked { + return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves") + } + if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && update.Role != initiatorUser.Role { + return status.Errorf(status.PermissionDenied, "admins can't change their role") + } + if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.Role != oldUser.Role { + return status.Errorf(status.PermissionDenied, "only owners can remove owner role from their user") + } + if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() { + return status.Errorf(status.PermissionDenied, "unable to block owner user") + } + if initiatorUser.Role == UserRoleAdmin && update.Role == UserRoleOwner && update.Role != oldUser.Role { + return status.Errorf(status.PermissionDenied, "only owners can add owner role to other users") + } + if oldUser.IsServiceUser && update.Role == UserRoleOwner { + return status.Errorf(status.PermissionDenied, "can't update a service user with owner role") + } + + for _, newGroupID := range update.AutoGroups { + if _, ok := account.Groups[newGroupID]; !ok { + return status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist", + newGroupID, update.Id) + } + } + + return nil } // GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist @@ -937,7 +1013,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, u userObj := account.Users[userID] - if account.Domain != lowerDomain && userObj.Role == UserRoleOwner { + if lowerDomain != "" && account.Domain != lowerDomain && userObj.Role == UserRoleOwner { account.Domain = lowerDomain err = am.Store.SaveAccount(ctx, account) if err != nil { diff --git a/signal/README.md b/signal/README.md index dd2d761ad..9e3207cfa 100644 --- a/signal/README.md +++ b/signal/README.md @@ -18,6 +18,8 @@ Flags: --letsencrypt-domain string a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS --port int Server port to listen on (e.g. 10000) (default 10000) --ssl-dir string server ssl directory location. *Required only for Let's Encrypt certificates. (default "/var/lib/netbird/") + --cert-file string Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect + --cert-key string Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect Global Flags: --log-file string sets Netbird log path. If console is specified the the log will be output to stdout (default "/var/log/netbird/signal.log") @@ -90,6 +92,9 @@ The Signal Server exposes the following metrics in Prometheus format: - **registration_delay_milliseconds**: A Histogram metric that measures the time it took to register a peer in milliseconds. +- **get_registration_delay_milliseconds**: A Histogram metric that measures the time + it took to get a peer registration in + milliseconds. - **messages_forwarded_total**: A Counter metric that counts the total number of messages forwarded between peers. - **message_forward_failures_total**: A Counter metric that counts the total diff --git a/signal/client/grpc.go b/signal/client/grpc.go index c6f03ec86..7a3b502ff 100644 --- a/signal/client/grpc.go +++ b/signal/client/grpc.go @@ -2,7 +2,6 @@ package client import ( "context" - "crypto/tls" "fmt" "io" "sync" @@ -14,9 +13,6 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -64,28 +60,21 @@ func (c *GrpcClient) Close() error { // NewClient creates a new Signal client func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) { + var conn *grpc.ClientConn - transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) - - if tlsEnabled { - transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})) + operation := func() error { + var err error + conn, err = nbgrpc.CreateConnection(addr, tlsEnabled) + if err != nil { + log.Printf("createConnection error: %v", err) + return err + } + return nil } - sigCtx, cancel := context.WithTimeout(ctx, client.ConnectTimeout) - defer cancel() - conn, err := grpc.DialContext( - sigCtx, - addr, - transportOption, - nbgrpc.WithCustomDialer(), - grpc.WithBlock(), - grpc.WithKeepaliveParams(keepalive.ClientParameters{ - Time: 30 * time.Second, - Timeout: 10 * time.Second, - })) - + err := backoff.Retry(operation, nbgrpc.Backoff(ctx)) if err != nil { - log.Errorf("failed to connect to the signalling server %v", err) + log.Errorf("failed to connect to the signalling server: %v", err) return nil, err } @@ -408,7 +397,7 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient, if err != nil { log.Errorf("error while handling message of Peer [key: %s] error: [%s]", msg.Key, err.Error()) - //todo send something?? + // todo send something?? } } } diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 4b0dc583e..61f7a32a7 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -2,15 +2,12 @@ package cmd import ( "context" + "crypto/tls" "errors" "flag" "fmt" - "io" - "io/fs" "net" "net/http" - "os" - "path" "strings" "time" @@ -41,7 +38,8 @@ var ( signalLetsencryptDomain string signalSSLDir string defaultSignalSSLDir string - tlsEnabled bool + signalCertFile string + signalCertKey string signalKaep = grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{ MinTime: 5 * time.Second, @@ -56,12 +54,22 @@ var ( }) runCmd = &cobra.Command{ - Use: "run", - Short: "start NetBird Signal Server daemon", + Use: "run", + Short: "start NetBird Signal Server daemon", + SilenceUsage: true, PreRun: func(cmd *cobra.Command, args []string) { + err := util.InitLog(logLevel, logFile) + if err != nil { + log.Fatalf("failed initializing log %v", err) + } + + flag.Parse() + // detect whether user specified a port userPort := cmd.Flag("port").Changed - if signalLetsencryptDomain != "" { + + tlsEnabled := false + if signalLetsencryptDomain != "" || (signalCertFile != "" && signalCertKey != "") { tlsEnabled = true } @@ -77,33 +85,12 @@ var ( RunE: func(cmd *cobra.Command, args []string) error { flag.Parse() - err := util.InitLog(logLevel, logFile) + opts, certManager, err := getTLSConfigurations() if err != nil { - log.Fatalf("failed initializing log %v", err) + return err } - if signalSSLDir == "" { - oldPath := "/var/lib/wiretrustee" - if migrateToNetbird(oldPath, defaultSignalSSLDir) { - if err := cpDir(oldPath, defaultSignalSSLDir); err != nil { - log.Fatal(err) - } - } - } - - var opts []grpc.ServerOption - var certManager *autocert.Manager - if tlsEnabled { - // Let's encrypt enabled -> generate certificate automatically - certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain) - if err != nil { - return err - } - transportCredentials := credentials.NewTLS(certManager.TLSConfig()) - opts = append(opts, grpc.Creds(transportCredentials)) - } - - metricsServer := metrics.NewServer(metricsPort, "") + metricsServer, err := metrics.NewServer(metricsPort, "") if err != nil { return fmt.Errorf("setup metrics: %v", err) } @@ -124,7 +111,25 @@ var ( } proto.RegisterSignalExchangeServer(grpcServer, srv) + grpcRootHandler := grpcHandlerFunc(grpcServer) + + if certManager != nil { + startServerWithCertManager(certManager, grpcRootHandler) + } + var compatListener net.Listener + var grpcListener net.Listener + var httpListener net.Listener + + // If certManager is configured and signalPort == 443, then the gRPC server has already been started + if certManager == nil || signalPort != 443 { + grpcListener, err = serveGRPC(grpcServer, signalPort) + if err != nil { + return err + } + log.Infof("running gRPC server: %s", grpcListener.Addr().String()) + } + if signalPort != 10000 { // The Signal gRPC server was running on port 10000 previously. Old agents that are already connected to Signal // are using port 10000. For compatibility purposes we keep running a 2nd gRPC server on port 10000. @@ -135,28 +140,6 @@ var ( log.Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String()) } - var grpcListener net.Listener - var httpListener net.Listener - if tlsEnabled { - httpListener = certManager.Listener() - if signalPort == 443 { - // running gRPC and HTTP cert manager on the same port - serveHTTP(httpListener, certManager.HTTPHandler(grpcHandlerFunc(grpcServer))) - log.Infof("running HTTP server (LetsEncrypt challenge handler) and gRPC server on the same port: %s", httpListener.Addr().String()) - } else { - serveHTTP(httpListener, certManager.HTTPHandler(nil)) - log.Infof("running HTTP server (LetsEncrypt challenge handler): %s", httpListener.Addr().String()) - } - } - - if signalPort != 443 || !tlsEnabled { - grpcListener, err = serveGRPC(grpcServer, signalPort) - if err != nil { - return err - } - log.Infof("running gRPC server: %s", grpcListener.Addr().String()) - } - log.Infof("signal server version %s", version.NetbirdVersion()) log.Infof("started Signal Service") @@ -190,6 +173,58 @@ var ( } ) +func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) { + var ( + err error + certManager *autocert.Manager + tlsConfig *tls.Config + ) + + if signalLetsencryptDomain == "" && signalCertFile == "" && signalCertKey == "" { + log.Infof("running without TLS") + return nil, nil, nil + } + + if signalLetsencryptDomain != "" { + certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain) + if err != nil { + return nil, certManager, err + } + tlsConfig = certManager.TLSConfig() + log.Infof("setting up TLS with LetsEncrypt.") + } else { + if signalCertFile == "" || signalCertKey == "" { + log.Errorf("both cert-file and cert-key must be provided when not using LetsEncrypt") + return nil, certManager, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt") + } + + tlsConfig, err = loadTLSConfig(signalCertFile, signalCertKey) + if err != nil { + log.Errorf("cannot load TLS credentials: %v", err) + return nil, certManager, err + } + log.Infof("setting up TLS with custom certificates.") + } + + transportCredentials := credentials.NewTLS(tlsConfig) + + return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, err +} + +func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler http.Handler) { + // a call to certManager.Listener() always creates a new listener so we do it once + httpListener := certManager.Listener() + if signalPort == 443 { + // running gRPC and HTTP cert manager on the same port + serveHTTP(httpListener, certManager.HTTPHandler(grpcRootHandler)) + log.Infof("running HTTP server (LetsEncrypt challenge handler) and gRPC server on the same port: %s", httpListener.Addr().String()) + } else { + // Start the HTTP cert manager server separately + serveHTTP(httpListener, certManager.HTTPHandler(nil)) + log.Infof("running HTTP server (LetsEncrypt challenge handler): %s", httpListener.Addr().String()) + } +} + func grpcHandlerFunc(grpcServer *grpc.Server) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { grpcHeader := strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc") || @@ -232,95 +267,29 @@ func serveGRPC(grpcServer *grpc.Server, port int) (net.Listener, error) { return listener, nil } -func cpFile(src, dst string) error { - var err error - var srcfd *os.File - var dstfd *os.File - var srcinfo os.FileInfo - - if srcfd, err = os.Open(src); err != nil { - return err - } - defer srcfd.Close() - - if dstfd, err = os.Create(dst); err != nil { - return err - } - defer dstfd.Close() - - if _, err = io.Copy(dstfd, srcfd); err != nil { - return err - } - if srcinfo, err = os.Stat(src); err != nil { - return err - } - return os.Chmod(dst, srcinfo.Mode()) -} - -func copySymLink(source, dest string) error { - link, err := os.Readlink(source) +func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) { + // Load server's certificate and private key + serverCert, err := tls.LoadX509KeyPair(certFile, certKey) if err != nil { - return err - } - return os.Symlink(link, dest) -} - -func cpDir(src string, dst string) error { - var err error - var fds []os.DirEntry - var srcinfo os.FileInfo - - if srcinfo, err = os.Stat(src); err != nil { - return err + return nil, err } - if err = os.MkdirAll(dst, srcinfo.Mode()); err != nil { - return err + // NewDefaultAppMetrics the credentials and return it + config := &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + ClientAuth: tls.NoClientCert, + NextProtos: []string{ + "h2", "http/1.1", // enable HTTP/2 + }, } - if fds, err = os.ReadDir(src); err != nil { - return err - } - for _, fd := range fds { - srcfp := path.Join(src, fd.Name()) - dstfp := path.Join(dst, fd.Name()) - - fileInfo, err := os.Stat(srcfp) - if err != nil { - log.Fatalf("Couldn't get fileInfo; %v", err) - } - - switch fileInfo.Mode() & os.ModeType { - case os.ModeSymlink: - if err = copySymLink(srcfp, dstfp); err != nil { - log.Fatalf("Failed to copy from %s to %s; %v", srcfp, dstfp, err) - } - case os.ModeDir: - if err = cpDir(srcfp, dstfp); err != nil { - log.Fatalf("Failed to copy from %s to %s; %v", srcfp, dstfp, err) - } - default: - if err = cpFile(srcfp, dstfp); err != nil { - log.Fatalf("Failed to copy from %s to %s; %v", srcfp, dstfp, err) - } - } - } - return nil -} - -func migrateToNetbird(oldPath, newPath string) bool { - _, errOld := os.Stat(oldPath) - _, errNew := os.Stat(newPath) - - if errors.Is(errOld, fs.ErrNotExist) || errNew == nil { - return false - } - - return true + return config, nil } func init() { runCmd.PersistentFlags().IntVar(&signalPort, "port", 80, "Server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise") runCmd.Flags().StringVar(&signalSSLDir, "ssl-dir", defaultSignalSSLDir, "server ssl directory location. *Required only for Let's Encrypt certificates.") runCmd.Flags().StringVar(&signalLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS") + runCmd.Flags().StringVar(&signalCertFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect") + runCmd.Flags().StringVar(&signalCertKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect") } diff --git a/signal/metrics/app.go b/signal/metrics/app.go index fb882a5d4..f8be88be7 100644 --- a/signal/metrics/app.go +++ b/signal/metrics/app.go @@ -15,6 +15,7 @@ type AppMetrics struct { Deregistrations metric.Int64Counter RegistrationFailures metric.Int64Counter RegistrationDelay metric.Float64Histogram + GetRegistrationDelay metric.Float64Histogram MessagesForwarded metric.Int64Counter MessageForwardFailures metric.Int64Counter @@ -54,6 +55,12 @@ func NewAppMetrics(meter metric.Meter) (*AppMetrics, error) { return nil, err } + getRegistrationDelay, err := meter.Float64Histogram("get_registration_delay_milliseconds", + metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...)) + if err != nil { + return nil, err + } + messagesForwarded, err := meter.Int64Counter("messages_forwarded_total") if err != nil { return nil, err @@ -80,6 +87,7 @@ func NewAppMetrics(meter metric.Meter) (*AppMetrics, error) { Deregistrations: deregistrations, RegistrationFailures: registrationFailures, RegistrationDelay: registrationDelay, + GetRegistrationDelay: getRegistrationDelay, MessagesForwarded: messagesForwarded, MessageForwardFailures: messageForwardFailures, diff --git a/signal/metrics/metrics.go b/signal/metrics/metrics.go index 30db1600a..f411501cb 100644 --- a/signal/metrics/metrics.go +++ b/signal/metrics/metrics.go @@ -26,10 +26,10 @@ type Metrics struct { } // NewServer initializes and returns a new Metrics instance -func NewServer(port int, endpoint string) *Metrics { +func NewServer(port int, endpoint string) (*Metrics, error) { exporter, err := prometheus.New() if err != nil { - return nil + return nil, err } provider := metric.NewMeterProvider(metric.WithReader(exporter)) @@ -57,7 +57,7 @@ func NewServer(port int, endpoint string) *Metrics { provider: provider, Endpoint: endpoint, Server: server, - } + }, nil } // Shutdown stops the metrics server diff --git a/signal/server/signal.go b/signal/server/signal.go index fc9c19efd..219bdcc41 100644 --- a/signal/server/signal.go +++ b/signal/server/signal.go @@ -23,11 +23,17 @@ const ( labelTypeError = "error" labelTypeNotConnected = "not_connected" labelTypeNotRegistered = "not_registered" + labelTypeStream = "stream" + labelTypeMessage = "message" labelError = "error" labelErrorMissingId = "missing_id" labelErrorMissingMeta = "missing_meta" labelErrorFailedHeader = "failed_header" + + labelRegistrationStatus = "status" + labelRegistrationFound = "found" + labelRegistrationNotFound = "not_found" ) // Server an instance of a Signal server @@ -61,7 +67,11 @@ func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto. return nil, fmt.Errorf("peer %s is not registered", msg.Key) } + getRegistrationStart := time.Now() + if dstPeer, found := s.registry.Get(msg.RemoteKey); found { + s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeMessage), attribute.String(labelRegistrationStatus, labelRegistrationFound))) + start := time.Now() //forward the message to the target peer if err := dstPeer.Stream.Send(msg); err != nil { log.Errorf("error while forwarding message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err) @@ -69,9 +79,11 @@ func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto. s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError))) } else { + s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeMessage))) s.metrics.MessagesForwarded.Add(context.Background(), 1) } } else { + s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeMessage), attribute.String(labelRegistrationStatus, labelRegistrationNotFound))) log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey) //todo respond to the sender? @@ -118,28 +130,30 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) } else if err != nil { return err } - start := time.Now() log.Debugf("received a new message from peer [%s] to peer [%s]", p.Id, msg.RemoteKey) + getRegistrationStart := time.Now() + // lookup the target peer where the message is going to if dstPeer, found := s.registry.Get(msg.RemoteKey); found { + s.metrics.GetRegistrationDelay.Record(stream.Context(), float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationFound))) + start := time.Now() //forward the message to the target peer if err := dstPeer.Stream.Send(msg); err != nil { log.Errorf("error while forwarding message from peer [%s] to peer [%s] %v", p.Id, msg.RemoteKey, err) //todo respond to the sender? - - // in milliseconds - s.metrics.MessageForwardLatency.Record(stream.Context(), float64(time.Since(start).Nanoseconds())/1e6) - s.metrics.MessagesForwarded.Add(stream.Context(), 1) - } else { s.metrics.MessageForwardFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelType, labelTypeError))) + } else { + // in milliseconds + s.metrics.MessageForwardLatency.Record(stream.Context(), float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream))) + s.metrics.MessagesForwarded.Add(stream.Context(), 1) } } else { + s.metrics.GetRegistrationDelay.Record(stream.Context(), float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationNotFound))) + s.metrics.MessageForwardFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected))) log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", p.Id, msg.RemoteKey) //todo respond to the sender? - - s.metrics.MessageForwardFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected))) } } <-stream.Context().Done() diff --git a/util/grpc/dialer.go b/util/grpc/dialer.go index 3fba0c84e..57ab8fd55 100644 --- a/util/grpc/dialer.go +++ b/util/grpc/dialer.go @@ -2,12 +2,18 @@ package grpc import ( "context" + "crypto/tls" "net" "os/user" "runtime" + "time" + "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/keepalive" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -35,3 +41,40 @@ func WithCustomDialer() grpc.DialOption { return conn, nil }) } + +// grpcDialBackoff is the backoff mechanism for the grpc calls +func Backoff(ctx context.Context) backoff.BackOff { + b := backoff.NewExponentialBackOff() + b.MaxElapsedTime = 10 * time.Second + b.Clock = backoff.SystemClock + return backoff.WithContext(b, ctx) +} + +func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) { + transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) + + if tlsEnabled { + transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})) + } + + connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + conn, err := grpc.DialContext( + connCtx, + addr, + transportOption, + WithCustomDialer(), + grpc.WithBlock(), + grpc.WithKeepaliveParams(keepalive.ClientParameters{ + Time: 30 * time.Second, + Timeout: 10 * time.Second, + }), + ) + if err != nil { + log.Printf("DialContext error: %v", err) + return nil, err + } + + return conn, nil +} diff --git a/util/log.go b/util/log.go index 90ccea48f..74b99311e 100644 --- a/util/log.go +++ b/util/log.go @@ -4,6 +4,7 @@ import ( "io" "os" "path/filepath" + "slices" log "github.com/sirupsen/logrus" "gopkg.in/natefinch/lumberjack.v2" @@ -18,8 +19,9 @@ func InitLog(logLevel string, logPath string) error { log.Errorf("Failed parsing log-level %s: %s", logLevel, err) return err } + customOutputs := []string{"console", "syslog"}; - if logPath != "" && logPath != "console" { + if logPath != "" && !slices.Contains(customOutputs, logPath) { lumberjackLogger := &lumberjack.Logger{ // Log file absolute path, os agnostic Filename: filepath.ToSlash(logPath), @@ -29,6 +31,8 @@ func InitLog(logLevel string, logPath string) error { Compress: true, } log.SetOutput(io.Writer(lumberjackLogger)) + } else if logPath == "syslog" { + AddSyslogHook() } if os.Getenv("NB_LOG_FORMAT") == "json" { diff --git a/util/syslog_nonwindows.go b/util/syslog_nonwindows.go new file mode 100644 index 000000000..6ffbcb8be --- /dev/null +++ b/util/syslog_nonwindows.go @@ -0,0 +1,20 @@ +//go:build !windows +// +build !windows + +package util + +import ( + "log/syslog" + + log "github.com/sirupsen/logrus" + lSyslog "github.com/sirupsen/logrus/hooks/syslog" +) + +func AddSyslogHook() { + hook, err := lSyslog.NewSyslogHook("", "", syslog.LOG_INFO, "") + + if err != nil { + log.Errorf("Failed creating syslog hook: %s", err) + } + log.AddHook(hook) +} diff --git a/util/syslog_windows.go b/util/syslog_windows.go new file mode 100644 index 000000000..171c1a459 --- /dev/null +++ b/util/syslog_windows.go @@ -0,0 +1,6 @@ +package util + +func AddSyslogHook() { + // The syslog package is not available for Windows. This adapter is needed + // to handle windows build. +} diff --git a/versioninfo.json b/versioninfo.json new file mode 100644 index 000000000..0967ef424 --- /dev/null +++ b/versioninfo.json @@ -0,0 +1 @@ +{}