diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 9e5e97a31..80809e667 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -1,15 +1,15 @@ -FROM golang:1.23-bullseye +FROM golang:1.25-bookworm RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ && apt-get -y install --no-install-recommends\ - gettext-base=0.21-4 \ - iptables=1.8.7-1 \ - libgl1-mesa-dev=20.3.5-1 \ - xorg-dev=1:7.7+22 \ - libayatana-appindicator3-dev=0.5.5-2+deb11u2 \ + gettext-base=0.21-12 \ + iptables=1.8.9-2 \ + libgl1-mesa-dev=22.3.6-1+deb12u1 \ + xorg-dev=1:7.7+23 \ + libayatana-appindicator3-dev=0.5.92-1 \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* \ - && go install -v golang.org/x/tools/gopls@v0.18.1 + && go install -v golang.org/x/tools/gopls@latest WORKDIR /app diff --git a/.github/workflows/golang-test-freebsd.yml b/.github/workflows/golang-test-freebsd.yml index 0d19e8a19..df64e86bb 100644 --- a/.github/workflows/golang-test-freebsd.yml +++ b/.github/workflows/golang-test-freebsd.yml @@ -25,7 +25,7 @@ jobs: release: "14.2" prepare: | pkg install -y curl pkgconf xorg - GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz" + GO_TARBALL="go1.25.3.freebsd-amd64.tar.gz" GO_URL="https://go.dev/dl/$GO_TARBALL" curl -vLO "$GO_URL" tar -C /usr/local -vxzf "$GO_TARBALL" diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index c09bfab39..195a37a1f 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -200,7 +200,7 @@ jobs: -e GOCACHE=${CONTAINER_GOCACHE} \ -e GOMODCACHE=${CONTAINER_GOMODCACHE} \ -e CONTAINER=${CONTAINER} \ - golang:1.24-alpine \ + golang:1.25-alpine \ sh -c ' \ apk update; apk add --no-cache \ ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \ @@ -259,7 +259,7 @@ jobs: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ go test ${{ matrix.raceFlag }} \ -exec 'sudo' \ - -timeout 10m ./relay/... ./shared/relay/... + -timeout 10m -p 1 ./relay/... ./shared/relay/... test_signal: name: "Signal / Unit" diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index c524f6f6b..9ce779dbb 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -52,7 +52,10 @@ jobs: if: matrix.os == 'ubuntu-latest' run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev - name: golangci-lint - uses: golangci/golangci-lint-action@v4 + uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0 with: version: latest - args: --timeout=12m --out-format colored-line-number + skip-cache: true + skip-save-cache: true + cache-invalidation-interval: 0 + args: --timeout=12m diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 2fa847dce..84f6f64ed 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -63,7 +63,7 @@ jobs: pkg install -y git curl portlint go # Install Go for building - GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz" + GO_TARBALL="go1.25.5.freebsd-amd64.tar.gz" GO_URL="https://go.dev/dl/$GO_TARBALL" curl -LO "$GO_URL" tar -C /usr/local -xzf "$GO_TARBALL" diff --git a/.github/workflows/wasm-build-validation.yml b/.github/workflows/wasm-build-validation.yml index 4100e16dd..47e45165b 100644 --- a/.github/workflows/wasm-build-validation.yml +++ b/.github/workflows/wasm-build-validation.yml @@ -14,6 +14,9 @@ jobs: js_lint: name: "JS / Lint" runs-on: ubuntu-latest + env: + GOOS: js + GOARCH: wasm steps: - name: Checkout repository uses: actions/checkout@v4 @@ -24,16 +27,14 @@ jobs: - name: Install dependencies run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev - name: Install golangci-lint - uses: golangci/golangci-lint-action@d6238b002a20823d52840fda27e2d4891c5952dc + uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0 with: version: latest install-mode: binary skip-cache: true - skip-pkg-cache: true - skip-build-cache: true - - name: Run golangci-lint for WASM - run: | - GOOS=js GOARCH=wasm golangci-lint run --timeout=12m --out-format colored-line-number ./client/... + skip-save-cache: true + cache-invalidation-interval: 0 + working-directory: ./client continue-on-error: true js_build: diff --git a/.golangci.yaml b/.golangci.yaml index 461677c2e..d81ad1377 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -1,139 +1,124 @@ -run: - # Timeout for analysis, e.g. 30s, 5m. - # Default: 1m - timeout: 6m - -# This file contains only configs which differ from defaults. -# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml -linters-settings: - errcheck: - # Report about not checking of errors in type assertions: `a := b.(MyStruct)`. - # Such cases aren't reported by default. - # Default: false - check-type-assertions: false - - gosec: - includes: - - G101 # Look for hard coded credentials - #- G102 # Bind to all interfaces - - G103 # Audit the use of unsafe block - - G104 # Audit errors not checked - - G106 # Audit the use of ssh.InsecureIgnoreHostKey - #- G107 # Url provided to HTTP request as taint input - - G108 # Profiling endpoint automatically exposed on /debug/pprof - - G109 # Potential Integer overflow made by strconv.Atoi result conversion to int16/32 - - G110 # Potential DoS vulnerability via decompression bomb - - G111 # Potential directory traversal - #- G112 # Potential slowloris attack - - G113 # Usage of Rat.SetString in math/big with an overflow (CVE-2022-23772) - #- G114 # Use of net/http serve function that has no support for setting timeouts - - G201 # SQL query construction using format string - - G202 # SQL query construction using string concatenation - - G203 # Use of unescaped data in HTML templates - #- G204 # Audit use of command execution - - G301 # Poor file permissions used when creating a directory - - G302 # Poor file permissions used with chmod - - G303 # Creating tempfile using a predictable path - - G304 # File path provided as taint input - - G305 # File traversal when extracting zip/tar archive - - G306 # Poor file permissions used when writing to a new file - - G307 # Poor file permissions used when creating a file with os.Create - #- G401 # Detect the usage of DES, RC4, MD5 or SHA1 - #- G402 # Look for bad TLS connection settings - - G403 # Ensure minimum RSA key length of 2048 bits - #- G404 # Insecure random number source (rand) - #- G501 # Import blocklist: crypto/md5 - - G502 # Import blocklist: crypto/des - - G503 # Import blocklist: crypto/rc4 - - G504 # Import blocklist: net/http/cgi - #- G505 # Import blocklist: crypto/sha1 - - G601 # Implicit memory aliasing of items from a range statement - - G602 # Slice access out of bounds - - gocritic: - disabled-checks: - - commentFormatting - - captLocal - - deprecatedComment - - govet: - # Enable all analyzers. - # Default: false - enable-all: false - enable: - - nilness - - revive: - rules: - - name: exported - severity: warning - disabled: false - arguments: - - "checkPrivateReceivers" - - "sayRepetitiveInsteadOfStutters" - tenv: - # The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures. - # Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked. - # Default: false - all: true - +version: "2" linters: - disable-all: true + default: none enable: - ## enabled by default - - errcheck # checking for unchecked errors, these unchecked errors can be critical bugs in some cases - - gosimple # specializes in simplifying a code - - govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string - - ineffassign # detects when assignments to existing variables are not used - - staticcheck # is a go vet on steroids, applying a ton of static analysis checks - - tenv # Tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17. - - typecheck # like the front-end of a Go compiler, parses and type-checks Go code - - unused # checks for unused constants, variables, functions and types - ## disable by default but the have interesting results so lets add them - - bodyclose # checks whether HTTP response body is closed successfully - - dupword # dupword checks for duplicate words in the source code - - durationcheck # durationcheck checks for two durations multiplied together - - forbidigo # forbidigo forbids identifiers - - gocritic # provides diagnostics that check for bugs, performance and style issues - - gosec # inspects source code for security problems - - mirror # mirror reports wrong mirror patterns of bytes/strings usage - - misspell # misspess finds commonly misspelled English words in comments - - nilerr # finds the code that returns nil even if it checks that the error is not nil - - nilnil # checks that there is no simultaneous return of nil error and an invalid value - - predeclared # predeclared finds code that shadows one of Go's predeclared identifiers - - revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint. - - sqlclosecheck # checks that sql.Rows and sql.Stmt are closed - # - thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers. - - wastedassign # wastedassign finds wasted assignment statements + - bodyclose + - dupword + - durationcheck + - errcheck + - forbidigo + - gocritic + - gosec + - govet + - ineffassign + - mirror + - misspell + - nilerr + - nilnil + - predeclared + - revive + - sqlclosecheck + - staticcheck + - unused + - wastedassign + settings: + errcheck: + check-type-assertions: false + gocritic: + disabled-checks: + - commentFormatting + - captLocal + - deprecatedComment + gosec: + includes: + - G101 + - G103 + - G104 + - G106 + - G108 + - G109 + - G110 + - G111 + - G201 + - G202 + - G203 + - G301 + - G302 + - G303 + - G304 + - G305 + - G306 + - G307 + - G403 + - G502 + - G503 + - G504 + - G601 + - G602 + govet: + enable: + - nilness + enable-all: false + revive: + rules: + - name: exported + arguments: + - checkPrivateReceivers + - sayRepetitiveInsteadOfStutters + severity: warning + disabled: false + exclusions: + generated: lax + presets: + - comments + - common-false-positives + - legacy + - std-error-handling + rules: + - linters: + - forbidigo + path: management/cmd/root\.go + - linters: + - forbidigo + path: signal/cmd/root\.go + - linters: + - unused + path: sharedsock/filter\.go + - linters: + - unused + path: client/firewall/iptables/rule\.go + - linters: + - gosec + - mirror + path: test\.go + - linters: + - nilnil + path: mock\.go + - linters: + - staticcheck + text: grpc.DialContext is deprecated + - linters: + - staticcheck + text: grpc.WithBlock is deprecated + - linters: + - staticcheck + text: "QF1001" + - linters: + - staticcheck + text: "QF1008" + - linters: + - staticcheck + text: "QF1012" + paths: + - third_party$ + - builtin$ + - examples$ issues: - # Maximum count of issues with the same text. - # Set to 0 to disable. - # Default: 3 max-same-issues: 5 - - exclude-rules: - # allow fmt - - path: management/cmd/root\.go - linters: forbidigo - - path: signal/cmd/root\.go - linters: forbidigo - - path: sharedsock/filter\.go - linters: - - unused - - path: client/firewall/iptables/rule\.go - linters: - - unused - - path: test\.go - linters: - - mirror - - gosec - - path: mock\.go - linters: - - nilnil - # Exclude specific deprecation warnings for grpc methods - - linters: - - staticcheck - text: "grpc.DialContext is deprecated" - - linters: - - staticcheck - text: "grpc.WithBlock is deprecated" +formatters: + exclusions: + generated: lax + paths: + - third_party$ + - builtin$ + - examples$ diff --git a/client/cmd/debug.go b/client/cmd/debug.go index 430012a17..7ca56857b 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -136,6 +136,7 @@ func setLogLevel(cmd *cobra.Command, args []string) error { client := proto.NewDaemonServiceClient(conn) level := server.ParseLogLevel(args[0]) if level == proto.LogLevel_UNKNOWN { + //nolint return fmt.Errorf("unknown log level: %s. Available levels are: panic, fatal, error, warn, info, debug, trace\n", args[0]) } diff --git a/client/cmd/login.go b/client/cmd/login.go index a34bb7c70..57c010571 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -81,6 +81,7 @@ var loginCmd = &cobra.Command{ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey string, activeProf *profilemanager.Profile, username string, pm *profilemanager.ProfileManager) error { conn, err := DialClientGRPCServer(ctx, daemonAddr) if err != nil { + //nolint return fmt.Errorf("failed to connect to daemon error: %v\n"+ "If the daemon is not running please run: "+ "\nnetbird service install \nnetbird service start\n", err) @@ -206,6 +207,7 @@ func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManage func switchProfile(ctx context.Context, profileName string, username string) error { conn, err := DialClientGRPCServer(ctx, daemonAddr) if err != nil { + //nolint return fmt.Errorf("failed to connect to daemon error: %v\n"+ "If the daemon is not running please run: "+ "\nnetbird service install \nnetbird service start\n", err) diff --git a/client/cmd/pprof.go b/client/cmd/pprof.go index 37efd35f0..c041c6ea9 100644 --- a/client/cmd/pprof.go +++ b/client/cmd/pprof.go @@ -1,5 +1,4 @@ //go:build pprof -// +build pprof package cmd diff --git a/client/cmd/root.go b/client/cmd/root.go index 30120c196..f4f4f6052 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -390,6 +390,7 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) { conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr) if err != nil { + //nolint return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+ "If the daemon is not running please run: "+ "\nnetbird service install \nnetbird service start\n", err) diff --git a/client/cmd/status.go b/client/cmd/status.go index 06460a6a7..99d47cd1a 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -124,6 +124,7 @@ func statusFunc(cmd *cobra.Command, args []string) error { func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) { conn, err := DialClientGRPCServer(ctx, daemonAddr) if err != nil { + //nolint return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+ "If the daemon is not running please run: "+ "\nnetbird service install \nnetbird service start\n", err) diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 888a9a3f7..2650d6225 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -89,9 +89,6 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp t.Cleanup(cleanUp) eventStore := &activity.InMemoryEventStore{} - if err != nil { - return nil, nil - } ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) diff --git a/client/cmd/up.go b/client/cmd/up.go index 9efc2e60d..057d35268 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -216,6 +216,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager conn, err := DialClientGRPCServer(ctx, daemonAddr) if err != nil { + //nolint return fmt.Errorf("failed to connect to daemon error: %v\n"+ "If the daemon is not running please run: "+ "\nnetbird service install \nnetbird service start\n", err) diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index 5ccaf17ba..d83798f09 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -386,11 +386,8 @@ func (m *aclManager) updateState() { // filterRuleSpecs returns the specs of a filtering rule func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) { - matchByIP := true // don't use IP matching if IP is 0.0.0.0 - if ip.IsUnspecified() { - matchByIP = false - } + matchByIP := !ip.IsUnspecified() if matchByIP { if ipsetName != "" { diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go index 6b5401e2b..ee47a27c0 100644 --- a/client/firewall/iptables/manager_linux_test.go +++ b/client/firewall/iptables/manager_linux_test.go @@ -161,7 +161,7 @@ func TestIptablesManagerDenyRules(t *testing.T) { t.Logf(" [%d] %s", i, rule) } - var denyRuleIndex, acceptRuleIndex int = -1, -1 + var denyRuleIndex, acceptRuleIndex = -1, -1 for i, rule := range rules { if strings.Contains(rule, "DROP") { t.Logf("Found DROP rule at index %d: %s", i, rule) diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 6b29c5606..75b1e2b6c 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -198,7 +198,7 @@ func TestNftablesManagerRuleOrder(t *testing.T) { t.Logf("Found %d rules in nftables chain", len(rules)) // Find the accept and deny rules and verify deny comes before accept - var acceptRuleIndex, denyRuleIndex int = -1, -1 + var acceptRuleIndex, denyRuleIndex = -1, -1 for i, rule := range rules { hasAcceptHTTPSet := false hasDenyHTTPSet := false @@ -208,11 +208,13 @@ func TestNftablesManagerRuleOrder(t *testing.T) { for _, e := range rule.Exprs { // Check for set lookup if lookup, ok := e.(*expr.Lookup); ok { - if lookup.SetName == "accept-http" { + switch lookup.SetName { + case "accept-http": hasAcceptHTTPSet = true - } else if lookup.SetName == "deny-http" { + case "deny-http": hasDenyHTTPSet = true } + } // Check for port 80 if cmp, ok := e.(*expr.Cmp); ok { @@ -222,9 +224,10 @@ func TestNftablesManagerRuleOrder(t *testing.T) { } // Check for verdict if verdict, ok := e.(*expr.Verdict); ok { - if verdict.Kind == expr.VerdictAccept { + switch verdict.Kind { + case expr.VerdictAccept: action = "ACCEPT" - } else if verdict.Kind == expr.VerdictDrop { + case expr.VerdictDrop: action = "DROP" } } diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index 4e22bde3f..8caa1a0ad 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -29,7 +29,7 @@ import ( ) const ( - layerTypeAll = 0 + layerTypeAll = 255 // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation ipTCPHeaderMinSize = 40 @@ -262,10 +262,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe } func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) { - wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String()) - if err != nil { - return nil, fmt.Errorf("parse wireguard network: %w", err) - } + wgPrefix := iface.Address().Network log.Debugf("blocking invalid routed traffic for %s", wgPrefix) rule, err := m.addRouteFiltering( @@ -439,19 +436,7 @@ func (m *Manager) AddPeerFiltering( r.sPort = sPort r.dPort = dPort - switch proto { - case firewall.ProtocolTCP: - r.protoLayer = layers.LayerTypeTCP - case firewall.ProtocolUDP: - r.protoLayer = layers.LayerTypeUDP - case firewall.ProtocolICMP: - r.protoLayer = layers.LayerTypeICMPv4 - if r.ipLayer == layers.LayerTypeIPv6 { - r.protoLayer = layers.LayerTypeICMPv6 - } - case firewall.ProtocolALL: - r.protoLayer = layerTypeAll - } + r.protoLayer = protoToLayer(proto, r.ipLayer) m.mutex.Lock() var targetMap map[netip.Addr]RuleSet @@ -496,16 +481,17 @@ func (m *Manager) addRouteFiltering( } ruleID := uuid.New().String() + rule := RouteRule{ // TODO: consolidate these IDs - id: ruleID, - mgmtId: id, - sources: sources, - dstSet: destination.Set, - proto: proto, - srcPort: sPort, - dstPort: dPort, - action: action, + id: ruleID, + mgmtId: id, + sources: sources, + dstSet: destination.Set, + protoLayer: protoToLayer(proto, layers.LayerTypeIPv4), + srcPort: sPort, + dstPort: dPort, + action: action, } if destination.IsPrefix() { rule.destinations = []netip.Prefix{destination.Prefix} @@ -795,7 +781,7 @@ func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeade pseudoSum += uint32(d.ip4.Protocol) pseudoSum += uint32(tcpLength) - var sum uint32 = pseudoSum + var sum = pseudoSum for i := 0; i < tcpLength-1; i += 2 { sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1]) } @@ -945,7 +931,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool { func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool { ruleID, blocked := m.peerACLsBlock(srcIP, d, packetData) if blocked { - _, pnum := getProtocolFromPacket(d) + pnum := getProtocolFromPacket(d) srcPort, dstPort := getPortsFromPacket(d) m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", @@ -1010,20 +996,22 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe return false } - proto, pnum := getProtocolFromPacket(d) + protoLayer := d.decoded[1] srcPort, dstPort := getPortsFromPacket(d) - ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) + ruleID, pass := m.routeACLsPass(srcIP, dstIP, protoLayer, srcPort, dstPort) if !pass { + proto := getProtocolFromPacket(d) + m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", - ruleID, pnum, srcIP, srcPort, dstIP, dstPort) + ruleID, proto, srcIP, srcPort, dstIP, dstPort) m.flowLogger.StoreEvent(nftypes.EventFields{ FlowID: uuid.New(), Type: nftypes.TypeDrop, RuleID: ruleID, Direction: nftypes.Ingress, - Protocol: pnum, + Protocol: proto, SourceIP: srcIP, DestIP: dstIP, SourcePort: srcPort, @@ -1052,16 +1040,33 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe return true } -func getProtocolFromPacket(d *decoder) (firewall.Protocol, nftypes.Protocol) { +func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType { + switch proto { + case firewall.ProtocolTCP: + return layers.LayerTypeTCP + case firewall.ProtocolUDP: + return layers.LayerTypeUDP + case firewall.ProtocolICMP: + if ipLayer == layers.LayerTypeIPv6 { + return layers.LayerTypeICMPv6 + } + return layers.LayerTypeICMPv4 + case firewall.ProtocolALL: + return layerTypeAll + } + return 0 +} + +func getProtocolFromPacket(d *decoder) nftypes.Protocol { switch d.decoded[1] { case layers.LayerTypeTCP: - return firewall.ProtocolTCP, nftypes.TCP + return nftypes.TCP case layers.LayerTypeUDP: - return firewall.ProtocolUDP, nftypes.UDP + return nftypes.UDP case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6: - return firewall.ProtocolICMP, nftypes.ICMP + return nftypes.ICMP default: - return firewall.ProtocolALL, nftypes.ProtocolUnknown + return nftypes.ProtocolUnknown } } @@ -1233,19 +1238,30 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d } // routeACLsPass returns true if the packet is allowed by the route ACLs -func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) { +func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) ([]byte, bool) { m.mutex.RLock() defer m.mutex.RUnlock() for _, rule := range m.routeRules { - if matches := m.ruleMatches(rule, srcIP, dstIP, proto, srcPort, dstPort); matches { + if matches := m.ruleMatches(rule, srcIP, dstIP, protoLayer, srcPort, dstPort); matches { return rule.mgmtId, rule.action == firewall.ActionAccept } } return nil, false } -func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool { +func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool { + // TODO: handle ipv6 vs ipv4 icmp rules + if rule.protoLayer != layerTypeAll && rule.protoLayer != protoLayer { + return false + } + + if protoLayer == layers.LayerTypeTCP || protoLayer == layers.LayerTypeUDP { + if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) { + return false + } + } + destMatched := false for _, dst := range rule.destinations { if dst.Contains(dstAddr) { @@ -1264,21 +1280,8 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot break } } - if !sourceMatched { - return false - } - if rule.proto != firewall.ProtocolALL && rule.proto != proto { - return false - } - - if proto == firewall.ProtocolTCP || proto == firewall.ProtocolUDP { - if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) { - return false - } - } - - return true + return sourceMatched } // AddUDPPacketHook calls hook when UDP packet from given direction matched diff --git a/client/firewall/uspfilter/filter_bench_test.go b/client/firewall/uspfilter/filter_bench_test.go index 5a2d0410f..10ff62ed3 100644 --- a/client/firewall/uspfilter/filter_bench_test.go +++ b/client/firewall/uspfilter/filter_bench_test.go @@ -955,7 +955,7 @@ func BenchmarkRouteACLs(b *testing.B) { for _, tc := range cases { srcIP := netip.MustParseAddr(tc.srcIP) dstIP := netip.MustParseAddr(tc.dstIP) - manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort) + manager.routeACLsPass(srcIP, dstIP, protoToLayer(tc.proto, layers.LayerTypeIPv4), 0, tc.dstPort) } } } diff --git a/client/firewall/uspfilter/filter_filter_test.go b/client/firewall/uspfilter/filter_filter_test.go index eb5aa3343..a8efbac1c 100644 --- a/client/firewall/uspfilter/filter_filter_test.go +++ b/client/firewall/uspfilter/filter_filter_test.go @@ -1259,7 +1259,7 @@ func TestRouteACLFiltering(t *testing.T) { // testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed // to the forwarder - _, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort) + _, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(tc.proto, layers.LayerTypeIPv4), tc.srcPort, tc.dstPort) require.Equal(t, tc.shouldPass, isAllowed) }) } @@ -1445,7 +1445,7 @@ func TestRouteACLOrder(t *testing.T) { srcIP := netip.MustParseAddr(p.srcIP) dstIP := netip.MustParseAddr(p.dstIP) - _, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort) + _, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(p.proto, layers.LayerTypeIPv4), p.srcPort, p.dstPort) require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i) } }) @@ -1488,13 +1488,13 @@ func TestRouteACLSet(t *testing.T) { dstIP := netip.MustParseAddr("192.168.1.100") // Check that traffic is dropped (empty set shouldn't match anything) - _, isAllowed := manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80) + _, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) require.False(t, isAllowed, "Empty set should not allow any traffic") err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")}) require.NoError(t, err) // Now the packet should be allowed - _, isAllowed = manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80) + _, isAllowed = manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) require.True(t, isAllowed, "After set update, traffic to the added network should be allowed") } diff --git a/client/firewall/uspfilter/filter_test.go b/client/firewall/uspfilter/filter_test.go index 120a9f418..c6a4ebeb8 100644 --- a/client/firewall/uspfilter/filter_test.go +++ b/client/firewall/uspfilter/filter_test.go @@ -767,9 +767,9 @@ func TestUpdateSetMerge(t *testing.T) { dstIP2 := netip.MustParseAddr("192.168.1.100") dstIP3 := netip.MustParseAddr("172.16.0.100") - _, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80) - _, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80) - _, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, fw.ProtocolTCP, 12345, 80) + _, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) + _, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) + _, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed") require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed") @@ -784,8 +784,8 @@ func TestUpdateSetMerge(t *testing.T) { require.NoError(t, err) // Check that all original prefixes are still included - _, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80) - _, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80) + _, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) + _, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update") require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update") @@ -793,8 +793,8 @@ func TestUpdateSetMerge(t *testing.T) { dstIP4 := netip.MustParseAddr("172.16.1.100") dstIP5 := netip.MustParseAddr("10.1.0.50") - _, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, fw.ProtocolTCP, 12345, 80) - _, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, fw.ProtocolTCP, 12345, 80) + _, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) + _, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed") require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed") @@ -922,7 +922,7 @@ func TestUpdateSetDeduplication(t *testing.T) { srcIP := netip.MustParseAddr("100.10.0.1") for _, tc := range testCases { - _, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, fw.ProtocolTCP, 12345, 80) + _, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) require.Equal(t, tc.expected, isAllowed, tc.desc) } } diff --git a/client/firewall/uspfilter/forwarder/endpoint.go b/client/firewall/uspfilter/forwarder/endpoint.go index f91291ea8..692a24140 100644 --- a/client/firewall/uspfilter/forwarder/endpoint.go +++ b/client/firewall/uspfilter/forwarder/endpoint.go @@ -2,6 +2,7 @@ package forwarder import ( "fmt" + "sync/atomic" wgdevice "golang.zx2c4.com/wireguard/device" "gvisor.dev/gvisor/pkg/tcpip" @@ -16,7 +17,7 @@ type endpoint struct { logger *nblog.Logger dispatcher stack.NetworkDispatcher device *wgdevice.Device - mtu uint32 + mtu atomic.Uint32 } func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { @@ -28,7 +29,7 @@ func (e *endpoint) IsAttached() bool { } func (e *endpoint) MTU() uint32 { - return e.mtu + return e.mtu.Load() } func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { @@ -82,6 +83,22 @@ func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool { return true } +func (e *endpoint) Close() { + // Endpoint cleanup - nothing to do as device is managed externally +} + +func (e *endpoint) SetLinkAddress(tcpip.LinkAddress) { + // Link address is not used for this endpoint type +} + +func (e *endpoint) SetMTU(mtu uint32) { + e.mtu.Store(mtu) +} + +func (e *endpoint) SetOnCloseAction(func()) { + // No action needed on close +} + type epID stack.TransportEndpointID func (i epID) String() string { diff --git a/client/firewall/uspfilter/forwarder/forwarder.go b/client/firewall/uspfilter/forwarder/forwarder.go index 00cb3f1df..d17c3cd5c 100644 --- a/client/firewall/uspfilter/forwarder/forwarder.go +++ b/client/firewall/uspfilter/forwarder/forwarder.go @@ -7,6 +7,7 @@ import ( "net/netip" "runtime" "sync" + "time" log "github.com/sirupsen/logrus" "gvisor.dev/gvisor/pkg/buffer" @@ -35,14 +36,16 @@ type Forwarder struct { logger *nblog.Logger flowLogger nftypes.FlowLogger // ruleIdMap is used to store the rule ID for a given connection - ruleIdMap sync.Map - stack *stack.Stack - endpoint *endpoint - udpForwarder *udpForwarder - ctx context.Context - cancel context.CancelFunc - ip tcpip.Address - netstack bool + ruleIdMap sync.Map + stack *stack.Stack + endpoint *endpoint + udpForwarder *udpForwarder + ctx context.Context + cancel context.CancelFunc + ip tcpip.Address + netstack bool + hasRawICMPAccess bool + pingSemaphore chan struct{} } func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) { @@ -60,8 +63,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow endpoint := &endpoint{ logger: logger, device: iface.GetWGDevice(), - mtu: uint32(mtu), } + endpoint.mtu.Store(uint32(mtu)) if err := s.CreateNIC(nicID, endpoint); err != nil { return nil, fmt.Errorf("create NIC: %v", err) @@ -103,15 +106,16 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow ctx, cancel := context.WithCancel(context.Background()) f := &Forwarder{ - logger: logger, - flowLogger: flowLogger, - stack: s, - endpoint: endpoint, - udpForwarder: newUDPForwarder(mtu, logger, flowLogger), - ctx: ctx, - cancel: cancel, - netstack: netstack, - ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()), + logger: logger, + flowLogger: flowLogger, + stack: s, + endpoint: endpoint, + udpForwarder: newUDPForwarder(mtu, logger, flowLogger), + ctx: ctx, + cancel: cancel, + netstack: netstack, + ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()), + pingSemaphore: make(chan struct{}, 3), } receiveWindow := defaultReceiveWindow @@ -129,6 +133,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP) + f.checkICMPCapability() + log.Debugf("forwarder: Initialization complete with NIC %d", nicID) return f, nil } @@ -198,3 +204,24 @@ func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKe DstPort: dstPort, } } + +// checkICMPCapability tests whether we have raw ICMP socket access at startup. +func (f *Forwarder) checkICMPCapability() { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + lc := net.ListenConfig{} + conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") + if err != nil { + f.hasRawICMPAccess = false + f.logger.Debug("forwarder: No raw ICMP socket access, will use ping binary fallback") + return + } + + if err := conn.Close(); err != nil { + f.logger.Debug1("forwarder: Failed to close ICMP capability test socket: %v", err) + } + + f.hasRawICMPAccess = true + f.logger.Debug("forwarder: Raw ICMP socket access available") +} diff --git a/client/firewall/uspfilter/forwarder/icmp.go b/client/firewall/uspfilter/forwarder/icmp.go index 939c04789..cb3db325d 100644 --- a/client/firewall/uspfilter/forwarder/icmp.go +++ b/client/firewall/uspfilter/forwarder/icmp.go @@ -2,8 +2,11 @@ package forwarder import ( "context" + "fmt" "net" "net/netip" + "os/exec" + "runtime" "time" "github.com/google/uuid" @@ -14,30 +17,95 @@ import ( ) // handleICMP handles ICMP packets from the network stack -func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool { +func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice()) - icmpType := uint8(icmpHdr.Type()) - icmpCode := uint8(icmpHdr.Code()) - - if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply { - // dont process our own replies - return true - } flowID := uuid.New() - f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode, 0, 0) + f.sendICMPEvent(nftypes.TypeStart, flowID, id, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 0, 0) - ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second) + // For Echo Requests, send and wait for response + if icmpHdr.Type() == header.ICMPv4Echo { + return f.handleICMPEcho(flowID, id, pkt, uint8(icmpHdr.Type()), uint8(icmpHdr.Code())) + } + + // For other ICMP types (Time Exceeded, Destination Unreachable, etc), forward without waiting + if !f.hasRawICMPAccess { + f.logger.Debug2("forwarder: Cannot handle ICMP type %v without raw socket access for %v", icmpHdr.Type(), epID(id)) + return false + } + + icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice() + conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 100*time.Millisecond) + if err != nil { + f.logger.Error2("forwarder: Failed to forward ICMP packet for %v: %v", epID(id), err) + return true + } + if err := conn.Close(); err != nil { + f.logger.Debug1("forwarder: Failed to close ICMP socket: %v", err) + } + + return true +} + +// handleICMPEcho handles ICMP echo requests asynchronously with rate limiting. +func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointID, pkt *stack.PacketBuffer, icmpType, icmpCode uint8) bool { + select { + case f.pingSemaphore <- struct{}{}: + icmpData := stack.PayloadSince(pkt.TransportHeader()).ToSlice() + rxBytes := pkt.Size() + + go func() { + defer func() { <-f.pingSemaphore }() + + if f.hasRawICMPAccess { + f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes) + } else { + f.handleICMPViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes) + } + }() + default: + f.logger.Debug3("forwarder: ICMP rate limit exceeded for %v type %v code %v", + epID(id), icmpType, icmpCode) + } + return true +} + +// forwardICMPPacket creates a raw ICMP socket and sends the packet, returning the connection. +// The caller is responsible for closing the returned connection. +func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, timeout time.Duration) (net.PacketConn, error) { + ctx, cancel := context.WithTimeout(f.ctx, timeout) defer cancel() lc := net.ListenConfig{} - // TODO: support non-root conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") if err != nil { - f.logger.Error2("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err) + return nil, fmt.Errorf("create ICMP socket: %w", err) + } - // This will make netstack reply on behalf of the original destination, that's ok for now - return false + dstIP := f.determineDialAddr(id.LocalAddress) + dst := &net.IPAddr{IP: dstIP} + + if _, err = conn.WriteTo(payload, dst); err != nil { + if closeErr := conn.Close(); closeErr != nil { + f.logger.Debug1("forwarder: Failed to close ICMP socket: %v", closeErr) + } + return nil, fmt.Errorf("write ICMP packet: %w", err) + } + + f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v", + epID(id), icmpType, icmpCode) + + return conn, nil +} + +// handleICMPViaSocket handles ICMP echo requests using raw sockets. +func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) { + sendTime := time.Now() + + conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, 5*time.Second) + if err != nil { + f.logger.Error2("forwarder: Failed to send ICMP packet for %v: %v", epID(id), err) + return } defer func() { if err := conn.Close(); err != nil { @@ -45,38 +113,22 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf } }() - dstIP := f.determineDialAddr(id.LocalAddress) - dst := &net.IPAddr{IP: dstIP} + txBytes := f.handleEchoResponse(conn, id) + rtt := time.Since(sendTime).Round(10 * time.Microsecond) - fullPacket := stack.PayloadSince(pkt.TransportHeader()) - payload := fullPacket.AsSlice() + f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)", + epID(id), icmpType, icmpCode, rtt) - if _, err = conn.WriteTo(payload, dst); err != nil { - f.logger.Error2("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err) - return true - } - - f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v", - epID(id), icmpHdr.Type(), icmpHdr.Code()) - - // For Echo Requests, send and handle response - if header.ICMPv4Type(icmpType) == header.ICMPv4Echo { - rxBytes := pkt.Size() - txBytes := f.handleEchoResponse(icmpHdr, conn, id) - f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) - } - - // For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing - return true + f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) } -func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) int { +func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID) int { if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err) return 0 } - response := make([]byte, f.endpoint.mtu) + response := make([]byte, f.endpoint.mtu.Load()) n, _, err := conn.ReadFrom(response) if err != nil { if !isTimeout(err) { @@ -85,31 +137,7 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon return 0 } - ipHdr := make([]byte, header.IPv4MinimumSize) - ip := header.IPv4(ipHdr) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(header.IPv4MinimumSize + n), - TTL: 64, - Protocol: uint8(header.ICMPv4ProtocolNumber), - SrcAddr: id.LocalAddress, - DstAddr: id.RemoteAddress, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - fullPacket := make([]byte, 0, len(ipHdr)+n) - fullPacket = append(fullPacket, ipHdr...) - fullPacket = append(fullPacket, response[:n]...) - - if err := f.InjectIncomingPacket(fullPacket); err != nil { - f.logger.Error1("forwarder: Failed to inject ICMP response: %v", err) - - return 0 - } - - f.logger.Trace3("forwarder: Forwarded ICMP echo reply for %v type %v code %v", - epID(id), icmpHdr.Type(), icmpHdr.Code()) - - return len(fullPacket) + return f.injectICMPReply(id, response[:n]) } // sendICMPEvent stores flow events for ICMP packets @@ -152,3 +180,95 @@ func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.T f.flowLogger.StoreEvent(fields) } + +// handleICMPViaPing handles ICMP echo requests by executing the system ping binary. +// This is used as a fallback when raw socket access is not available. +func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) { + ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second) + defer cancel() + + dstIP := f.determineDialAddr(id.LocalAddress) + cmd := buildPingCommand(ctx, dstIP, 5*time.Second) + + pingStart := time.Now() + if err := cmd.Run(); err != nil { + f.logger.Warn4("forwarder: Ping binary failed for %v type %v code %v: %v", epID(id), + icmpType, icmpCode, err) + return + } + rtt := time.Since(pingStart).Round(10 * time.Microsecond) + + f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v", + epID(id), icmpType, icmpCode) + + txBytes := f.synthesizeEchoReply(id, icmpData) + + f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)", + epID(id), icmpType, icmpCode, rtt) + + f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) +} + +// buildPingCommand creates a platform-specific ping command. +func buildPingCommand(ctx context.Context, target net.IP, timeout time.Duration) *exec.Cmd { + timeoutSec := int(timeout.Seconds()) + if timeoutSec < 1 { + timeoutSec = 1 + } + + switch runtime.GOOS { + case "linux", "android": + return exec.CommandContext(ctx, "ping", "-c", "1", "-W", fmt.Sprintf("%d", timeoutSec), "-q", target.String()) + case "darwin", "ios": + return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), "-q", target.String()) + case "freebsd": + return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), target.String()) + case "openbsd", "netbsd": + return exec.CommandContext(ctx, "ping", "-c", "1", "-w", fmt.Sprintf("%d", timeoutSec), target.String()) + case "windows": + return exec.CommandContext(ctx, "ping", "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String()) + default: + return exec.CommandContext(ctx, "ping", "-c", "1", target.String()) + } +} + +// synthesizeEchoReply creates an ICMP echo reply from raw ICMP data and injects it back into the network stack. +// Returns the size of the injected packet. +func (f *Forwarder) synthesizeEchoReply(id stack.TransportEndpointID, icmpData []byte) int { + replyICMP := make([]byte, len(icmpData)) + copy(replyICMP, icmpData) + + replyICMPHdr := header.ICMPv4(replyICMP) + replyICMPHdr.SetType(header.ICMPv4EchoReply) + replyICMPHdr.SetChecksum(0) + replyICMPHdr.SetChecksum(header.ICMPv4Checksum(replyICMPHdr, 0)) + + return f.injectICMPReply(id, replyICMP) +} + +// injectICMPReply wraps an ICMP payload in an IP header and injects it into the network stack. +// Returns the total size of the injected packet, or 0 if injection failed. +func (f *Forwarder) injectICMPReply(id stack.TransportEndpointID, icmpPayload []byte) int { + ipHdr := make([]byte, header.IPv4MinimumSize) + ip := header.IPv4(ipHdr) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(header.IPv4MinimumSize + len(icmpPayload)), + TTL: 64, + Protocol: uint8(header.ICMPv4ProtocolNumber), + SrcAddr: id.LocalAddress, + DstAddr: id.RemoteAddress, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + fullPacket := make([]byte, 0, len(ipHdr)+len(icmpPayload)) + fullPacket = append(fullPacket, ipHdr...) + fullPacket = append(fullPacket, icmpPayload...) + + // Bypass netstack and send directly to peer to avoid looping through our ICMP handler + if err := f.endpoint.device.CreateOutboundPacket(fullPacket, id.RemoteAddress.AsSlice()); err != nil { + f.logger.Error1("forwarder: Failed to send ICMP reply to peer: %v", err) + return 0 + } + + return len(fullPacket) +} diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go index 55743d975..f175e275b 100644 --- a/client/firewall/uspfilter/forwarder/udp.go +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io" "net" "net/netip" "sync" @@ -131,10 +132,10 @@ func (f *udpForwarder) cleanup() { } // handleUDP is called by the UDP forwarder for new packets -func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { +func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool { if f.ctx.Err() != nil { f.logger.Trace("forwarder: context done, dropping UDP packet") - return + return false } id := r.ID() @@ -144,7 +145,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { f.udpForwarder.RUnlock() if exists { f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id)) - return + return true } flowID := uuid.New() @@ -162,7 +163,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { if err != nil { f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err) // TODO: Send ICMP error message - return + return false } // Create wait queue for blocking syscalls @@ -173,10 +174,10 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { if err := outConn.Close(); err != nil { f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err) } - return + return false } - inConn := gonet.NewUDPConn(f.stack, &wq, ep) + inConn := gonet.NewUDPConn(&wq, ep) connCtx, connCancel := context.WithCancel(f.ctx) pConn := &udpPacketConn{ @@ -199,7 +200,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { if err := outConn.Close(); err != nil { f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err) } - return + return true } f.udpForwarder.conns[id] = pConn f.udpForwarder.Unlock() @@ -208,6 +209,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { f.logger.Trace1("forwarder: established UDP connection %v", epID(id)) go f.proxyUDP(connCtx, pConn, id, ep) + return true } func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) { @@ -348,7 +350,7 @@ func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bu } func isClosedError(err error) bool { - return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) + return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) || errors.Is(err, io.EOF) } func isTimeout(err error) bool { diff --git a/client/firewall/uspfilter/localip.go b/client/firewall/uspfilter/localip.go index 7f6b52c71..ffc807f46 100644 --- a/client/firewall/uspfilter/localip.go +++ b/client/firewall/uspfilter/localip.go @@ -130,6 +130,7 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) { // 127.0.0.0/8 newIPv4Bitmap[127] = &ipv4LowBitmap{} for i := 0; i < 8192; i++ { + // #nosec G602 -- bitmap is defined as [8192]uint32, loop range is correct newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF } diff --git a/client/firewall/uspfilter/localip_test.go b/client/firewall/uspfilter/localip_test.go index 45ac912cd..6653947fa 100644 --- a/client/firewall/uspfilter/localip_test.go +++ b/client/firewall/uspfilter/localip_test.go @@ -218,7 +218,7 @@ func BenchmarkIPChecks(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { // nolint:gosimple - _, _ = mapManager.localIPs[ip.String()] + _ = mapManager.localIPs[ip.String()] } }) @@ -227,7 +227,7 @@ func BenchmarkIPChecks(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { // nolint:gosimple - _, _ = mapManager.localIPs[ip.String()] + _ = mapManager.localIPs[ip.String()] } }) } diff --git a/client/firewall/uspfilter/log/log.go b/client/firewall/uspfilter/log/log.go index 139f702f2..66308defc 100644 --- a/client/firewall/uspfilter/log/log.go +++ b/client/firewall/uspfilter/log/log.go @@ -168,6 +168,15 @@ func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) { } } +func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) { + if l.level.Load() >= uint32(LevelWarn) { + select { + case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}: + default: + } + } +} + func (l *Logger) Debug1(format string, arg1 any) { if l.level.Load() >= uint32(LevelDebug) { select { diff --git a/client/firewall/uspfilter/nat_test.go b/client/firewall/uspfilter/nat_test.go index 400d61020..50743d006 100644 --- a/client/firewall/uspfilter/nat_test.go +++ b/client/firewall/uspfilter/nat_test.go @@ -234,9 +234,10 @@ func TestInboundPortDNATNegative(t *testing.T) { require.False(t, translated, "Packet should NOT be translated for %s", tc.name) d = parsePacket(t, packet) - if tc.protocol == layers.IPProtocolTCP { + switch tc.protocol { + case layers.IPProtocolTCP: require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged") - } else if tc.protocol == layers.IPProtocolUDP { + case layers.IPProtocolUDP: require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged") } }) diff --git a/client/firewall/uspfilter/rule.go b/client/firewall/uspfilter/rule.go index b765c72e9..dbe3a7858 100644 --- a/client/firewall/uspfilter/rule.go +++ b/client/firewall/uspfilter/rule.go @@ -34,7 +34,7 @@ type RouteRule struct { sources []netip.Prefix dstSet firewall.Set destinations []netip.Prefix - proto firewall.Protocol + protoLayer gopacket.LayerType srcPort *firewall.Port dstPort *firewall.Port action firewall.Action diff --git a/client/firewall/uspfilter/tracer.go b/client/firewall/uspfilter/tracer.go index c46a6581d..69c2519bf 100644 --- a/client/firewall/uspfilter/tracer.go +++ b/client/firewall/uspfilter/tracer.go @@ -379,9 +379,9 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace { } func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace { - proto, _ := getProtocolFromPacket(d) + protoLayer := d.decoded[1] srcPort, dstPort := getPortsFromPacket(d) - id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) + id, allowed := m.routeACLsPass(srcIP, dstIP, protoLayer, srcPort, dstPort) strId := string(id) if id == nil { diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go index dfb22ecde..0957d2dd5 100644 --- a/client/iface/bind/ice_bind.go +++ b/client/iface/bind/ice_bind.go @@ -27,8 +27,23 @@ type receiverCreator struct { iceBind *ICEBind } -func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc { - return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload, msgPool) +func (rc receiverCreator) CreateReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc { + if ipv4PC, ok := pc.(*ipv4.PacketConn); ok { + return rc.iceBind.createIPv4ReceiverFn(ipv4PC, conn, rxOffload, msgPool) + } + // IPv6 is currently not supported in the udpmux, this is a stub for compatibility with the + // wireguard-go ReceiverCreator interface which is called for both IPv4 and IPv6. + return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { + buf := bufs[0] + size, ep, err := conn.ReadFromUDPAddrPort(buf) + if err != nil { + return 0, err + } + sizes[0] = size + stdEp := &wgConn.StdNetEndpoint{AddrPort: ep} + eps[0] = stdEp + return 1, nil + } } // ICEBind is a bind implementation with two main features: diff --git a/client/iface/device/device_ios.go b/client/iface/device/device_ios.go index d841ac2fe..aa77cee45 100644 --- a/client/iface/device/device_ios.go +++ b/client/iface/device/device_ios.go @@ -1,6 +1,3 @@ -//go:build ios -// +build ios - package device import ( diff --git a/client/internal/debug/debug_linux.go b/client/internal/debug/debug_linux.go index 39d796fda..aedf88b79 100644 --- a/client/internal/debug/debug_linux.go +++ b/client/internal/debug/debug_linux.go @@ -507,15 +507,13 @@ func formatPayloadWithCmp(p *expr.Payload, cmp *expr.Cmp) string { if p.Base == expr.PayloadBaseNetworkHeader { switch p.Offset { case 12: - if p.Len == 4 { - return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data)) - } else if p.Len == 2 { + switch p.Len { + case 4, 2: return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data)) } case 16: - if p.Len == 4 { - return fmt.Sprintf("ip daddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data)) - } else if p.Len == 2 { + switch p.Len { + case 4, 2: return fmt.Sprintf("ip daddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data)) } } diff --git a/client/internal/iface.go b/client/internal/iface.go index bd0069c19..a82d87aab 100644 --- a/client/internal/iface.go +++ b/client/internal/iface.go @@ -1,5 +1,4 @@ //go:build !windows -// +build !windows package internal diff --git a/client/internal/routemanager/iface/iface.go b/client/internal/routemanager/iface/iface.go index 57dbec03d..b44d9fa65 100644 --- a/client/internal/routemanager/iface/iface.go +++ b/client/internal/routemanager/iface/iface.go @@ -1,5 +1,4 @@ //go:build !windows -// +build !windows package iface diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 26a548634..ec219c7fe 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -210,7 +210,8 @@ func (r *SysOps) refreshLocalSubnetsCache() { func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { nextHop := Nexthop{netip.Addr{}, intf} - if prefix == vars.Defaultv4 { + switch prefix { + case vars.Defaultv4: if err := r.addToRouteTable(splitDefaultv4_1, nextHop); err != nil { return err } @@ -233,7 +234,7 @@ func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) er } return nil - } else if prefix == vars.Defaultv6 { + case vars.Defaultv6: if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil { return fmt.Errorf("add unreachable route split 1: %w", err) } @@ -255,7 +256,8 @@ func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) er func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error { nextHop := Nexthop{netip.Addr{}, intf} - if prefix == vars.Defaultv4 { + switch prefix { + case vars.Defaultv4: var result *multierror.Error if err := r.removeFromRouteTable(splitDefaultv4_1, nextHop); err != nil { result = multierror.Append(result, err) @@ -273,7 +275,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) } return nberrors.FormatErrorOrNil(result) - } else if prefix == vars.Defaultv6 { + case vars.Defaultv6: var result *multierror.Error if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil { result = multierror.Append(result, err) @@ -283,9 +285,9 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) } return nberrors.FormatErrorOrNil(result) + default: + return r.removeFromRouteTable(prefix, nextHop) } - - return r.removeFromRouteTable(prefix, nextHop) } func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error { diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index e901386d9..935910fc9 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -76,7 +76,7 @@ type Client struct { loginComplete bool connectClient *internal.ConnectClient // preloadedConfig holds config loaded from JSON (used on tvOS where file writes are blocked) - preloadedConfig *profilemanager.Config + preloadedConfig *profilemanager.Config } // NewClient instantiate a new Client diff --git a/client/server/panic_windows.go b/client/server/panic_windows.go index f441ec9ea..8592f12ad 100644 --- a/client/server/panic_windows.go +++ b/client/server/panic_windows.go @@ -1,5 +1,4 @@ //go:build windows -// +build windows package server diff --git a/client/ssh/server/jwt_test.go b/client/ssh/server/jwt_test.go index d36d7cbbf..6eb88accc 100644 --- a/client/ssh/server/jwt_test.go +++ b/client/ssh/server/jwt_test.go @@ -602,12 +602,13 @@ func TestJWTAuthentication(t *testing.T) { require.NoError(t, err) var authMethods []cryptossh.AuthMethod - if tc.token == "valid" { + switch tc.token { + case "valid": token := generateValidJWT(t, privateKey, issuer, audience) authMethods = []cryptossh.AuthMethod{ cryptossh.Password(token), } - } else if tc.token == "invalid" { + case "invalid": invalidToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.invalid" authMethods = []cryptossh.AuthMethod{ cryptossh.Password(invalidToken), diff --git a/client/system/info_android.go b/client/system/info_android.go index 78895bfa8..794ff15ed 100644 --- a/client/system/info_android.go +++ b/client/system/info_android.go @@ -1,6 +1,3 @@ -//go:build android -// +build android - package system import ( diff --git a/client/system/info_darwin.go b/client/system/info_darwin.go index caa344737..4a31920ec 100644 --- a/client/system/info_darwin.go +++ b/client/system/info_darwin.go @@ -1,5 +1,4 @@ //go:build !ios -// +build !ios package system diff --git a/client/system/info_ios.go b/client/system/info_ios.go index 705c37920..322609db4 100644 --- a/client/system/info_ios.go +++ b/client/system/info_ios.go @@ -1,6 +1,3 @@ -//go:build ios -// +build ios - package system import ( diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 78934ea95..5d955ed25 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -510,7 +510,7 @@ func (s *serviceClient) saveSettings() { // Continue with default behavior if features can't be retrieved } else if features != nil && features.DisableUpdateSettings { log.Warn("Configuration updates are disabled by daemon") - dialog.ShowError(fmt.Errorf("Configuration updates are disabled by daemon"), s.wSettings) + dialog.ShowError(fmt.Errorf("configuration updates are disabled by daemon"), s.wSettings) return } @@ -540,7 +540,7 @@ func (s *serviceClient) saveSettings() { func (s *serviceClient) validateSettings() error { if s.iPreSharedKey.Text != "" && s.iPreSharedKey.Text != censoredPreSharedKey { if _, err := wgtypes.ParseKey(s.iPreSharedKey.Text); err != nil { - return fmt.Errorf("Invalid Pre-shared Key Value") + return fmt.Errorf("invalid pre-shared key value") } } return nil @@ -549,10 +549,10 @@ func (s *serviceClient) validateSettings() error { func (s *serviceClient) parseNumericSettings() (int64, int64, error) { port, err := strconv.ParseInt(s.iInterfacePort.Text, 10, 64) if err != nil { - return 0, 0, errors.New("Invalid interface port") + return 0, 0, errors.New("invalid interface port") } if port < 1 || port > 65535 { - return 0, 0, errors.New("Invalid interface port: out of range 1-65535") + return 0, 0, errors.New("invalid interface port: out of range 1-65535") } var mtu int64 @@ -560,7 +560,7 @@ func (s *serviceClient) parseNumericSettings() (int64, int64, error) { if mtuText != "" { mtu, err = strconv.ParseInt(mtuText, 10, 64) if err != nil { - return 0, 0, errors.New("Invalid MTU value") + return 0, 0, errors.New("invalid MTU value") } if mtu < iface.MinMTU || mtu > iface.MaxMTU { return 0, 0, fmt.Errorf("MTU must be between %d and %d bytes", iface.MinMTU, iface.MaxMTU) @@ -645,7 +645,7 @@ func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) ( if sshJWTCacheTTLText != "" { sshJWTCacheTTL, err := strconv.ParseInt(sshJWTCacheTTLText, 10, 32) if err != nil { - return nil, errors.New("Invalid SSH JWT Cache TTL value") + return nil, errors.New("invalid SSH JWT Cache TTL value") } if sshJWTCacheTTL < 0 || sshJWTCacheTTL > maxSSHJWTCacheTTL { return nil, fmt.Errorf("SSH JWT Cache TTL must be between 0 and %d seconds", maxSSHJWTCacheTTL) diff --git a/client/ui/signal_windows.go b/client/ui/signal_windows.go index ca98be526..58f46374f 100644 --- a/client/ui/signal_windows.go +++ b/client/ui/signal_windows.go @@ -164,7 +164,7 @@ func sendShowWindowSignal(pid int32) error { err = windows.SetEvent(eventHandle) if err != nil { - return fmt.Errorf("Error setting event: %w", err) + return fmt.Errorf("error setting event: %w", err) } return nil diff --git a/go.mod b/go.mod index 23cf0f37d..cf55b9260 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,8 @@ module github.com/netbirdio/netbird -go 1.24.10 +go 1.25 + +toolchain go1.25.5 require ( cunicu.li/go-rosenpass v0.4.0 @@ -40,7 +42,7 @@ require ( github.com/cilium/ebpf v0.15.0 github.com/coder/websocket v1.8.13 github.com/coreos/go-iptables v0.7.0 - github.com/creack/pty v1.1.18 + github.com/creack/pty v1.1.24 github.com/dexidp/dex v0.0.0-00010101000000-000000000000 github.com/dexidp/dex/api/v2 v2.4.0 github.com/eko/gocache/lib/v4 v4.2.0 @@ -81,7 +83,7 @@ require ( github.com/pion/turn/v3 v3.0.1 github.com/pkg/sftp v1.13.9 github.com/prometheus/client_golang v1.23.2 - github.com/quic-go/quic-go v0.49.1 + github.com/quic-go/quic-go v0.55.0 github.com/redis/go-redis/v9 v9.7.3 github.com/rs/xid v1.3.0 github.com/shirou/gopsutil/v3 v3.24.4 @@ -103,7 +105,7 @@ require ( go.opentelemetry.io/otel/exporters/prometheus v0.48.0 go.opentelemetry.io/otel/metric v1.38.0 go.opentelemetry.io/otel/sdk/metric v1.38.0 - go.uber.org/mock v0.5.0 + go.uber.org/mock v0.5.2 go.uber.org/zap v1.27.0 goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 @@ -120,7 +122,7 @@ require ( gorm.io/driver/postgres v1.5.7 gorm.io/driver/sqlite v1.5.7 gorm.io/gorm v1.25.12 - gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 + gvisor.dev/gvisor v0.0.0-20251031020517-ecfcdd2f171c ) require ( @@ -186,12 +188,10 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-sql-driver/mysql v1.9.3 // indirect - github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/go-text/render v0.2.0 // indirect github.com/go-text/typesetting v0.2.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/google/btree v1.1.2 // indirect - github.com/google/pprof v0.0.0-20211214055906-6f57359322fd // indirect github.com/google/s2a-go v0.1.9 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.7 // indirect github.com/googleapis/gax-go/v2 v2.15.0 // indirect @@ -285,7 +285,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024 replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 -replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6 +replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 diff --git a/go.sum b/go.sum index 354c7732e..e89e0ef12 100644 --- a/go.sum +++ b/go.sum @@ -101,9 +101,6 @@ github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK3 github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= -github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= -github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk= github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso= github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE= @@ -121,8 +118,8 @@ github.com/coreos/go-oidc/v3 v3.14.1/go.mod h1:HaZ3szPaZ0e4r6ebqvsLWlk2Tn+aejfmr github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA= github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= -github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= -github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= +github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= +github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:/DS5cDX3FJdl+XaN2D7XAwFpuanTxnp52DBLZAaJKx0= github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -286,7 +283,6 @@ github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09 github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= -github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= @@ -411,8 +407,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= -github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6 h1:X5h5QgP7uHAv78FWgHV8+WYLjHxK9v3ilkVXT1cpCrQ= -github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= +github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 h1:h/QnNzm7xzHPm+gajcblYUOclrW2FeNeDlUNj6tTWKQ= +github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk= @@ -491,8 +487,8 @@ github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9Z github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA= github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= -github.com/quic-go/quic-go v0.49.1 h1:e5JXpUyF0f2uFjckQzD8jTghZrOUK1xxDqqZhlwixo0= -github.com/quic-go/quic-go v0.49.1/go.mod h1:s2wDnmCdooUQBmQfpUSTCYBl1/D4FcqbULMMkASvR6s= +github.com/quic-go/quic-go v0.55.0 h1:zccPQIqYCXDt5NmcEabyYvOnomjs8Tlwl7tISjJh9Mk= +github.com/quic-go/quic-go v0.55.0/go.mod h1:DR51ilwU1uE164KuWXhinFcKWGlEjzys2l8zUl5Ss1U= github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM= github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= @@ -622,8 +618,8 @@ go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lI go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= -go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= +go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko= +go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= @@ -717,7 +713,6 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -848,5 +843,5 @@ gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY= gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= -gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 h1:qDCwdCWECGnwQSQC01Dpnp09fRHxJs9PbktotUqG+hs= -gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1/go.mod h1:8hmigyCdYtw5xJGfQDJzSH5Ju8XEIDBnpyi8+O6GRt8= +gvisor.dev/gvisor v0.0.0-20251031020517-ecfcdd2f171c h1:pfzmXIkkDgydR4ZRP+e1hXywZfYR21FA0Fbk6ptMkiA= +gvisor.dev/gvisor v0.0.0-20251031020517-ecfcdd2f171c/go.mod h1:/mc6CfwbOm5KKmqoV7Qx20Q+Ja8+vO4g7FuCdlVoAfQ= diff --git a/management/cmd/management.go b/management/cmd/management.go index 557cf45f8..5391b0866 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -64,7 +64,7 @@ var ( config.HttpConfig.IdpSignKeyRefreshEnabled = idpSignKeyRefreshEnabled } - tlsEnabled := false + var tlsEnabled bool if mgmtLetsencryptDomain != "" || (config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "") { tlsEnabled = true } diff --git a/management/internals/shared/grpc/loginfilter_test.go b/management/internals/shared/grpc/loginfilter_test.go index 8b26e14ab..797879ae7 100644 --- a/management/internals/shared/grpc/loginfilter_test.go +++ b/management/internals/shared/grpc/loginfilter_test.go @@ -85,6 +85,7 @@ func (s *LoginFilterTestSuite) TestBanDurationIncreasesExponentially() { s.True(s.filter.logged[pubKey].isBanned) s.Equal(2, s.filter.logged[pubKey].banLevel) secondBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen) + // nolint expectedSecondDuration := time.Duration(float64(baseBan) * math.Pow(2, 1)) s.InDelta(expectedSecondDuration, secondBanDuration, float64(time.Millisecond)) } diff --git a/management/server/account.go b/management/server/account.go index 26fb079f3..493159e37 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1006,7 +1006,7 @@ func (am *DefaultAccountManager) isCacheFresh(ctx context.Context, accountUsers for user, loggedInOnce := range accountUsers { if datum, ok := userDataMap[user]; ok { // check if the matching user data has a pending invite and if the user has logged in once, forcing the cache to be refreshed - if datum.AppMetadata.WTPendingInvite != nil && *datum.AppMetadata.WTPendingInvite && loggedInOnce == true { //nolint:gosimple + if datum.AppMetadata.WTPendingInvite != nil && *datum.AppMetadata.WTPendingInvite && loggedInOnce == true { //nolint log.WithContext(ctx).Infof("user %s has a pending invite and has logged in once, cache invalid", user) return false } diff --git a/management/server/account_test.go b/management/server/account_test.go index d3a888c78..6fd48f227 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -753,7 +753,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) { t.Fatalf("expected to create an account for a user %s", userId) } - if account != nil && account.Domain != domain { + if account.Domain != domain { t.Errorf("setting account domain failed, expected %s, got %s", domain, account.Domain) } @@ -768,7 +768,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) { t.Fatalf("expected to get an account for a user %s", userId) } - if account != nil && account.Domain != domain { + if account.Domain != domain { t.Errorf("updating domain. expected %s got %s", domain, account.Domain) } } @@ -3479,11 +3479,11 @@ func TestPropagateUserGroupMemberships(t *testing.T) { account, err := manager.GetOrCreateAccountByUser(ctx, auth.UserAuth{UserId: initiatorId, Domain: domain}) require.NoError(t, err) - peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"} + peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, Key: "key1", UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"} err = manager.Store.AddPeerToAccount(ctx, peer1) require.NoError(t, err) - peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"} + peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, Key: "key2", UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"} err = manager.Store.AddPeerToAccount(ctx, peer2) require.NoError(t, err) diff --git a/management/server/group_test.go b/management/server/group_test.go index 493487b0b..ee4403a73 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -903,6 +903,7 @@ func Test_AddPeerAndAddToAll(t *testing.T) { peer := &peer2.Peer{ ID: strconv.Itoa(i), AccountID: accountID, + Key: "key" + strconv.Itoa(i), DNSLabel: "peer" + strconv.Itoa(i), IP: uint32ToIP(uint32(i)), } diff --git a/management/server/http/handlers/policies/posture_checks_handler_test.go b/management/server/http/handlers/policies/posture_checks_handler_test.go index 35198da32..a5999f6c7 100644 --- a/management/server/http/handlers/policies/posture_checks_handler_test.go +++ b/management/server/http/handlers/policies/posture_checks_handler_test.go @@ -46,7 +46,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksH testPostureChecks[postureChecks.ID] = postureChecks if err := postureChecks.Validate(); err != nil { - return nil, status.Errorf(status.InvalidArgument, "%s", err.Error()) //nolint + return nil, status.Errorf(status.InvalidArgument, "%v", err) //nolint } return postureChecks, nil diff --git a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go index 3fe3fe809..3345a034b 100644 --- a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go @@ -1,5 +1,4 @@ //go:build benchmark -// +build benchmark package benchmarks diff --git a/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go index 36b226db0..ca25861dd 100644 --- a/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go @@ -1,5 +1,4 @@ //go:build benchmark -// +build benchmark package benchmarks diff --git a/management/server/http/testing/benchmarks/users_handler_benchmark_test.go b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go index 2868a20bd..b13773268 100644 --- a/management/server/http/testing/benchmarks/users_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go @@ -1,5 +1,4 @@ //go:build benchmark -// +build benchmark package benchmarks diff --git a/management/server/http/testing/integration/setupkeys_handler_integration_test.go b/management/server/http/testing/integration/setupkeys_handler_integration_test.go index 1079de4aa..c1a9829da 100644 --- a/management/server/http/testing/integration/setupkeys_handler_integration_test.go +++ b/management/server/http/testing/integration/setupkeys_handler_integration_test.go @@ -1,5 +1,4 @@ //go:build integration -// +build integration package integration diff --git a/management/server/idp/pocketid.go b/management/server/idp/pocketid.go index 38a5cc67f..d8d764830 100644 --- a/management/server/idp/pocketid.go +++ b/management/server/idp/pocketid.go @@ -121,7 +121,7 @@ func NewPocketIdManager(config PocketIdClientConfig, appMetrics telemetry.AppMet func (p *PocketIdManager) request(ctx context.Context, method, resource string, query *url.Values, body string) ([]byte, error) { var MethodsWithBody = []string{http.MethodPost, http.MethodPut} if !slices.Contains(MethodsWithBody, method) && body != "" { - return nil, fmt.Errorf("Body provided to unsupported method: %s", method) + return nil, fmt.Errorf("body provided to unsupported method: %s", method) } reqURL := fmt.Sprintf("%s/api/%s", p.managementEndpoint, resource) @@ -301,7 +301,7 @@ func (p *PocketIdManager) CreateUser(ctx context.Context, email, name, accountID if p.appMetrics != nil { p.appMetrics.IDPMetrics().CountCreateUser() } - var pending bool = true + pending := true ret := &UserData{ Email: email, Name: name, diff --git a/management/server/idp/zitadel.go b/management/server/idp/zitadel.go index 24228346a..8db3c4796 100644 --- a/management/server/idp/zitadel.go +++ b/management/server/idp/zitadel.go @@ -357,7 +357,7 @@ func (zm *ZitadelManager) CreateUser(ctx context.Context, email, name, accountID return nil, err } - var pending bool = true + pending := true ret := &UserData{ Email: email, Name: name, diff --git a/management/server/migration/migration.go b/management/server/migration/migration.go index 7a9155eba..c347f5089 100644 --- a/management/server/migration/migration.go +++ b/management/server/migration/migration.go @@ -393,7 +393,7 @@ func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName s return fmt.Errorf("failed to parse model schema: %w", err) } tableName := stmt.Schema.Table - dialect := db.Dialector.Name() + dialect := db.Name() if db.Migrator().HasIndex(&model, indexName) { log.WithContext(ctx).Infof("index %s already exists on table %s", indexName, tableName) @@ -404,10 +404,11 @@ func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName s if dialect == "mysql" { var withLength []string for _, col := range columns { - if col == "ip" || col == "dns_label" { - withLength = append(withLength, fmt.Sprintf("%s(64)", col)) + quotedCol := fmt.Sprintf("`%s`", col) + if col == "ip" || col == "dns_label" || col == "key" { + withLength = append(withLength, fmt.Sprintf("%s(64)", quotedCol)) } else { - withLength = append(withLength, col) + withLength = append(withLength, quotedCol) } } columnClause = strings.Join(withLength, ", ") @@ -488,6 +489,57 @@ func MigrateJsonToTable[T any](ctx context.Context, db *gorm.DB, columnName stri return nil } +func RemoveDuplicatePeerKeys(ctx context.Context, db *gorm.DB) error { + if !db.Migrator().HasTable("peers") { + log.WithContext(ctx).Debug("peers table does not exist, skipping duplicate key cleanup") + return nil + } + + keyColumn := GetColumnName(db, "key") + + var duplicates []struct { + Key string + Count int64 + } + + if err := db.Table("peers"). + Select(keyColumn + ", COUNT(*) as count"). + Group(keyColumn). + Having("COUNT(*) > 1"). + Find(&duplicates).Error; err != nil { + return fmt.Errorf("find duplicate keys: %w", err) + } + + if len(duplicates) == 0 { + return nil + } + + log.WithContext(ctx).Warnf("Found %d duplicate peer keys, cleaning up", len(duplicates)) + + for _, dup := range duplicates { + var peerIDs []string + if err := db.Table("peers"). + Select("id"). + Where(keyColumn+" = ?", dup.Key). + Order("peer_status_last_seen DESC"). + Pluck("id", &peerIDs).Error; err != nil { + return fmt.Errorf("get peers for key: %w", err) + } + + if len(peerIDs) <= 1 { + continue + } + + idsToDelete := peerIDs[1:] + + if err := db.Table("peers").Where("id IN ?", idsToDelete).Delete(nil).Error; err != nil { + return fmt.Errorf("delete duplicate peers: %w", err) + } + } + + return nil +} + // CleanupOrphanedIDs removes non-existent IDs from the JSON array column. // T is the type of the model that contains the list. // This migration cleans up the lists field by removing IDs that no longer exist in the target table. diff --git a/management/server/migration/migration_test.go b/management/server/migration/migration_test.go index ce76bd668..c1be8a3a3 100644 --- a/management/server/migration/migration_test.go +++ b/management/server/migration/migration_test.go @@ -340,3 +340,104 @@ func TestCreateIndexIfExists(t *testing.T) { exist = db.Migrator().HasIndex(&nbpeer.Peer{}, indexName) assert.True(t, exist, "Should have the index") } + +type testPeer struct { + ID string `gorm:"primaryKey"` + Key string `gorm:"index"` + PeerStatusLastSeen time.Time + PeerStatusConnected bool +} + +func (testPeer) TableName() string { + return "peers" +} + +func setupPeerTestDB(t *testing.T) *gorm.DB { + t.Helper() + db := setupDatabase(t) + _ = db.Migrator().DropTable(&testPeer{}) + err := db.AutoMigrate(&testPeer{}) + require.NoError(t, err, "Failed to auto-migrate tables") + return db +} + +func TestRemoveDuplicatePeerKeys_NoDuplicates(t *testing.T) { + db := setupPeerTestDB(t) + + now := time.Now() + peers := []testPeer{ + {ID: "peer1", Key: "key1", PeerStatusLastSeen: now}, + {ID: "peer2", Key: "key2", PeerStatusLastSeen: now}, + {ID: "peer3", Key: "key3", PeerStatusLastSeen: now}, + } + + for _, p := range peers { + err := db.Create(&p).Error + require.NoError(t, err) + } + + err := migration.RemoveDuplicatePeerKeys(context.Background(), db) + require.NoError(t, err) + + var count int64 + db.Model(&testPeer{}).Count(&count) + assert.Equal(t, int64(len(peers)), count, "All peers should remain when no duplicates") +} + +func TestRemoveDuplicatePeerKeys_WithDuplicates(t *testing.T) { + db := setupPeerTestDB(t) + + now := time.Now() + peers := []testPeer{ + {ID: "peer1", Key: "key1", PeerStatusLastSeen: now.Add(-2 * time.Hour)}, + {ID: "peer2", Key: "key1", PeerStatusLastSeen: now.Add(-1 * time.Hour)}, + {ID: "peer3", Key: "key1", PeerStatusLastSeen: now}, + {ID: "peer4", Key: "key2", PeerStatusLastSeen: now}, + {ID: "peer5", Key: "key3", PeerStatusLastSeen: now.Add(-1 * time.Hour)}, + {ID: "peer6", Key: "key3", PeerStatusLastSeen: now}, + } + + for _, p := range peers { + err := db.Create(&p).Error + require.NoError(t, err) + } + + err := migration.RemoveDuplicatePeerKeys(context.Background(), db) + require.NoError(t, err) + + var count int64 + db.Model(&testPeer{}).Count(&count) + assert.Equal(t, int64(3), count, "Should have 3 peers after removing duplicates") + + var remainingPeers []testPeer + err = db.Find(&remainingPeers).Error + require.NoError(t, err) + + remainingIDs := make(map[string]bool) + for _, p := range remainingPeers { + remainingIDs[p.ID] = true + } + + assert.True(t, remainingIDs["peer3"], "peer3 should remain (most recent for key1)") + assert.True(t, remainingIDs["peer4"], "peer4 should remain (only peer for key2)") + assert.True(t, remainingIDs["peer6"], "peer6 should remain (most recent for key3)") + + assert.False(t, remainingIDs["peer1"], "peer1 should be deleted (older duplicate)") + assert.False(t, remainingIDs["peer2"], "peer2 should be deleted (older duplicate)") + assert.False(t, remainingIDs["peer5"], "peer5 should be deleted (older duplicate)") +} + +func TestRemoveDuplicatePeerKeys_EmptyTable(t *testing.T) { + db := setupPeerTestDB(t) + + err := migration.RemoveDuplicatePeerKeys(context.Background(), db) + require.NoError(t, err, "Should not fail on empty table") +} + +func TestRemoveDuplicatePeerKeys_NoTable(t *testing.T) { + db := setupDatabase(t) + _ = db.Migrator().DropTable(&testPeer{}) + + err := migration.RemoveDuplicatePeerKeys(context.Background(), db) + require.NoError(t, err, "Should not fail when table does not exist") +} diff --git a/management/server/nameserver.go b/management/server/nameserver.go index f278e1761..a3eb4ae2e 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -20,7 +20,7 @@ import ( const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*[*.a-z]{1,}$` -var invalidDomainName = errors.New("invalid domain name") +var errInvalidDomainName = errors.New("invalid domain name") // GetNameServerGroup gets a nameserver group object from account and nameserver group IDs func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { @@ -314,7 +314,7 @@ func validateDomain(domain string) error { _, valid := dns.IsDomainName(domain) if !valid { - return invalidDomainName + return errInvalidDomainName } return nil diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index a898fd782..2439e8a22 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -19,7 +19,7 @@ type Peer struct { // AccountID is a reference to Account that this object belongs AccountID string `json:"-" gorm:"index"` // WireGuard public key - Key string `gorm:"index"` + Key string // uniqueness index (check migrations) // IP address of the Peer IP net.IP `gorm:"serializer:json"` // uniqueness index per accountID (check migrations) // Meta is a Peer system meta data diff --git a/management/server/peer_test.go b/management/server/peer_test.go index ce04adf9e..0160ff586 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -2129,12 +2129,14 @@ func Test_DeletePeer(t *testing.T) { "peer1": { ID: "peer1", AccountID: accountID, + Key: "key1", IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.test", }, "peer2": { ID: "peer2", AccountID: accountID, + Key: "key2", IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.test", }, diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 9a743eb8c..ba901c771 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -158,7 +158,7 @@ func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.St // validatePostureChecks validates the posture checks. func validatePostureChecks(ctx context.Context, transaction store.Store, accountID string, postureChecks *posture.Checks) error { if err := postureChecks.Validate(); err != nil { - return status.Errorf(status.InvalidArgument, "%s", err.Error()) //nolint + return status.Errorf(status.InvalidArgument, "%v", err.Error()) //nolint } // If the posture check already has an ID, verify its existence in the store. diff --git a/management/server/store/sql_store_get_account_test.go b/management/server/store/sql_store_get_account_test.go index 8ff04d68a..69e346ae7 100644 --- a/management/server/store/sql_store_get_account_test.go +++ b/management/server/store/sql_store_get_account_test.go @@ -997,9 +997,10 @@ func TestGetAccount_ComprehensiveFieldValidation(t *testing.T) { // Find posture checks by ID var pc1, pc2 *posture.Checks for _, pc := range retrievedAccount.PostureChecks { - if pc.ID == postureCheckID1 { + switch pc.ID { + case postureCheckID1: pc1 = pc - } else if pc.ID == postureCheckID2 { + case postureCheckID2: pc2 = pc } } diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 4365f234e..9480af7b5 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -30,7 +30,6 @@ import ( "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" nbroute "github.com/netbirdio/netbird/route" - route2 "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/util/crypt" ) @@ -110,12 +109,12 @@ func runLargeTest(t *testing.T, store Store) { AccountID: account.Id, } account.Users[user.Id] = user - route := &route2.Route{ - ID: route2.ID(fmt.Sprintf("network-id-%d", n)), + route := &nbroute.Route{ + ID: nbroute.ID(fmt.Sprintf("network-id-%d", n)), Description: "base route", - NetID: route2.NetID(fmt.Sprintf("network-id-%d", n)), + NetID: nbroute.NetID(fmt.Sprintf("network-id-%d", n)), Network: netip.MustParsePrefix(netIP.String() + "/24"), - NetworkType: route2.IPv4Network, + NetworkType: nbroute.IPv4Network, Metric: 9999, Masquerade: false, Enabled: true, @@ -689,7 +688,7 @@ func TestMigrate(t *testing.T) { require.NoError(t, err, "Failed to insert Gob data") type route struct { - route2.Route + nbroute.Route Network netip.Prefix `gorm:"serializer:gob"` PeerGroups []string `gorm:"serializer:gob"` } @@ -698,7 +697,7 @@ func TestMigrate(t *testing.T) { rt := &route{ Network: prefix, PeerGroups: []string{"group1", "group2"}, - Route: route2.Route{ID: "route1"}, + Route: nbroute.Route{ID: "route1"}, } err = store.(*SqlStore).db.Save(rt).Error @@ -714,7 +713,7 @@ func TestMigrate(t *testing.T) { require.NoError(t, err, "Failed to delete Gob data") prefix = netip.MustParsePrefix("12.0.0.0/24") - nRT := &route2.Route{ + nRT := &nbroute.Route{ Network: prefix, ID: "route2", Peer: "peer-id", @@ -969,6 +968,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) { peer1 := &nbpeer.Peer{ ID: "peer1", AccountID: existingAccountID, + Key: "key1", DNSLabel: "peer1", IP: net.IP{1, 1, 1, 1}, } @@ -983,6 +983,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) { peer2 := &nbpeer.Peer{ ID: "peer1second", AccountID: existingAccountID, + Key: "key2", DNSLabel: "peer1-1", IP: net.IP{2, 2, 2, 2}, } @@ -1010,6 +1011,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { peer1 := &nbpeer.Peer{ ID: "peer1", AccountID: existingAccountID, + Key: "key1", DNSLabel: "peer1", IP: net.IP{1, 1, 1, 1}, } @@ -1023,6 +1025,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { peer2 := &nbpeer.Peer{ ID: "peer1second", AccountID: existingAccountID, + Key: "key2", DNSLabel: "peer1-1", IP: net.IP{2, 2, 2, 2}, } @@ -1049,6 +1052,7 @@ func Test_AddPeerWithSameDnsLabel(t *testing.T) { peer1 := &nbpeer.Peer{ ID: "peer1", AccountID: existingAccountID, + Key: "key1", DNSLabel: "peer1.domain.test", } err = store.AddPeerToAccount(context.Background(), peer1) @@ -1057,6 +1061,7 @@ func Test_AddPeerWithSameDnsLabel(t *testing.T) { peer2 := &nbpeer.Peer{ ID: "peer1second", AccountID: existingAccountID, + Key: "key2", DNSLabel: "peer1.domain.test", } err = store.AddPeerToAccount(context.Background(), peer2) @@ -1074,6 +1079,7 @@ func Test_AddPeerWithSameIP(t *testing.T) { peer1 := &nbpeer.Peer{ ID: "peer1", AccountID: existingAccountID, + Key: "key1", IP: net.IP{1, 1, 1, 1}, } err = store.AddPeerToAccount(context.Background(), peer1) @@ -1082,6 +1088,7 @@ func Test_AddPeerWithSameIP(t *testing.T) { peer2 := &nbpeer.Peer{ ID: "peer1second", AccountID: existingAccountID, + Key: "key2", IP: net.IP{1, 1, 1, 1}, } err = store.AddPeerToAccount(context.Background(), peer2) @@ -3547,13 +3554,13 @@ func TestSqlStore_SaveRoute(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - route := &route2.Route{ + route := &nbroute.Route{ ID: "route-id", AccountID: accountID, Network: netip.MustParsePrefix("10.10.0.0/16"), NetID: "netID", PeerGroups: []string{"routeA"}, - NetworkType: route2.IPv4Network, + NetworkType: nbroute.IPv4Network, Masquerade: true, Metric: 9999, Enabled: true, @@ -3700,6 +3707,7 @@ func BenchmarkGetAccountPeers(b *testing.B) { peer := &nbpeer.Peer{ ID: fmt.Sprintf("peer-%d", i), AccountID: accountID, + Key: fmt.Sprintf("key-%d", i), DNSLabel: fmt.Sprintf("peer%d.example.com", i), IP: intToIPv4(uint32(i)), } diff --git a/management/server/store/store.go b/management/server/store/store.go index 372f2ebdc..cf73af341 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -352,11 +352,16 @@ func getMigrationsPreAuto(ctx context.Context) []migrationFunc { func(db *gorm.DB) error { return migration.MigrateNewField[types.User](ctx, db, "email", "") }, + func(db *gorm.DB) error { + return migration.RemoveDuplicatePeerKeys(ctx, db) + }, func(db *gorm.DB) error { return migration.CleanupOrphanedIDs[types.User, types.Group](ctx, db, "auto_groups") }, } -} // migratePostAuto migrates the SQLite database to the latest schema +} + +// migratePostAuto migrates the SQLite database to the latest schema func migratePostAuto(ctx context.Context, db *gorm.DB) error { migrations := getMigrationsPostAuto(ctx) @@ -386,6 +391,12 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc { } }) }, + func(db *gorm.DB) error { + return migration.DropIndex[nbpeer.Peer](ctx, db, "idx_peers_key") + }, + func(db *gorm.DB) error { + return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_peers_key_unique", "key") + }, func(db *gorm.DB) error { return migration.MigrateJsonToTable[types.User](ctx, db, "auto_groups", func(accountID, id, value string) any { return &types.GroupUser{ diff --git a/management/server/testdata/extended-store.sql b/management/server/testdata/extended-store.sql index 0393d1ade..9bb5dbace 100644 --- a/management/server/testdata/extended-store.sql +++ b/management/server/testdata/extended-store.sql @@ -14,7 +14,7 @@ CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); -CREATE INDEX `idx_peers_key` ON `peers`(`key`); +CREATE UNIQUE INDEX `idx_peers_key_unique` ON `peers`(`key`); CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`); CREATE INDEX `idx_users_account_id` ON `users`(`account_id`); CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`); diff --git a/management/server/testdata/store.sql b/management/server/testdata/store.sql index a21783857..022508323 100644 --- a/management/server/testdata/store.sql +++ b/management/server/testdata/store.sql @@ -18,7 +18,7 @@ CREATE TABLE `network_resources` (`id` text,`network_id` text,`account_id` text, CREATE TABLE `networks` (`id` text,`account_id` text,`name` text,`description` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_networks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); -CREATE INDEX `idx_peers_key` ON `peers`(`key`); +CREATE UNIQUE INDEX `idx_peers_key_unique` ON `peers`(`key`); CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`); CREATE INDEX `idx_peers_account_id_ip` ON `peers`(`account_id`,`ip`); CREATE INDEX `idx_users_account_id` ON `users`(`account_id`); @@ -54,4 +54,4 @@ INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','D INSERT INTO network_routers VALUES('ctc20ji7qv9ck2sebc80','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','cs1tnh0hhcjnqoiuebeg',NULL,0,0); INSERT INTO network_resources VALUES ('ctc4nci7qv9061u6ilfg','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Host','192.168.1.1'); INSERT INTO networks VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Test Network','Test Network'); -INSERT INTO peers VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','','','"192.168.0.0"','','','','','','','','','','','','','','','','','test','test','2023-01-01 00:00:00+00:00',0,0,0,'a23efe53-63fb-11ec-90d6-0242ac120003','',0,0,'2023-01-01 00:00:00+00:00','2023-01-01 00:00:00+00:00',0,'','','',0); +INSERT INTO peers VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Do=','','"192.168.0.0"','','','','','','','','','','','','','','','','','test','test','2023-01-01 00:00:00+00:00',0,0,0,'a23efe53-63fb-11ec-90d6-0242ac120003','',0,0,'2023-01-01 00:00:00+00:00','2023-01-01 00:00:00+00:00',0,'','','',0); diff --git a/management/server/testdata/store_policy_migrate.sql b/management/server/testdata/store_policy_migrate.sql index a88411795..395276cb1 100644 --- a/management/server/testdata/store_policy_migrate.sql +++ b/management/server/testdata/store_policy_migrate.sql @@ -14,7 +14,7 @@ CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); -CREATE INDEX `idx_peers_key` ON `peers`(`key`); +CREATE UNIQUE INDEX `idx_peers_key_unique` ON `peers`(`key`); CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`); CREATE INDEX `idx_users_account_id` ON `users`(`account_id`); CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`); diff --git a/management/server/testdata/store_with_expired_peers.sql b/management/server/testdata/store_with_expired_peers.sql index f2ef56a23..dfcaeee6f 100644 --- a/management/server/testdata/store_with_expired_peers.sql +++ b/management/server/testdata/store_with_expired_peers.sql @@ -14,7 +14,7 @@ CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); -CREATE INDEX `idx_peers_key` ON `peers`(`key`); +CREATE UNIQUE INDEX `idx_peers_key_unique` ON `peers`(`key`); CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`); CREATE INDEX `idx_users_account_id` ON `users`(`account_id`); CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`); @@ -30,7 +30,7 @@ INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62 INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-34653001fc3b','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,0,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); -INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.97"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost-1','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','nVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HX=','','"100.64.117.97"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost-1','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,''); INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,''); INSERT INTO installations VALUES(1,''); diff --git a/management/server/testdata/storev1.sql b/management/server/testdata/storev1.sql index 8b09ec2be..eb5be31b7 100644 --- a/management/server/testdata/storev1.sql +++ b/management/server/testdata/storev1.sql @@ -14,7 +14,7 @@ CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); -CREATE INDEX `idx_peers_key` ON `peers`(`key`); +CREATE UNIQUE INDEX `idx_peers_key_unique` ON `peers`(`key`); CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`); CREATE INDEX `idx_users_account_id` ON `users`(`account_id`); CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`); diff --git a/management/server/testutil/store.go b/management/server/testutil/store.go index db418c45b..f92153399 100644 --- a/management/server/testutil/store.go +++ b/management/server/testutil/store.go @@ -1,5 +1,4 @@ //go:build !ios -// +build !ios package testutil diff --git a/management/server/testutil/store_ios.go b/management/server/testutil/store_ios.go index c3dd839d3..9e3b5ce4a 100644 --- a/management/server/testutil/store_ios.go +++ b/management/server/testutil/store_ios.go @@ -1,5 +1,4 @@ //go:build ios -// +build ios package testutil diff --git a/relay/cmd/pprof.go b/relay/cmd/pprof.go index 37efd35f0..c041c6ea9 100644 --- a/relay/cmd/pprof.go +++ b/relay/cmd/pprof.go @@ -1,5 +1,4 @@ //go:build pprof -// +build pprof package cmd diff --git a/relay/server/listener/quic/conn.go b/relay/server/listener/quic/conn.go index 909ec1cc6..6e2201bf7 100644 --- a/relay/server/listener/quic/conn.go +++ b/relay/server/listener/quic/conn.go @@ -12,14 +12,14 @@ import ( ) type Conn struct { - session quic.Connection + session *quic.Conn closed bool closedMu sync.Mutex ctx context.Context ctxCancel context.CancelFunc } -func NewConn(session quic.Connection) *Conn { +func NewConn(session *quic.Conn) *Conn { ctx, cancel := context.WithCancel(context.Background()) return &Conn{ session: session, diff --git a/relay/server/listener/ws/conn.go b/relay/server/listener/ws/conn.go index 3ec08945b..d5bce56f7 100644 --- a/relay/server/listener/ws/conn.go +++ b/relay/server/listener/ws/conn.go @@ -88,7 +88,7 @@ func (c *Conn) Close() error { c.closedMu.Lock() c.closed = true c.closedMu.Unlock() - return c.Conn.CloseNow() + return c.CloseNow() } func (c *Conn) isClosed() bool { diff --git a/shared/management/client/rest/accounts_test.go b/shared/management/client/rest/accounts_test.go index be0066488..e44ada298 100644 --- a/shared/management/client/rest/accounts_test.go +++ b/shared/management/client/rest/accounts_test.go @@ -1,5 +1,4 @@ //go:build integration -// +build integration package rest_test diff --git a/shared/management/client/rest/client.go b/shared/management/client/rest/client.go index 4d1de2631..77c960435 100644 --- a/shared/management/client/rest/client.go +++ b/shared/management/client/rest/client.go @@ -161,7 +161,7 @@ func (c *Client) NewRequest(ctx context.Context, method, path string, body io.Re func parseResponse[T any](resp *http.Response) (T, error) { var ret T if resp.Body == nil { - return ret, fmt.Errorf("Body missing, HTTP Error code %d", resp.StatusCode) + return ret, fmt.Errorf("body missing, HTTP Error code %d", resp.StatusCode) } bs, err := io.ReadAll(resp.Body) if err != nil { @@ -169,7 +169,7 @@ func parseResponse[T any](resp *http.Response) (T, error) { } err = json.Unmarshal(bs, &ret) if err != nil { - return ret, fmt.Errorf("Error code %d, error unmarshalling body: %w", resp.StatusCode, err) + return ret, fmt.Errorf("error code %d, error unmarshalling body: %w", resp.StatusCode, err) } return ret, nil diff --git a/shared/management/client/rest/client_test.go b/shared/management/client/rest/client_test.go index 17df8dd8b..2b3e6cabe 100644 --- a/shared/management/client/rest/client_test.go +++ b/shared/management/client/rest/client_test.go @@ -1,5 +1,4 @@ //go:build integration -// +build integration package rest_test diff --git a/shared/management/client/rest/dns_test.go b/shared/management/client/rest/dns_test.go index 58082abe8..8e8633f8d 100644 --- a/shared/management/client/rest/dns_test.go +++ b/shared/management/client/rest/dns_test.go @@ -1,5 +1,4 @@ //go:build integration -// +build integration package rest_test diff --git a/shared/management/client/rest/events_test.go b/shared/management/client/rest/events_test.go index b28390001..1ee10eb6e 100644 --- a/shared/management/client/rest/events_test.go +++ b/shared/management/client/rest/events_test.go @@ -1,5 +1,4 @@ //go:build integration -// +build integration package rest_test diff --git a/shared/management/client/rest/geo_test.go b/shared/management/client/rest/geo_test.go index fcb4808a1..2410f2641 100644 --- a/shared/management/client/rest/geo_test.go +++ b/shared/management/client/rest/geo_test.go @@ -1,5 +1,4 @@ //go:build integration -// +build integration package rest_test diff --git a/shared/management/client/rest/groups_test.go b/shared/management/client/rest/groups_test.go index fcd759e9a..51fd0c0ee 100644 --- a/shared/management/client/rest/groups_test.go +++ b/shared/management/client/rest/groups_test.go @@ -1,5 +1,4 @@ //go:build integration -// +build integration package rest_test diff --git a/shared/management/client/rest/impersonation_test.go b/shared/management/client/rest/impersonation_test.go index 4fb8f24eb..d257d0987 100644 --- a/shared/management/client/rest/impersonation_test.go +++ b/shared/management/client/rest/impersonation_test.go @@ -1,5 +1,4 @@ //go:build integration -// +build integration package rest_test diff --git a/shared/management/client/rest/networks_test.go b/shared/management/client/rest/networks_test.go index ca2a294ae..2bf1a0d3b 100644 --- a/shared/management/client/rest/networks_test.go +++ b/shared/management/client/rest/networks_test.go @@ -1,5 +1,4 @@ //go:build integration -// +build integration package rest_test diff --git a/shared/management/client/rest/peers_test.go b/shared/management/client/rest/peers_test.go index a45f9d6ec..c464de7ed 100644 --- a/shared/management/client/rest/peers_test.go +++ b/shared/management/client/rest/peers_test.go @@ -1,5 +1,4 @@ //go:build integration -// +build integration package rest_test diff --git a/shared/management/client/rest/policies_test.go b/shared/management/client/rest/policies_test.go index a19d0a728..e948e2949 100644 --- a/shared/management/client/rest/policies_test.go +++ b/shared/management/client/rest/policies_test.go @@ -1,5 +1,4 @@ //go:build integration -// +build integration package rest_test diff --git a/shared/management/client/rest/posturechecks_test.go b/shared/management/client/rest/posturechecks_test.go index 9b1b618df..d74d455a5 100644 --- a/shared/management/client/rest/posturechecks_test.go +++ b/shared/management/client/rest/posturechecks_test.go @@ -1,5 +1,4 @@ //go:build integration -// +build integration package rest_test diff --git a/shared/management/client/rest/routes_test.go b/shared/management/client/rest/routes_test.go index 9452a07fc..5ee2def24 100644 --- a/shared/management/client/rest/routes_test.go +++ b/shared/management/client/rest/routes_test.go @@ -1,5 +1,4 @@ //go:build integration -// +build integration package rest_test diff --git a/shared/management/client/rest/setupkeys_test.go b/shared/management/client/rest/setupkeys_test.go index 0fa782da5..bd8d3f835 100644 --- a/shared/management/client/rest/setupkeys_test.go +++ b/shared/management/client/rest/setupkeys_test.go @@ -1,5 +1,4 @@ //go:build integration -// +build integration package rest_test diff --git a/shared/management/client/rest/tokens_test.go b/shared/management/client/rest/tokens_test.go index ce3748751..5af41eb73 100644 --- a/shared/management/client/rest/tokens_test.go +++ b/shared/management/client/rest/tokens_test.go @@ -1,5 +1,4 @@ //go:build integration -// +build integration package rest_test diff --git a/shared/management/client/rest/users_test.go b/shared/management/client/rest/users_test.go index d53c4eb6a..68815d4f9 100644 --- a/shared/management/client/rest/users_test.go +++ b/shared/management/client/rest/users_test.go @@ -1,5 +1,4 @@ //go:build integration -// +build integration package rest_test diff --git a/shared/relay/client/client_test.go b/shared/relay/client/client_test.go index 8fe5f04f4..9820d642f 100644 --- a/shared/relay/client/client_test.go +++ b/shared/relay/client/client_test.go @@ -19,15 +19,7 @@ import ( ) var ( - hmacTokenStore = &hmac.TokenStore{} - serverListenAddr = "127.0.0.1:1234" - serverURL = "rel://127.0.0.1:1234" - serverCfg = server.Config{ - Meter: otel.Meter(""), - ExposedAddress: serverURL, - TLSSupport: false, - AuthValidator: &allow.Auth{}, - } + hmacTokenStore = &hmac.TokenStore{} ) func TestMain(m *testing.M) { @@ -36,8 +28,20 @@ func TestMain(m *testing.M) { os.Exit(code) } +// newClientTestServerConfig creates a new server config for client testing with the given address +func newClientTestServerConfig(address string) server.Config { + return server.Config{ + Meter: otel.Meter(""), + ExposedAddress: "rel://" + address, + TLSSupport: false, + AuthValidator: &allow.Auth{}, + } +} + func TestClient(t *testing.T) { ctx := context.Background() + serverListenAddr := "127.0.0.1:50001" + serverCfg := newClientTestServerConfig(serverListenAddr) srv, err := server.NewServer(serverCfg) if err != nil { @@ -64,7 +68,7 @@ func TestClient(t *testing.T) { t.Fatalf("failed to start server: %s", err) } t.Log("alice connecting to server") - clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU) + clientAlice := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "alice", iface.DefaultMTU) err = clientAlice.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -72,7 +76,7 @@ func TestClient(t *testing.T) { defer clientAlice.Close() t.Log("placeholder connecting to server") - clientPlaceHolder := NewClient(serverURL, hmacTokenStore, "clientPlaceHolder", iface.DefaultMTU) + clientPlaceHolder := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "clientPlaceHolder", iface.DefaultMTU) err = clientPlaceHolder.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -80,7 +84,7 @@ func TestClient(t *testing.T) { defer clientPlaceHolder.Close() t.Log("Bob connecting to server") - clientBob := NewClient(serverURL, hmacTokenStore, "bob", iface.DefaultMTU) + clientBob := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "bob", iface.DefaultMTU) err = clientBob.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -120,6 +124,8 @@ func TestClient(t *testing.T) { func TestRegistration(t *testing.T) { ctx := context.Background() + serverListenAddr := "127.0.0.1:50101" + serverCfg := newClientTestServerConfig(serverListenAddr) srvCfg := server.ListenerConfig{Address: serverListenAddr} srv, err := server.NewServer(serverCfg) if err != nil { @@ -138,7 +144,7 @@ func TestRegistration(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU) + clientAlice := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "alice", iface.DefaultMTU) err = clientAlice.Connect(ctx) if err != nil { _ = srv.Shutdown(ctx) @@ -157,7 +163,7 @@ func TestRegistration(t *testing.T) { func TestRegistrationTimeout(t *testing.T) { ctx := context.Background() fakeUDPListener, err := net.ListenUDP("udp", &net.UDPAddr{ - Port: 1234, + Port: 50201, IP: net.ParseIP("0.0.0.0"), }) if err != nil { @@ -168,7 +174,7 @@ func TestRegistrationTimeout(t *testing.T) { }(fakeUDPListener) fakeTCPListener, err := net.ListenTCP("tcp", &net.TCPAddr{ - Port: 1234, + Port: 50201, IP: net.ParseIP("0.0.0.0"), }) if err != nil { @@ -178,7 +184,7 @@ func TestRegistrationTimeout(t *testing.T) { _ = fakeTCPListener.Close() }(fakeTCPListener) - clientAlice := NewClient("127.0.0.1:1234", hmacTokenStore, "alice", iface.DefaultMTU) + clientAlice := NewClient("127.0.0.1:50201", hmacTokenStore, "alice", iface.DefaultMTU) err = clientAlice.Connect(ctx) if err == nil { t.Errorf("failed to connect to server: %s", err) @@ -192,6 +198,8 @@ func TestRegistrationTimeout(t *testing.T) { func TestEcho(t *testing.T) { ctx := context.Background() + serverListenAddr := "127.0.0.1:50301" + serverCfg := newClientTestServerConfig(serverListenAddr) idAlice := "alice" idBob := "bob" srvCfg := server.ListenerConfig{Address: serverListenAddr} @@ -219,7 +227,7 @@ func TestEcho(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - clientAlice := NewClient(serverURL, hmacTokenStore, idAlice, iface.DefaultMTU) + clientAlice := NewClient(serverCfg.ExposedAddress, hmacTokenStore, idAlice, iface.DefaultMTU) err = clientAlice.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -231,7 +239,7 @@ func TestEcho(t *testing.T) { } }() - clientBob := NewClient(serverURL, hmacTokenStore, idBob, iface.DefaultMTU) + clientBob := NewClient(serverCfg.ExposedAddress, hmacTokenStore, idBob, iface.DefaultMTU) err = clientBob.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -282,6 +290,8 @@ func TestEcho(t *testing.T) { func TestBindToUnavailabePeer(t *testing.T) { ctx := context.Background() + serverListenAddr := "127.0.0.1:50401" + serverCfg := newClientTestServerConfig(serverListenAddr) srvCfg := server.ListenerConfig{Address: serverListenAddr} srv, err := server.NewServer(serverCfg) @@ -309,7 +319,7 @@ func TestBindToUnavailabePeer(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU) + clientAlice := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "alice", iface.DefaultMTU) err = clientAlice.Connect(ctx) if err != nil { t.Errorf("failed to connect to server: %s", err) @@ -328,6 +338,8 @@ func TestBindToUnavailabePeer(t *testing.T) { func TestBindReconnect(t *testing.T) { ctx := context.Background() + serverListenAddr := "127.0.0.1:50501" + serverCfg := newClientTestServerConfig(serverListenAddr) srvCfg := server.ListenerConfig{Address: serverListenAddr} srv, err := server.NewServer(serverCfg) @@ -355,13 +367,13 @@ func TestBindReconnect(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU) + clientAlice := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "alice", iface.DefaultMTU) err = clientAlice.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) } - clientBob := NewClient(serverURL, hmacTokenStore, "bob", iface.DefaultMTU) + clientBob := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "bob", iface.DefaultMTU) err = clientBob.Connect(ctx) if err != nil { t.Errorf("failed to connect to server: %s", err) @@ -383,7 +395,7 @@ func TestBindReconnect(t *testing.T) { t.Errorf("failed to close client: %s", err) } - clientAlice = NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU) + clientAlice = NewClient(serverCfg.ExposedAddress, hmacTokenStore, "alice", iface.DefaultMTU) err = clientAlice.Connect(ctx) if err != nil { t.Errorf("failed to connect to server: %s", err) @@ -429,6 +441,8 @@ func TestBindReconnect(t *testing.T) { func TestCloseConn(t *testing.T) { ctx := context.Background() + serverListenAddr := "127.0.0.1:50601" + serverCfg := newClientTestServerConfig(serverListenAddr) srvCfg := server.ListenerConfig{Address: serverListenAddr} srv, err := server.NewServer(serverCfg) @@ -456,13 +470,13 @@ func TestCloseConn(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - bob := NewClient(serverURL, hmacTokenStore, "bob", iface.DefaultMTU) + bob := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "bob", iface.DefaultMTU) err = bob.Connect(ctx) if err != nil { t.Errorf("failed to connect to server: %s", err) } - clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU) + clientAlice := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "alice", iface.DefaultMTU) err = clientAlice.Connect(ctx) if err != nil { t.Errorf("failed to connect to server: %s", err) @@ -492,6 +506,8 @@ func TestCloseConn(t *testing.T) { func TestCloseRelayConn(t *testing.T) { ctx := context.Background() + serverListenAddr := "127.0.0.1:50701" + serverCfg := newClientTestServerConfig(serverListenAddr) srvCfg := server.ListenerConfig{Address: serverListenAddr} srv, err := server.NewServer(serverCfg) @@ -518,13 +534,13 @@ func TestCloseRelayConn(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - bob := NewClient(serverURL, hmacTokenStore, "bob", iface.DefaultMTU) + bob := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "bob", iface.DefaultMTU) err = bob.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) } - clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU) + clientAlice := NewClient(serverCfg.ExposedAddress, hmacTokenStore, "alice", iface.DefaultMTU) err = clientAlice.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -550,6 +566,8 @@ func TestCloseRelayConn(t *testing.T) { func TestCloseByServer(t *testing.T) { ctx := context.Background() + serverListenAddr := "127.0.0.1:50801" + serverCfg := newClientTestServerConfig(serverListenAddr) srvCfg := server.ListenerConfig{Address: serverListenAddr} srv1, err := server.NewServer(serverCfg) @@ -572,7 +590,7 @@ func TestCloseByServer(t *testing.T) { idAlice := "alice" log.Debugf("connect by alice") - relayClient := NewClient(serverURL, hmacTokenStore, idAlice, iface.DefaultMTU) + relayClient := NewClient(serverCfg.ExposedAddress, hmacTokenStore, idAlice, iface.DefaultMTU) if err = relayClient.Connect(ctx); err != nil { log.Fatalf("failed to connect to server: %s", err) } @@ -607,6 +625,8 @@ func TestCloseByServer(t *testing.T) { func TestCloseByClient(t *testing.T) { ctx := context.Background() + serverListenAddr := "127.0.0.1:50901" + serverCfg := newClientTestServerConfig(serverListenAddr) srvCfg := server.ListenerConfig{Address: serverListenAddr} srv, err := server.NewServer(serverCfg) @@ -628,7 +648,7 @@ func TestCloseByClient(t *testing.T) { idAlice := "alice" log.Debugf("connect by alice") - relayClient := NewClient(serverURL, hmacTokenStore, idAlice, iface.DefaultMTU) + relayClient := NewClient(serverCfg.ExposedAddress, hmacTokenStore, idAlice, iface.DefaultMTU) err = relayClient.Connect(ctx) if err != nil { log.Fatalf("failed to connect to server: %s", err) @@ -652,6 +672,8 @@ func TestCloseByClient(t *testing.T) { func TestCloseNotDrainedChannel(t *testing.T) { ctx := context.Background() + serverListenAddr := "127.0.0.1:51001" + serverCfg := newClientTestServerConfig(serverListenAddr) idAlice := "alice" idBob := "bob" srvCfg := server.ListenerConfig{Address: serverListenAddr} @@ -679,7 +701,7 @@ func TestCloseNotDrainedChannel(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - clientAlice := NewClient(serverURL, hmacTokenStore, idAlice, iface.DefaultMTU) + clientAlice := NewClient(serverCfg.ExposedAddress, hmacTokenStore, idAlice, iface.DefaultMTU) err = clientAlice.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -691,7 +713,7 @@ func TestCloseNotDrainedChannel(t *testing.T) { } }() - clientBob := NewClient(serverURL, hmacTokenStore, idBob, iface.DefaultMTU) + clientBob := NewClient(serverCfg.ExposedAddress, hmacTokenStore, idBob, iface.DefaultMTU) err = clientBob.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) diff --git a/shared/relay/client/dialer/quic/conn.go b/shared/relay/client/dialer/quic/conn.go index 9243605b5..1d90d7139 100644 --- a/shared/relay/client/dialer/quic/conn.go +++ b/shared/relay/client/dialer/quic/conn.go @@ -30,11 +30,11 @@ func (a Addr) String() string { } type Conn struct { - session quic.Connection + session *quic.Conn ctx context.Context } -func NewConn(session quic.Connection) net.Conn { +func NewConn(session *quic.Conn) net.Conn { return &Conn{ session: session, ctx: context.Background(), diff --git a/shared/relay/client/manager_test.go b/shared/relay/client/manager_test.go index f00b35707..fb91f7682 100644 --- a/shared/relay/client/manager_test.go +++ b/shared/relay/client/manager_test.go @@ -13,6 +13,16 @@ import ( "github.com/netbirdio/netbird/shared/relay/auth/allow" ) +// newManagerTestServerConfig creates a new server config for manager testing with the given address +func newManagerTestServerConfig(address string) server.Config { + return server.Config{ + Meter: otel.Meter(""), + ExposedAddress: address, + TLSSupport: false, + AuthValidator: &allow.Auth{}, + } +} + func TestEmptyURL(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -27,15 +37,10 @@ func TestForeignConn(t *testing.T) { ctx := context.Background() lstCfg1 := server.ListenerConfig{ - Address: "localhost:1234", + Address: "localhost:52101", } - srv1, err := server.NewServer(server.Config{ - Meter: otel.Meter(""), - ExposedAddress: lstCfg1.Address, - TLSSupport: false, - AuthValidator: &allow.Auth{}, - }) + srv1, err := server.NewServer(newManagerTestServerConfig(lstCfg1.Address)) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -59,14 +64,9 @@ func TestForeignConn(t *testing.T) { } srvCfg2 := server.ListenerConfig{ - Address: "localhost:2234", + Address: "localhost:52102", } - srv2, err := server.NewServer(server.Config{ - Meter: otel.Meter(""), - ExposedAddress: srvCfg2.Address, - TLSSupport: false, - AuthValidator: &allow.Auth{}, - }) + srv2, err := server.NewServer(newManagerTestServerConfig(srvCfg2.Address)) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -144,9 +144,9 @@ func TestForeginConnClose(t *testing.T) { ctx := context.Background() srvCfg1 := server.ListenerConfig{ - Address: "localhost:1234", + Address: "localhost:52201", } - srv1, err := server.NewServer(serverCfg) + srv1, err := server.NewServer(newManagerTestServerConfig(srvCfg1.Address)) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -170,9 +170,9 @@ func TestForeginConnClose(t *testing.T) { } srvCfg2 := server.ListenerConfig{ - Address: "localhost:2234", + Address: "localhost:52202", } - srv2, err := server.NewServer(serverCfg) + srv2, err := server.NewServer(newManagerTestServerConfig(srvCfg2.Address)) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -225,9 +225,9 @@ func TestForeignAutoClose(t *testing.T) { keepUnusedServerTime = 2 * time.Second srvCfg1 := server.ListenerConfig{ - Address: "localhost:1234", + Address: "localhost:52301", } - srv1, err := server.NewServer(serverCfg) + srv1, err := server.NewServer(newManagerTestServerConfig(srvCfg1.Address)) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -252,9 +252,9 @@ func TestForeignAutoClose(t *testing.T) { } srvCfg2 := server.ListenerConfig{ - Address: "localhost:2234", + Address: "localhost:52302", } - srv2, err := server.NewServer(serverCfg) + srv2, err := server.NewServer(newManagerTestServerConfig(srvCfg2.Address)) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -327,9 +327,9 @@ func TestAutoReconnect(t *testing.T) { ctx := context.Background() srvCfg := server.ListenerConfig{ - Address: "localhost:1234", + Address: "localhost:52401", } - srv, err := server.NewServer(serverCfg) + srv, err := server.NewServer(newManagerTestServerConfig(srvCfg.Address)) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -397,14 +397,9 @@ func TestNotifierDoubleAdd(t *testing.T) { ctx := context.Background() listenerCfg1 := server.ListenerConfig{ - Address: "localhost:1234", + Address: "localhost:52501", } - srv, err := server.NewServer(server.Config{ - Meter: otel.Meter(""), - ExposedAddress: listenerCfg1.Address, - TLSSupport: false, - AuthValidator: &allow.Auth{}, - }) + srv, err := server.NewServer(newManagerTestServerConfig(listenerCfg1.Address)) if err != nil { t.Fatalf("failed to create server: %s", err) } diff --git a/signal/cmd/run.go b/signal/cmd/run.go index bf8f8e327..d7662a886 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -73,7 +73,7 @@ var ( // detect whether user specified a port userPort := cmd.Flag("port").Changed - tlsEnabled := false + var tlsEnabled bool if signalLetsencryptDomain != "" || (signalCertFile != "" && signalCertKey != "") { tlsEnabled = true } @@ -259,8 +259,8 @@ func grpcHandlerFunc(grpcServer *grpc.Server, meter metric.Meter) http.Handler { wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter)) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch { - case r.URL.Path == wsproxy.ProxyPath+wsproxy.SignalComponent: + switch r.URL.Path { + case wsproxy.ProxyPath + wsproxy.SignalComponent: wsProxy.Handler().ServeHTTP(w, r) default: grpcServer.ServeHTTP(w, r) diff --git a/util/syslog_nonwindows.go b/util/syslog_nonwindows.go index 6ffbcb8be..328bb8b1c 100644 --- a/util/syslog_nonwindows.go +++ b/util/syslog_nonwindows.go @@ -1,5 +1,4 @@ //go:build !windows -// +build !windows package util