diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index e727aa4e5..d585ba209 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -146,6 +146,65 @@ jobs: - name: Test run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay) + test_client_on_docker: + name: "Client (Docker) / Unit" + needs: [build-cache] + runs-on: ubuntu-22.04 + steps: + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: "1.23.x" + cache: false + + - name: Checkout code + uses: actions/checkout@v4 + + - name: Get Go environment + id: go-env + run: | + echo "cache_dir=$(go env GOCACHE)" >> $GITHUB_OUTPUT + echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT + + - name: Cache Go modules + uses: actions/cache/restore@v4 + id: cache-restore + with: + path: | + ${{ steps.go-env.outputs.cache_dir }} + ${{ steps.go-env.outputs.modcache_dir }} + key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-gotest-cache- + + - name: Run tests in container + env: + HOST_GOCACHE: ${{ steps.go-env.outputs.cache_dir }} + HOST_GOMODCACHE: ${{ steps.go-env.outputs.modcache_dir }} + run: | + CONTAINER_GOCACHE="/root/.cache/go-build" + CONTAINER_GOMODCACHE="/go/pkg/mod" + + docker run --rm \ + --cap-add=NET_ADMIN \ + --privileged \ + -v $PWD:/app \ + -w /app \ + -v "${HOST_GOCACHE}:${CONTAINER_GOCACHE}" \ + -v "${HOST_GOMODCACHE}:${CONTAINER_GOMODCACHE}" \ + -e CGO_ENABLED=1 \ + -e CI=true \ + -e DOCKER_CI=true \ + -e GOARCH=${GOARCH_TARGET} \ + -e GOCACHE=${CONTAINER_GOCACHE} \ + -e GOMODCACHE=${CONTAINER_GOMODCACHE} \ + golang:1.23-alpine \ + sh -c ' \ + apk update; apk add --no-cache \ + ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \ + go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server) + ' + test_relay: name: "Relay / Unit" needs: [build-cache] @@ -179,13 +238,6 @@ jobs: restore-keys: | ${{ runner.os }}-gotest-cache- - - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev - - - name: Install 32-bit libpcap - if: matrix.arch == '386' - run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 - - name: Install modules run: go mod tidy @@ -232,13 +284,6 @@ jobs: restore-keys: | ${{ runner.os }}-gotest-cache- - - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev - - - name: Install 32-bit libpcap - if: matrix.arch == '386' - run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 - - name: Install modules run: go mod tidy @@ -286,13 +331,6 @@ jobs: restore-keys: | ${{ runner.os }}-gotest-cache- - - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev - - - name: Install 32-bit libpcap - if: matrix.arch == '386' - run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 - - name: Install modules run: go mod tidy @@ -354,13 +392,6 @@ jobs: restore-keys: | ${{ runner.os }}-gotest-cache- - - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev - - - name: Install 32-bit libpcap - if: matrix.arch == '386' - run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 - - name: Install modules run: go mod tidy @@ -385,7 +416,7 @@ jobs: CI=true \ go test -tags devcert -run=^$ -bench=. \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ - -timeout 20m ./... + -timeout 20m ./management/... api_benchmark: name: "Management / Benchmark (API)" @@ -449,13 +480,6 @@ jobs: restore-keys: | ${{ runner.os }}-gotest-cache- - - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev - - - name: Install 32-bit libpcap - if: matrix.arch == '386' - run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 - - name: Install modules run: go mod tidy @@ -520,13 +544,6 @@ jobs: restore-keys: | ${{ runner.os }}-gotest-cache- - - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev - - - name: Install 32-bit libpcap - if: matrix.arch == '386' - run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 - - name: Install modules run: go mod tidy @@ -541,85 +558,3 @@ jobs: go test -tags=integration \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ -timeout 20m ./management/... - - test_client_on_docker: - name: "Client (Docker) / Unit" - needs: [ build-cache ] - runs-on: ubuntu-20.04 - steps: - - name: Install Go - uses: actions/setup-go@v5 - with: - go-version: "1.23.x" - cache: false - - - name: Checkout code - uses: actions/checkout@v4 - - - name: Get Go environment - run: | - echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV - echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV - - - name: Cache Go modules - uses: actions/cache/restore@v4 - with: - path: | - ${{ env.cache }} - ${{ env.modcache }} - key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-gotest-cache- - - - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev - - - name: Install modules - run: go mod tidy - - - name: check git status - run: git --no-pager diff --exit-code - - - name: Generate Shared Sock Test bin - run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock - - - name: Generate RouteManager Test bin - run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager - - - name: Generate SystemOps Test bin - run: CGO_ENABLED=1 go test -c -o systemops-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/systemops - - - name: Generate nftables Manager Test bin - run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/... - - - name: Generate Engine Test bin - run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal - - - name: Generate Peer Test bin - run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/ - - - run: chmod +x *testing.bin - - - name: Run Shared Sock tests in docker - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1 - - - name: Run Iface tests in docker - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/... - - - name: Run RouteManager tests in docker - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1 - - - name: Run SystemOps tests in docker - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager/systemops --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/systemops-testing.bin -test.timeout 5m -test.parallel 1 - - - name: Run nftables Manager tests in docker - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1 - - - name: Run Engine tests in docker with file store - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1 - - - name: Run Engine tests in docker with sqlite store - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1 - - - name: Run Peer tests in docker - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1 diff --git a/.goreleaser.yaml b/.goreleaser.yaml index d6479763e..112659d1c 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -96,6 +96,20 @@ builds: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser mod_timestamp: "{{ .CommitTimestamp }}" + - id: netbird-upload + dir: upload-server + env: [CGO_ENABLED=0] + binary: netbird-upload + goos: + - linux + goarch: + - amd64 + - arm64 + - arm + ldflags: + - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser + mod_timestamp: "{{ .CommitTimestamp }}" + universal_binaries: - id: netbird @@ -409,6 +423,52 @@ dockers: - "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.version={{.Version}}" - "--label=maintainer=dev@netbird.io" + - image_templates: + - netbirdio/upload:{{ .Version }}-amd64 + ids: + - netbird-upload + goarch: amd64 + use: buildx + dockerfile: upload-server/Dockerfile + build_flag_templates: + - "--platform=linux/amd64" + - "--label=org.opencontainers.image.created={{.Date}}" + - "--label=org.opencontainers.image.title={{.ProjectName}}" + - "--label=org.opencontainers.image.version={{.Version}}" + - "--label=org.opencontainers.image.revision={{.FullCommit}}" + - "--label=org.opencontainers.image.version={{.Version}}" + - "--label=maintainer=dev@netbird.io" + - image_templates: + - netbirdio/upload:{{ .Version }}-arm64v8 + ids: + - netbird-upload + goarch: arm64 + use: buildx + dockerfile: upload-server/Dockerfile + build_flag_templates: + - "--platform=linux/arm64" + - "--label=org.opencontainers.image.created={{.Date}}" + - "--label=org.opencontainers.image.title={{.ProjectName}}" + - "--label=org.opencontainers.image.version={{.Version}}" + - "--label=org.opencontainers.image.revision={{.FullCommit}}" + - "--label=org.opencontainers.image.version={{.Version}}" + - "--label=maintainer=dev@netbird.io" + - image_templates: + - netbirdio/upload:{{ .Version }}-arm + ids: + - netbird-upload + goarch: arm + goarm: 6 + use: buildx + dockerfile: upload-server/Dockerfile + build_flag_templates: + - "--platform=linux/arm" + - "--label=org.opencontainers.image.created={{.Date}}" + - "--label=org.opencontainers.image.title={{.ProjectName}}" + - "--label=org.opencontainers.image.version={{.Version}}" + - "--label=org.opencontainers.image.revision={{.FullCommit}}" + - "--label=org.opencontainers.image.version={{.Version}}" + - "--label=maintainer=dev@netbird.io" docker_manifests: - name_template: netbirdio/netbird:{{ .Version }} image_templates: @@ -475,7 +535,17 @@ docker_manifests: - netbirdio/management:{{ .Version }}-debug-arm64v8 - netbirdio/management:{{ .Version }}-debug-arm - netbirdio/management:{{ .Version }}-debug-amd64 + - name_template: netbirdio/upload:{{ .Version }} + image_templates: + - netbirdio/upload:{{ .Version }}-arm64v8 + - netbirdio/upload:{{ .Version }}-arm + - netbirdio/upload:{{ .Version }}-amd64 + - name_template: netbirdio/upload:latest + image_templates: + - netbirdio/upload:{{ .Version }}-arm64v8 + - netbirdio/upload:{{ .Version }}-arm + - netbirdio/upload:{{ .Version }}-amd64 brews: - ids: - default diff --git a/client/Dockerfile b/client/Dockerfile index 35c1d04c2..16b2916c7 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -1,5 +1,6 @@ FROM alpine:3.21.3 -RUN apk add --no-cache ca-certificates iptables ip6tables +# iproute2: busybox doesn't display ip rules properly +RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables ENV NB_FOREGROUND_MODE=true ENTRYPOINT [ "/usr/local/bin/netbird","up"] -COPY netbird /usr/local/bin/netbird \ No newline at end of file +COPY netbird /usr/local/bin/netbird diff --git a/client/cmd/debug.go b/client/cmd/debug.go index d2e5bdd7e..385bd95f5 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -87,16 +87,27 @@ func debugBundle(cmd *cobra.Command, _ []string) error { }() client := proto.NewDaemonServiceClient(conn) - resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{ + request := &proto.DebugBundleRequest{ Anonymize: anonymizeFlag, Status: getStatusOutput(cmd, anonymizeFlag), SystemInfo: debugSystemInfoFlag, - }) + } + if debugUploadBundle { + request.UploadURL = debugUploadBundleURL + } + resp, err := client.DebugBundle(cmd.Context(), request) if err != nil { return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message()) } + cmd.Printf("Local file:\n%s\n", resp.GetPath()) - cmd.Println(resp.GetPath()) + if resp.GetUploadFailureReason() != "" { + return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason()) + } + + if debugUploadBundle { + cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey()) + } return nil } @@ -211,23 +222,19 @@ func runForDuration(cmd *cobra.Command, args []string) error { headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration) statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag)) - - resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{ + request := &proto.DebugBundleRequest{ Anonymize: anonymizeFlag, Status: statusOutput, SystemInfo: debugSystemInfoFlag, - }) + } + if debugUploadBundle { + request.UploadURL = debugUploadBundleURL + } + resp, err := client.DebugBundle(cmd.Context(), request) if err != nil { return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message()) } - // Disable network map persistence after creating the debug bundle - if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{ - Enabled: false, - }); err != nil { - return fmt.Errorf("failed to disable network map persistence: %v", status.Convert(err).Message()) - } - if stateWasDown { if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { return fmt.Errorf("failed to down: %v", status.Convert(err).Message()) @@ -242,7 +249,15 @@ func runForDuration(cmd *cobra.Command, args []string) error { cmd.Println("Log level restored to", initialLogLevel.GetLevel()) } - cmd.Println(resp.GetPath()) + cmd.Printf("Local file:\n%s\n", resp.GetPath()) + + if resp.GetUploadFailureReason() != "" { + return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason()) + } + + if debugUploadBundle { + cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey()) + } return nil } diff --git a/client/cmd/login.go b/client/cmd/login.go index c86d6c636..84906a7a4 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -55,6 +55,9 @@ var loginCmd = &cobra.Command{ return err } + // update host's static platform and system information + system.UpdateStaticInfo() + ic := internal.ConfigInput{ ManagementURL: managementURL, AdminURL: adminURL, diff --git a/client/cmd/root.go b/client/cmd/root.go index baf444b99..b57bee230 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -22,6 +22,7 @@ import ( "google.golang.org/grpc/credentials/insecure" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/upload-server/types" ) const ( @@ -39,6 +40,8 @@ const ( dnsRouteIntervalFlag = "dns-router-interval" systemInfoFlag = "system-info" blockLANAccessFlag = "block-lan-access" + uploadBundle = "upload-bundle" + uploadBundleURL = "upload-bundle-url" ) var ( @@ -75,6 +78,8 @@ var ( debugSystemInfoFlag bool dnsRouteInterval time.Duration blockLANAccess bool + debugUploadBundle bool + debugUploadBundleURL string rootCmd = &cobra.Command{ Use: "netbird", @@ -181,6 +186,8 @@ func init() { upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.") debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle") + debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL)) + debugCmd.PersistentFlags().StringVar(&debugUploadBundleURL, uploadBundleURL, types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle") } // SetupCloseHandler handles SIGTERM signal and exits with success diff --git a/client/cmd/service_controller.go b/client/cmd/service_controller.go index 0ddf6c4c8..5e3c63e57 100644 --- a/client/cmd/service_controller.go +++ b/client/cmd/service_controller.go @@ -16,12 +16,17 @@ import ( "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/server" + "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/util" ) func (p *program) Start(svc service.Service) error { // Start should not block. Do the actual work async. log.Info("starting Netbird service") //nolint + + // Collect static system and platform information + system.UpdateStaticInfo() + // in any case, even if configuration does not exists we run daemon to serve CLI gRPC API. p.serv = grpc.NewServer() diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 70abe4abe..258a8daff 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -98,6 +98,11 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc settingsMockManager := settings.NewMockManager(ctrl) permissionsManagerMock := permissions.NewMockManager(ctrl) + settingsMockManager.EXPECT(). + GetSettings(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&types.Settings{}, nil). + AnyTimes() + accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock) if err != nil { t.Fatal(err) diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 652ab1b3e..b229688fc 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -113,17 +113,16 @@ func (m *Manager) AddPeerFiltering( func (m *Manager) AddRouteFiltering( id []byte, sources []netip.Prefix, - destination netip.Prefix, + destination firewall.Network, proto firewall.Protocol, - sPort *firewall.Port, - dPort *firewall.Port, + sPort, dPort *firewall.Port, action firewall.Action, ) (firewall.Rule, error) { m.mutex.Lock() defer m.mutex.Unlock() - if !destination.Addr().Is4() { - return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String()) + if destination.IsPrefix() && !destination.Prefix.Addr().Is4() { + return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String()) } return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) @@ -243,6 +242,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { return m.router.DeleteDNATRule(rule) } +// UpdateSet updates the set with the given prefixes +func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.UpdateSet(set, prefixes) +} + func getConntrackEstablished() []string { return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} } diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 869b0b359..bb799b99b 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -57,18 +57,18 @@ type ruleInfo struct { } type routeFilteringRuleParams struct { - Sources []netip.Prefix - Destination netip.Prefix + Source firewall.Network + Destination firewall.Network Proto firewall.Protocol SPort *firewall.Port DPort *firewall.Port Direction firewall.RuleDirection Action firewall.Action - SetName string } type routeRules map[string][]string +// the ipset library currently does not support comments, so we use the name only (string) type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}] type router struct { @@ -129,7 +129,7 @@ func (r *router) init(stateManager *statemanager.Manager) error { func (r *router) AddRouteFiltering( id []byte, sources []netip.Prefix, - destination netip.Prefix, + destination firewall.Network, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, @@ -140,27 +140,28 @@ func (r *router) AddRouteFiltering( return ruleKey, nil } - var setName string + var source firewall.Network if len(sources) > 1 { - setName = firewall.GenerateSetName(sources) - if _, err := r.ipsetCounter.Increment(setName, sources); err != nil { - return nil, fmt.Errorf("create or get ipset: %w", err) - } + source.Set = firewall.NewPrefixSet(sources) + } else if len(sources) > 0 { + source.Prefix = sources[0] } params := routeFilteringRuleParams{ - Sources: sources, + Source: source, Destination: destination, Proto: proto, SPort: sPort, DPort: dPort, Action: action, - SetName: setName, } - rule := genRouteFilteringRuleSpec(params) + rule, err := r.genRouteRuleSpec(params, sources) + if err != nil { + return nil, fmt.Errorf("generate route rule spec: %w", err) + } + // Insert DROP rules at the beginning, append ACCEPT rules at the end - var err error if action == firewall.ActionDrop { // after the established rule err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...) @@ -183,17 +184,13 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error { ruleKey := rule.ID() if rule, exists := r.rules[ruleKey]; exists { - setName := r.findSetNameInRule(rule) - if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil { return fmt.Errorf("delete route rule: %v", err) } delete(r.rules, ruleKey) - if setName != "" { - if _, err := r.ipsetCounter.Decrement(setName); err != nil { - return fmt.Errorf("failed to remove ipset: %w", err) - } + if err := r.decrementSetCounter(rule); err != nil { + return fmt.Errorf("decrement ipset counter: %w", err) } } else { log.Debugf("route rule %s not found", ruleKey) @@ -204,13 +201,26 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error { return nil } -func (r *router) findSetNameInRule(rule []string) string { - for i, arg := range rule { - if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet { - return rule[i+3] +func (r *router) decrementSetCounter(rule []string) error { + sets := r.findSets(rule) + var merr *multierror.Error + for _, setName := range sets { + if _, err := r.ipsetCounter.Decrement(setName); err != nil { + merr = multierror.Append(merr, fmt.Errorf("decrement counter: %w", err)) } } - return "" + + return nberrors.FormatErrorOrNil(merr) +} + +func (r *router) findSets(rule []string) []string { + var sets []string + for i, arg := range rule { + if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet { + sets = append(sets, rule[i+3]) + } + } + return sets } func (r *router) createIpSet(setName string, sources []netip.Prefix) error { @@ -231,6 +241,8 @@ func (r *router) deleteIpSet(setName string) error { if err := ipset.Destroy(setName); err != nil { return fmt.Errorf("destroy set %s: %w", setName, err) } + + log.Debugf("Deleted unused ipset %s", setName) return nil } @@ -270,12 +282,14 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error { log.Errorf("%v", err) } - if err := r.removeNatRule(pair); err != nil { - return fmt.Errorf("remove nat rule: %w", err) - } + if pair.Masquerade { + if err := r.removeNatRule(pair); err != nil { + return fmt.Errorf("remove nat rule: %w", err) + } - if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { - return fmt.Errorf("remove inverse nat rule: %w", err) + if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { + return fmt.Errorf("remove inverse nat rule: %w", err) + } } if err := r.removeLegacyRouteRule(pair); err != nil { @@ -313,8 +327,10 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error { return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) } delete(r.rules, ruleKey) - } else { - log.Debugf("legacy forwarding rule %s not found", ruleKey) + + if err := r.decrementSetCounter(rule); err != nil { + return fmt.Errorf("decrement ipset counter: %w", err) + } } return nil @@ -599,12 +615,26 @@ func (r *router) addNatRule(pair firewall.RouterPair) error { rule = append(rule, "-m", "conntrack", "--ctstate", "NEW", - "-s", pair.Source.String(), - "-d", pair.Destination.String(), + ) + sourceExp, err := r.applyNetwork("-s", pair.Source, nil) + if err != nil { + return fmt.Errorf("apply network -s: %w", err) + } + destExp, err := r.applyNetwork("-d", pair.Destination, nil) + if err != nil { + return fmt.Errorf("apply network -d: %w", err) + } + + rule = append(rule, sourceExp...) + rule = append(rule, destExp...) + rule = append(rule, "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue), ) - if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil { + // Ensure nat rules come first, so the mark can be overwritten. + // Currently overwritten by the dst-type LOCAL rules for redirected traffic. + if err := r.iptablesClient.Insert(tableMangle, chainRTPRE, 1, rule...); err != nil { + // TODO: rollback ipset counter return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err) } @@ -622,6 +652,10 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error { return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err) } delete(r.rules, ruleKey) + + if err := r.decrementSetCounter(rule); err != nil { + return fmt.Errorf("decrement ipset counter: %w", err) + } } else { log.Debugf("marking rule %s not found", ruleKey) } @@ -787,17 +821,21 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error { return nberrors.FormatErrorOrNil(merr) } -func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string { +func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []netip.Prefix) ([]string, error) { var rule []string - if params.SetName != "" { - rule = append(rule, "-m", "set", matchSet, params.SetName, "src") - } else if len(params.Sources) > 0 { - source := params.Sources[0] - rule = append(rule, "-s", source.String()) + sourceExp, err := r.applyNetwork("-s", params.Source, sources) + if err != nil { + return nil, fmt.Errorf("apply network -s: %w", err) + + } + destExp, err := r.applyNetwork("-d", params.Destination, nil) + if err != nil { + return nil, fmt.Errorf("apply network -d: %w", err) } - rule = append(rule, "-d", params.Destination.String()) + rule = append(rule, sourceExp...) + rule = append(rule, destExp...) if params.Proto != firewall.ProtocolALL { rule = append(rule, "-p", strings.ToLower(string(params.Proto))) @@ -807,7 +845,47 @@ func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string { rule = append(rule, "-j", actionToStr(params.Action)) - return rule + return rule, nil +} + +func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []netip.Prefix) ([]string, error) { + direction := "src" + if flag == "-d" { + direction = "dst" + } + + if network.IsSet() { + if _, err := r.ipsetCounter.Increment(network.Set.HashedName(), prefixes); err != nil { + return nil, fmt.Errorf("create or get ipset: %w", err) + } + + return []string{"-m", "set", matchSet, network.Set.HashedName(), direction}, nil + } + if network.IsPrefix() { + return []string{flag, network.Prefix.String()}, nil + } + + // nolint:nilnil + return nil, nil +} + +func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { + var merr *multierror.Error + for _, prefix := range prefixes { + // TODO: Implement IPv6 support + if prefix.Addr().Is6() { + log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) + continue + } + if err := ipset.AddPrefix(set.HashedName(), prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("increment ipset counter: %w", err)) + } + } + if merr == nil { + log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes) + } + + return nberrors.FormatErrorOrNil(merr) } func applyPort(flag string, port *firewall.Port) []string { diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go index dad77dee7..e9eeff863 100644 --- a/client/firewall/iptables/router_linux_test.go +++ b/client/firewall/iptables/router_linux_test.go @@ -60,8 +60,8 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { pair := firewall.RouterPair{ ID: "abc", - Source: netip.MustParsePrefix("100.100.100.1/32"), - Destination: netip.MustParsePrefix("100.100.100.0/24"), + Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")}, + Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.0/24")}, Masquerade: true, } @@ -332,7 +332,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) + ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action) require.NoError(t, err, "AddRouteFiltering failed") // Check if the rule is in the internal map @@ -347,23 +347,29 @@ func TestRouter_AddRouteFiltering(t *testing.T) { assert.NoError(t, err, "Failed to check rule existence") assert.True(t, exists, "Rule not found in iptables") + var source firewall.Network + if len(tt.sources) > 1 { + source.Set = firewall.NewPrefixSet(tt.sources) + } else if len(tt.sources) > 0 { + source.Prefix = tt.sources[0] + } // Verify rule content params := routeFilteringRuleParams{ - Sources: tt.sources, - Destination: tt.destination, + Source: source, + Destination: firewall.Network{Prefix: tt.destination}, Proto: tt.proto, SPort: tt.sPort, DPort: tt.dPort, Action: tt.action, - SetName: "", } - expectedRule := genRouteFilteringRuleSpec(params) + expectedRule, err := r.genRouteRuleSpec(params, nil) + require.NoError(t, err, "Failed to generate expected rule spec") if tt.expectSet { - setName := firewall.GenerateSetName(tt.sources) - params.SetName = setName - expectedRule = genRouteFilteringRuleSpec(params) + setName := firewall.NewPrefixSet(tt.sources).HashedName() + expectedRule, err = r.genRouteRuleSpec(params, nil) + require.NoError(t, err, "Failed to generate expected rule spec with set") // Check if the set was created _, exists := r.ipsetCounter.Get(setName) @@ -378,3 +384,62 @@ func TestRouter_AddRouteFiltering(t *testing.T) { }) } } + +func TestFindSetNameInRule(t *testing.T) { + r := &router{} + + testCases := []struct { + name string + rule []string + expected []string + }{ + { + name: "Basic rule with two sets", + rule: []string{ + "-A", "NETBIRD-RT-FWD-IN", "-p", "tcp", "-m", "set", "--match-set", "nb-2e5a2a05", "src", + "-m", "set", "--match-set", "nb-349ae051", "dst", "-m", "tcp", "--dport", "8080", "-j", "ACCEPT", + }, + expected: []string{"nb-2e5a2a05", "nb-349ae051"}, + }, + { + name: "No sets", + rule: []string{"-A", "NETBIRD-RT-FWD-IN", "-p", "tcp", "-j", "ACCEPT"}, + expected: []string{}, + }, + { + name: "Multiple sets with different positions", + rule: []string{ + "-m", "set", "--match-set", "set1", "src", "-p", "tcp", + "-m", "set", "--match-set", "set-abc123", "dst", "-j", "ACCEPT", + }, + expected: []string{"set1", "set-abc123"}, + }, + { + name: "Boundary case - sequence appears at end", + rule: []string{"-p", "tcp", "-m", "set", "--match-set", "final-set"}, + expected: []string{"final-set"}, + }, + { + name: "Incomplete pattern - missing set name", + rule: []string{"-p", "tcp", "-m", "set", "--match-set"}, + expected: []string{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := r.findSets(tc.rule) + + if len(result) != len(tc.expected) { + t.Errorf("Expected %d sets, got %d. Sets found: %v", len(tc.expected), len(result), result) + return + } + + for i, set := range result { + if set != tc.expected[i] { + t.Errorf("Expected set %q at position %d, got %q", tc.expected[i], i, set) + } + } + }) + } +} diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 1d71051ef..084d19423 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -1,13 +1,10 @@ package manager import ( - "crypto/sha256" - "encoding/hex" "fmt" "net" "net/netip" "sort" - "strings" log "github.com/sirupsen/logrus" @@ -43,6 +40,18 @@ const ( // Action is the action to be taken on a rule type Action int +// String returns the string representation of the action +func (a Action) String() string { + switch a { + case ActionAccept: + return "accept" + case ActionDrop: + return "drop" + default: + return "unknown" + } +} + const ( // ActionAccept is the action to accept a packet ActionAccept Action = iota @@ -50,6 +59,33 @@ const ( ActionDrop ) +// Network is a rule destination, either a set or a prefix +type Network struct { + Set Set + Prefix netip.Prefix +} + +// String returns the string representation of the destination +func (d Network) String() string { + if d.Prefix.IsValid() { + return d.Prefix.String() + } + if d.IsSet() { + return d.Set.HashedName() + } + return "" +} + +// IsSet returns true if the destination is a set +func (d Network) IsSet() bool { + return d.Set != Set{} +} + +// IsPrefix returns true if the destination is a valid prefix +func (d Network) IsPrefix() bool { + return d.Prefix.IsValid() +} + // Manager is the high level abstraction of a firewall manager // // It declares methods which handle actions required by the @@ -83,10 +119,9 @@ type Manager interface { AddRouteFiltering( id []byte, sources []netip.Prefix, - destination netip.Prefix, + destination Network, proto Protocol, - sPort *Port, - dPort *Port, + sPort, dPort *Port, action Action, ) (Rule, error) @@ -119,6 +154,9 @@ type Manager interface { // DeleteDNATRule deletes a DNAT rule DeleteDNATRule(Rule) error + + // UpdateSet updates the set with the given prefixes + UpdateSet(hash Set, prefixes []netip.Prefix) error } func GenKey(format string, pair RouterPair) string { @@ -153,22 +191,6 @@ func SetLegacyManagement(router LegacyManager, isLegacy bool) error { return nil } -// GenerateSetName generates a unique name for an ipset based on the given sources. -func GenerateSetName(sources []netip.Prefix) string { - // sort for consistent naming - SortPrefixes(sources) - - var sourcesStr strings.Builder - for _, src := range sources { - sourcesStr.WriteString(src.String()) - } - - hash := sha256.Sum256([]byte(sourcesStr.String())) - shortHash := hex.EncodeToString(hash[:])[:8] - - return fmt.Sprintf("nb-%s", shortHash) -} - // MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix { if len(prefixes) == 0 { diff --git a/client/firewall/manager/firewall_test.go b/client/firewall/manager/firewall_test.go index 3f47d6679..180346906 100644 --- a/client/firewall/manager/firewall_test.go +++ b/client/firewall/manager/firewall_test.go @@ -20,8 +20,8 @@ func TestGenerateSetName(t *testing.T) { netip.MustParsePrefix("192.168.1.0/24"), } - result1 := manager.GenerateSetName(prefixes1) - result2 := manager.GenerateSetName(prefixes2) + result1 := manager.NewPrefixSet(prefixes1) + result2 := manager.NewPrefixSet(prefixes2) if result1 != result2 { t.Errorf("Different orders produced different hashes: %s != %s", result1, result2) @@ -34,9 +34,9 @@ func TestGenerateSetName(t *testing.T) { netip.MustParsePrefix("10.0.0.0/8"), } - result := manager.GenerateSetName(prefixes) + result := manager.NewPrefixSet(prefixes) - matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result) + matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result.HashedName()) if err != nil { t.Fatalf("Error matching regex: %v", err) } @@ -46,8 +46,8 @@ func TestGenerateSetName(t *testing.T) { }) t.Run("Empty input produces consistent result", func(t *testing.T) { - result1 := manager.GenerateSetName([]netip.Prefix{}) - result2 := manager.GenerateSetName([]netip.Prefix{}) + result1 := manager.NewPrefixSet([]netip.Prefix{}) + result2 := manager.NewPrefixSet([]netip.Prefix{}) if result1 != result2 { t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2) @@ -64,8 +64,8 @@ func TestGenerateSetName(t *testing.T) { netip.MustParsePrefix("192.168.1.0/24"), } - result1 := manager.GenerateSetName(prefixes1) - result2 := manager.GenerateSetName(prefixes2) + result1 := manager.NewPrefixSet(prefixes1) + result2 := manager.NewPrefixSet(prefixes2) if result1 != result2 { t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2) diff --git a/client/firewall/manager/routerpair.go b/client/firewall/manager/routerpair.go index 8c94b7dd4..079c051d9 100644 --- a/client/firewall/manager/routerpair.go +++ b/client/firewall/manager/routerpair.go @@ -1,15 +1,13 @@ package manager import ( - "net/netip" - "github.com/netbirdio/netbird/route" ) type RouterPair struct { ID route.ID - Source netip.Prefix - Destination netip.Prefix + Source Network + Destination Network Masquerade bool Inverse bool } diff --git a/client/firewall/manager/set.go b/client/firewall/manager/set.go new file mode 100644 index 000000000..4c88f6eac --- /dev/null +++ b/client/firewall/manager/set.go @@ -0,0 +1,74 @@ +package manager + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "net/netip" + "slices" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/domain" +) + +type Set struct { + hash [4]byte + comment string +} + +// String returns the string representation of the set: hashed name and comment +func (h Set) String() string { + if h.comment == "" { + return h.HashedName() + } + return h.HashedName() + ": " + h.comment +} + +// HashedName returns the string representation of the hash +func (h Set) HashedName() string { + return fmt.Sprintf( + "nb-%s", + hex.EncodeToString(h.hash[:]), + ) +} + +// Comment returns the comment of the set +func (h Set) Comment() string { + return h.comment +} + +// NewPrefixSet generates a unique name for an ipset based on the given prefixes. +func NewPrefixSet(prefixes []netip.Prefix) Set { + // sort for consistent naming + SortPrefixes(prefixes) + + hash := sha256.New() + for _, src := range prefixes { + bytes, err := src.MarshalBinary() + if err != nil { + log.Warnf("failed to marshal prefix %s: %v", src, err) + } + hash.Write(bytes) + } + var set Set + copy(set.hash[:], hash.Sum(nil)[:4]) + + return set +} + +// NewDomainSet generates a unique name for an ipset based on the given domains. +func NewDomainSet(domains domain.List) Set { + slices.Sort(domains) + + hash := sha256.New() + for _, d := range domains { + hash.Write([]byte(d.PunycodeString())) + } + set := Set{ + comment: domains.SafeString(), + } + copy(set.hash[:], hash.Sum(nil)[:4]) + + return set +} diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index a5809471c..e6b3a031b 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -135,17 +135,16 @@ func (m *Manager) AddPeerFiltering( func (m *Manager) AddRouteFiltering( id []byte, sources []netip.Prefix, - destination netip.Prefix, + destination firewall.Network, proto firewall.Protocol, - sPort *firewall.Port, - dPort *firewall.Port, + sPort, dPort *firewall.Port, action firewall.Action, ) (firewall.Rule, error) { m.mutex.Lock() defer m.mutex.Unlock() - if !destination.Addr().Is4() { - return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String()) + if destination.IsPrefix() && !destination.Prefix.Addr().Is4() { + return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String()) } return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) @@ -242,7 +241,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error { return firewall.SetLegacyManagement(m.router, isLegacy) } -// Reset firewall to the default state +// Close closes the firewall manager func (m *Manager) Close(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -359,6 +358,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { return m.router.DeleteDNATRule(rule) } +// UpdateSet updates the set with the given prefixes +func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.UpdateSet(set, prefixes) +} + func (m *Manager) createWorkTable() (*nftables.Table, error) { tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) if err != nil { diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 373743a08..602a6b8dc 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -289,7 +289,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) { _, err = manager.AddRouteFiltering( nil, []netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")}, - netip.MustParsePrefix("10.1.0.0/24"), + fw.Network{Prefix: netip.MustParsePrefix("10.1.0.0/24")}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{443}}, @@ -298,8 +298,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) { require.NoError(t, err, "failed to add route filtering rule") pair := fw.RouterPair{ - Source: netip.MustParsePrefix("192.168.1.0/24"), - Destination: netip.MustParsePrefix("10.0.0.0/24"), + Source: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, + Destination: fw.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")}, Masquerade: true, } err = manager.AddNatRule(pair) diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index aff86dd90..0f6c5bdf6 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -10,7 +10,6 @@ import ( "strings" "github.com/coreos/go-iptables/iptables" - "github.com/davecgh/go-spew/spew" "github.com/google/nftables" "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" @@ -44,9 +43,14 @@ const ( const refreshRulesMapError = "refresh rules map: %w" var ( - errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found") + errFilterTableNotFound = fmt.Errorf("'filter' table not found") ) +type setInput struct { + set firewall.Set + prefixes []netip.Prefix +} + type router struct { conn *nftables.Conn workTable *nftables.Table @@ -54,7 +58,7 @@ type router struct { chains map[string]*nftables.Chain // rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules rules map[string]*nftables.Rule - ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set] + ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set] wgIface iFaceMapper ipFwdState *ipfwdstate.IPForwardingState @@ -163,7 +167,7 @@ func (r *router) removeNatPreroutingRules() error { func (r *router) loadFilterTable() (*nftables.Table, error) { tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4) if err != nil { - return nil, fmt.Errorf("nftables: unable to list tables: %v", err) + return nil, fmt.Errorf("unable to list tables: %v", err) } for _, table := range tables { @@ -316,7 +320,7 @@ func (r *router) setupDataPlaneMark() error { func (r *router) AddRouteFiltering( id []byte, sources []netip.Prefix, - destination netip.Prefix, + destination firewall.Network, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, @@ -331,23 +335,29 @@ func (r *router) AddRouteFiltering( chain := r.chains[chainNameRoutingFw] var exprs []expr.Any + var source firewall.Network switch { case len(sources) == 1 && sources[0].Bits() == 0: // If it's 0.0.0.0/0, we don't need to add any source matching case len(sources) == 1: // If there's only one source, we can use it directly - exprs = append(exprs, generateCIDRMatcherExpressions(true, sources[0])...) + source.Prefix = sources[0] default: - // If there are multiple sources, create or get an ipset - var err error - exprs, err = r.getIpSetExprs(sources, exprs) - if err != nil { - return nil, fmt.Errorf("get ipset expressions: %w", err) - } + // If there are multiple sources, use a set + source.Set = firewall.NewPrefixSet(sources) } - // Handle destination - exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...) + sourceExp, err := r.applyNetwork(source, sources, true) + if err != nil { + return nil, fmt.Errorf("apply source: %w", err) + } + exprs = append(exprs, sourceExp...) + + destExp, err := r.applyNetwork(destination, nil, false) + if err != nil { + return nil, fmt.Errorf("apply destination: %w", err) + } + exprs = append(exprs, destExp...) // Handle protocol if proto != firewall.ProtocolALL { @@ -391,39 +401,27 @@ func (r *router) AddRouteFiltering( rule = r.conn.AddRule(rule) } - log.Tracef("Adding route rule %s", spew.Sdump(rule)) if err := r.conn.Flush(); err != nil { return nil, fmt.Errorf(flushError, err) } r.rules[string(ruleKey)] = rule - log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action) + log.Debugf("added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action) return ruleKey, nil } -func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) { - setName := firewall.GenerateSetName(sources) - ref, err := r.ipsetCounter.Increment(setName, sources) +func (r *router) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bool) ([]expr.Any, error) { + ref, err := r.ipsetCounter.Increment(set.HashedName(), setInput{ + set: set, + prefixes: prefixes, + }) if err != nil { - return nil, fmt.Errorf("create or get ipset for sources: %w", err) + return nil, fmt.Errorf("create or get ipset: %w", err) } - exprs = append(exprs, - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: 12, - Len: 4, - }, - &expr.Lookup{ - SourceRegister: 1, - SetName: ref.Out.Name, - SetID: ref.Out.ID, - }, - ) - return exprs, nil + return getIpSetExprs(ref, isSource) } func (r *router) DeleteRouteRule(rule firewall.Rule) error { @@ -442,42 +440,54 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error { return fmt.Errorf("route rule %s has no handle", ruleKey) } - setName := r.findSetNameInRule(nftRule) - if err := r.deleteNftRule(nftRule, ruleKey); err != nil { return fmt.Errorf("delete: %w", err) } - if setName != "" { - if _, err := r.ipsetCounter.Decrement(setName); err != nil { - return fmt.Errorf("decrement ipset reference: %w", err) - } - } - if err := r.conn.Flush(); err != nil { return fmt.Errorf(flushError, err) } + if err := r.decrementSetCounter(nftRule); err != nil { + return fmt.Errorf("decrement set counter: %w", err) + } + return nil } -func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) { +func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, error) { // overlapping prefixes will result in an error, so we need to merge them - sources = firewall.MergeIPRanges(sources) + prefixes := firewall.MergeIPRanges(input.prefixes) - set := &nftables.Set{ - Name: setName, - Table: r.workTable, + nfset := &nftables.Set{ + Name: setName, + Comment: input.set.Comment(), + Table: r.workTable, // required for prefixes Interval: true, KeyType: nftables.TypeIPAddr, } + elements := convertPrefixesToSet(prefixes) + if err := r.conn.AddSet(nfset, elements); err != nil { + return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err) + } + + if err := r.conn.Flush(); err != nil { + return nil, fmt.Errorf("flush error: %w", err) + } + + log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2) + + return nfset, nil +} + +func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement { var elements []nftables.SetElement - for _, prefix := range sources { + for _, prefix := range prefixes { // TODO: Implement IPv6 support if prefix.Addr().Is6() { - log.Printf("Skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) + log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) continue } @@ -493,18 +503,7 @@ func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables. nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true}, ) } - - if err := r.conn.AddSet(set, elements); err != nil { - return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err) - } - - if err := r.conn.Flush(); err != nil { - return nil, fmt.Errorf("flush error: %w", err) - } - - log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2) - - return set, nil + return elements } // calculateLastIP determines the last IP in a given prefix. @@ -528,8 +527,8 @@ func uint32ToBytes(ip uint32) [4]byte { return b } -func (r *router) deleteIpSet(setName string, set *nftables.Set) error { - r.conn.DelSet(set) +func (r *router) deleteIpSet(setName string, nfset *nftables.Set) error { + r.conn.DelSet(nfset) if err := r.conn.Flush(); err != nil { return fmt.Errorf(flushError, err) } @@ -538,13 +537,27 @@ func (r *router) deleteIpSet(setName string, set *nftables.Set) error { return nil } -func (r *router) findSetNameInRule(rule *nftables.Rule) string { - for _, e := range rule.Exprs { - if lookup, ok := e.(*expr.Lookup); ok { - return lookup.SetName +func (r *router) decrementSetCounter(rule *nftables.Rule) error { + sets := r.findSets(rule) + + var merr *multierror.Error + for _, setName := range sets { + if _, err := r.ipsetCounter.Decrement(setName); err != nil { + merr = multierror.Append(merr, fmt.Errorf("decrement set counter: %w", err)) } } - return "" + + return nberrors.FormatErrorOrNil(merr) +} + +func (r *router) findSets(rule *nftables.Rule) []string { + var sets []string + for _, e := range rule.Exprs { + if lookup, ok := e.(*expr.Lookup); ok { + sets = append(sets, lookup.SetName) + } + } + return sets } func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error { @@ -586,7 +599,8 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error { } if err := r.conn.Flush(); err != nil { - return fmt.Errorf("nftables: insert rules for %s: %v", pair.Destination, err) + // TODO: rollback ipset counter + return fmt.Errorf("insert rules for %s: %v", pair.Destination, err) } return nil @@ -594,19 +608,22 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error { // addNatRule inserts a nftables rule to the conn client flush queue func (r *router) addNatRule(pair firewall.RouterPair) error { - sourceExp := generateCIDRMatcherExpressions(true, pair.Source) - destExp := generateCIDRMatcherExpressions(false, pair.Destination) + sourceExp, err := r.applyNetwork(pair.Source, nil, true) + if err != nil { + return fmt.Errorf("apply source: %w", err) + } + + destExp, err := r.applyNetwork(pair.Destination, nil, false) + if err != nil { + return fmt.Errorf("apply destination: %w", err) + } op := expr.CmpOpEq if pair.Inverse { op = expr.CmpOpNeq } - // We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading. - // Masquerading will take care of the conntrack state, which means we won't need to mark established connections. - exprs := getCtNewExprs() - exprs = append(exprs, - // interface matching + exprs := []expr.Any{ &expr.Meta{ Key: expr.MetaKeyIIFNAME, Register: 1, @@ -616,7 +633,10 @@ func (r *router) addNatRule(pair firewall.RouterPair) error { Register: 1, Data: ifname(r.wgIface.Name()), }, - ) + } + // We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading. + // Masquerading will take care of the conntrack state, which means we won't need to mark established connections. + exprs = append(exprs, getCtNewExprs()...) exprs = append(exprs, sourceExp...) exprs = append(exprs, destExp...) @@ -646,7 +666,9 @@ func (r *router) addNatRule(pair firewall.RouterPair) error { } } - r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ + // Ensure nat rules come first, so the mark can be overwritten. + // Currently overwritten by the dst-type LOCAL rules for redirected traffic. + r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{ Table: r.workTable, Chain: r.chains[chainNameManglePrerouting], Exprs: exprs, @@ -729,8 +751,15 @@ func (r *router) addPostroutingRules() error { // addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { - sourceExp := generateCIDRMatcherExpressions(true, pair.Source) - destExp := generateCIDRMatcherExpressions(false, pair.Destination) + sourceExp, err := r.applyNetwork(pair.Source, nil, true) + if err != nil { + return fmt.Errorf("apply source: %w", err) + } + + destExp, err := r.applyNetwork(pair.Destination, nil, false) + if err != nil { + return fmt.Errorf("apply destination: %w", err) + } exprs := []expr.Any{ &expr.Counter{}, @@ -739,7 +768,8 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { }, } - expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic + exprs = append(exprs, sourceExp...) + exprs = append(exprs, destExp...) ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) @@ -752,7 +782,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ Table: r.workTable, Chain: r.chains[chainNameRoutingFw], - Exprs: expression, + Exprs: exprs, UserData: []byte(ruleKey), }) return nil @@ -767,11 +797,13 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error { return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) } - log.Debugf("nftables: removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination) + log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination) delete(r.rules, ruleKey) - } else { - log.Debugf("nftables: legacy forwarding rule %s not found", ruleKey) + + if err := r.decrementSetCounter(rule); err != nil { + return fmt.Errorf("decrement set counter: %w", err) + } } return nil @@ -982,12 +1014,14 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error { return fmt.Errorf(refreshRulesMapError, err) } - if err := r.removeNatRule(pair); err != nil { - return fmt.Errorf("remove prerouting rule: %w", err) - } + if pair.Masquerade { + if err := r.removeNatRule(pair); err != nil { + return fmt.Errorf("remove prerouting rule: %w", err) + } - if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { - return fmt.Errorf("remove inverse prerouting rule: %w", err) + if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { + return fmt.Errorf("remove inverse prerouting rule: %w", err) + } } if err := r.removeLegacyRouteRule(pair); err != nil { @@ -995,10 +1029,10 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error { } if err := r.conn.Flush(); err != nil { - return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err) + // TODO: rollback set counter + return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err) } - log.Debugf("nftables: removed nat rules for %s", pair.Destination) return nil } @@ -1006,16 +1040,19 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error { ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair) if rule, exists := r.rules[ruleKey]; exists { - err := r.conn.DelRule(rule) - if err != nil { + if err := r.conn.DelRule(rule); err != nil { return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err) } - log.Debugf("nftables: removed prerouting rule %s -> %s", pair.Source, pair.Destination) + log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination) delete(r.rules, ruleKey) + + if err := r.decrementSetCounter(rule); err != nil { + return fmt.Errorf("decrement set counter: %w", err) + } } else { - log.Debugf("nftables: prerouting rule %s not found", ruleKey) + log.Debugf("prerouting rule %s not found", ruleKey) } return nil @@ -1027,7 +1064,7 @@ func (r *router) refreshRulesMap() error { for _, chain := range r.chains { rules, err := r.conn.GetRules(chain.Table, chain) if err != nil { - return fmt.Errorf("nftables: unable to list rules: %v", err) + return fmt.Errorf(" unable to list rules: %v", err) } for _, rule := range rules { if len(rule.UserData) > 0 { @@ -1301,13 +1338,54 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error { return nberrors.FormatErrorOrNil(merr) } -// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR -func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any { - var offset uint32 - if source { - offset = 12 // src offset - } else { - offset = 16 // dst offset +func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { + nfset, err := r.conn.GetSetByName(r.workTable, set.HashedName()) + if err != nil { + return fmt.Errorf("get set %s: %w", set.HashedName(), err) + } + + elements := convertPrefixesToSet(prefixes) + if err := r.conn.SetAddElements(nfset, elements); err != nil { + return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err) + } + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf(flushError, err) + } + + log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes) + + return nil +} + +// applyNetwork generates nftables expressions for networks (CIDR) or sets +func (r *router) applyNetwork( + network firewall.Network, + setPrefixes []netip.Prefix, + isSource bool, +) ([]expr.Any, error) { + if network.IsSet() { + exprs, err := r.getIpSet(network.Set, setPrefixes, isSource) + if err != nil { + return nil, fmt.Errorf("source: %w", err) + } + return exprs, nil + } + + if network.IsPrefix() { + return applyPrefix(network.Prefix, isSource), nil + } + + return nil, nil +} + +// applyPrefix generates nftables expressions for a CIDR prefix +func applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any { + // dst offset + offset := uint32(16) + if isSource { + // src offset + offset = 12 } ones := prefix.Bits() @@ -1415,3 +1493,27 @@ func getCtNewExprs() []expr.Any { }, } } + +func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) { + + // dst offset + offset := uint32(16) + if isSource { + // src offset + offset = 12 + } + + return []expr.Any{ + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: offset, + Len: 4, + }, + &expr.Lookup{ + SourceRegister: 1, + SetName: ref.Out.Name, + SetID: ref.Out.ID, + }, + }, nil +} diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index 28baef4dd..4fdbf3505 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -88,8 +88,8 @@ func TestNftablesManager_AddNatRule(t *testing.T) { } // Build CIDR matching expressions - sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) - destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) + sourceExp := applyPrefix(testCase.InputPair.Source.Prefix, true) + destExp := applyPrefix(testCase.InputPair.Destination.Prefix, false) // Combine all expressions in the correct order // nolint:gocritic @@ -311,7 +311,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) + ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action) require.NoError(t, err, "AddRouteFiltering failed") t.Cleanup(func() { @@ -441,8 +441,8 @@ func TestNftablesCreateIpSet(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - setName := firewall.GenerateSetName(tt.sources) - set, err := r.createIpSet(setName, tt.sources) + setName := firewall.NewPrefixSet(tt.sources).HashedName() + set, err := r.createIpSet(setName, setInput{prefixes: tt.sources}) if err != nil { t.Logf("Failed to create IP set: %v", err) printNftSets() diff --git a/client/firewall/test/cases_linux.go b/client/firewall/test/cases_linux.go index 267e93efd..59a370a97 100644 --- a/client/firewall/test/cases_linux.go +++ b/client/firewall/test/cases_linux.go @@ -15,8 +15,8 @@ var ( Name: "Insert Forwarding IPV4 Rule", InputPair: firewall.RouterPair{ ID: "zxa", - Source: netip.MustParsePrefix("100.100.100.1/32"), - Destination: netip.MustParsePrefix("100.100.200.0/24"), + Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")}, + Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")}, Masquerade: false, }, }, @@ -24,8 +24,8 @@ var ( Name: "Insert Forwarding And Nat IPV4 Rules", InputPair: firewall.RouterPair{ ID: "zxa", - Source: netip.MustParsePrefix("100.100.100.1/32"), - Destination: netip.MustParsePrefix("100.100.200.0/24"), + Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")}, + Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")}, Masquerade: true, }, }, @@ -40,8 +40,8 @@ var ( Name: "Remove Forwarding And Nat IPV4 Rules", InputPair: firewall.RouterPair{ ID: "zxa", - Source: netip.MustParsePrefix("100.100.100.1/32"), - Destination: netip.MustParsePrefix("100.100.200.0/24"), + Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")}, + Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")}, Masquerade: true, }, }, diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index 5fe698aa9..ce04c82c7 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -12,7 +12,7 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) -// Reset firewall to the default state +// Close cleans up the firewall manager by removing all rules and closing trackers func (m *Manager) Close(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index f63792fec..f261c472f 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -10,7 +10,6 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -22,7 +21,7 @@ const ( firewallRuleName = "Netbird" ) -// Reset firewall to the default state +// Close cleans up the firewall manager by removing all rules and closing trackers func (m *Manager) Close(*statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -32,17 +31,14 @@ func (m *Manager) Close(*statemanager.Manager) error { if m.udpTracker != nil { m.udpTracker.Close() - m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger) } if m.icmpTracker != nil { m.icmpTracker.Close() - m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger) } if m.tcpTracker != nil { m.tcpTracker.Close() - m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger) } if fwder := m.forwarder.Load(); fwder != nil { diff --git a/client/firewall/uspfilter/forwarder/forwarder.go b/client/firewall/uspfilter/forwarder/forwarder.go index 0dff3acc7..2ae983f6e 100644 --- a/client/firewall/uspfilter/forwarder/forwarder.go +++ b/client/firewall/uspfilter/forwarder/forwarder.go @@ -4,7 +4,9 @@ import ( "context" "fmt" "net" + "net/netip" "runtime" + "sync" log "github.com/sirupsen/logrus" "gvisor.dev/gvisor/pkg/buffer" @@ -17,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "github.com/netbirdio/netbird/client/firewall/uspfilter/common" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" ) @@ -29,8 +32,10 @@ const ( ) type Forwarder struct { - logger *nblog.Logger - flowLogger nftypes.FlowLogger + logger *nblog.Logger + flowLogger nftypes.FlowLogger + // ruleIdMap is used to store the rule ID for a given connection + ruleIdMap sync.Map stack *stack.Stack endpoint *endpoint udpForwarder *udpForwarder @@ -167,3 +172,35 @@ func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP { } return addr.AsSlice() } + +func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) { + key := buildKey(srcIP, dstIP, srcPort, dstPort) + f.ruleIdMap.LoadOrStore(key, ruleID) +} + +func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) { + + if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok { + return value.([]byte), true + } else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok { + return value.([]byte), true + } + + return nil, false +} + +func (f *Forwarder) DeleteRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) { + if _, ok := f.ruleIdMap.LoadAndDelete(buildKey(srcIP, dstIP, srcPort, dstPort)); ok { + return + } + f.ruleIdMap.LoadAndDelete(buildKey(dstIP, srcIP, dstPort, srcPort)) +} + +func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKey { + return conntrack.ConnKey{ + SrcIP: srcIP, + DstIP: dstIP, + SrcPort: srcPort, + DstPort: dstPort, + } +} diff --git a/client/firewall/uspfilter/forwarder/icmp.go b/client/firewall/uspfilter/forwarder/icmp.go index a21ec2c87..08d77ed05 100644 --- a/client/firewall/uspfilter/forwarder/icmp.go +++ b/client/firewall/uspfilter/forwarder/icmp.go @@ -25,7 +25,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf } flowID := uuid.New() - f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode) + f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode, 0, 0) ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second) defer cancel() @@ -34,14 +34,14 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf // TODO: support non-root conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") if err != nil { - f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err) + f.logger.Error("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err) // This will make netstack reply on behalf of the original destination, that's ok for now return false } defer func() { if err := conn.Close(); err != nil { - f.logger.Debug("Failed to close ICMP socket: %v", err) + f.logger.Debug("forwarder: Failed to close ICMP socket: %v", err) } }() @@ -52,36 +52,37 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf payload := fullPacket.AsSlice() if _, err = conn.WriteTo(payload, dst); err != nil { - f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err) + f.logger.Error("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err) return true } - f.logger.Trace("Forwarded ICMP packet %v type %v code %v", + f.logger.Trace("forwarder: Forwarded ICMP packet %v type %v code %v", epID(id), icmpHdr.Type(), icmpHdr.Code()) // For Echo Requests, send and handle response if header.ICMPv4Type(icmpType) == header.ICMPv4Echo { - f.handleEchoResponse(icmpHdr, conn, id) - f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode) + rxBytes := pkt.Size() + txBytes := f.handleEchoResponse(icmpHdr, conn, id) + f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) } // For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing return true } -func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) { +func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) int { if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { - f.logger.Error("Failed to set read deadline for ICMP response: %v", err) - return + f.logger.Error("forwarder: Failed to set read deadline for ICMP response: %v", err) + return 0 } response := make([]byte, f.endpoint.mtu) n, _, err := conn.ReadFrom(response) if err != nil { if !isTimeout(err) { - f.logger.Error("Failed to read ICMP response: %v", err) + f.logger.Error("forwarder: Failed to read ICMP response: %v", err) } - return + return 0 } ipHdr := make([]byte, header.IPv4MinimumSize) @@ -100,28 +101,54 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon fullPacket = append(fullPacket, response[:n]...) if err := f.InjectIncomingPacket(fullPacket); err != nil { - f.logger.Error("Failed to inject ICMP response: %v", err) + f.logger.Error("forwarder: Failed to inject ICMP response: %v", err) - return + return 0 } - f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v", + f.logger.Trace("forwarder: Forwarded ICMP echo reply for %v type %v code %v", epID(id), icmpHdr.Type(), icmpHdr.Code()) + + return len(fullPacket) } // sendICMPEvent stores flow events for ICMP packets -func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8) { - f.flowLogger.StoreEvent(nftypes.EventFields{ +func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, rxBytes, txBytes uint64) { + var rxPackets, txPackets uint64 + if rxBytes > 0 { + rxPackets = 1 + } + if txBytes > 0 { + txPackets = 1 + } + + srcIp := netip.AddrFrom4(id.RemoteAddress.As4()) + dstIp := netip.AddrFrom4(id.LocalAddress.As4()) + + fields := nftypes.EventFields{ FlowID: flowID, Type: typ, Direction: nftypes.Ingress, Protocol: nftypes.ICMP, // TODO: handle ipv6 - SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()), - DestIP: netip.AddrFrom4(id.LocalAddress.As4()), + SourceIP: srcIp, + DestIP: dstIp, ICMPType: icmpType, ICMPCode: icmpCode, - // TODO: get packets/bytes - }) + RxBytes: rxBytes, + TxBytes: txBytes, + RxPackets: rxPackets, + TxPackets: txPackets, + } + + if typ == nftypes.TypeStart { + if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok { + fields.RuleID = ruleId + } + } else { + f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort) + } + + f.flowLogger.StoreEvent(fields) } diff --git a/client/firewall/uspfilter/forwarder/tcp.go b/client/firewall/uspfilter/forwarder/tcp.go index 71cd457ef..04b3ae233 100644 --- a/client/firewall/uspfilter/forwarder/tcp.go +++ b/client/firewall/uspfilter/forwarder/tcp.go @@ -6,8 +6,10 @@ import ( "io" "net" "net/netip" + "sync" "github.com/google/uuid" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -23,11 +25,11 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { flowID := uuid.New() - f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil) + f.sendTCPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0) var success bool defer func() { if !success { - f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil) + f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0) } }() @@ -65,67 +67,97 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { } func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) { - defer func() { - if err := inConn.Close(); err != nil { - f.logger.Debug("forwarder: inConn close error: %v", err) - } - if err := outConn.Close(); err != nil { - f.logger.Debug("forwarder: outConn close error: %v", err) - } - ep.Close() - f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep) - }() - - // Create context for managing the proxy goroutines ctx, cancel := context.WithCancel(f.ctx) defer cancel() - errChan := make(chan error, 2) - go func() { - _, err := io.Copy(outConn, inConn) - errChan <- err - }() - - go func() { - _, err := io.Copy(inConn, outConn) - errChan <- err - }() - - select { - case <-ctx.Done(): - f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", epID(id)) - return - case err := <-errChan: - if err != nil && !isClosedError(err) { - f.logger.Error("proxyTCP: copy error: %v", err) + <-ctx.Done() + // Close connections and endpoint. + if err := inConn.Close(); err != nil && !isClosedError(err) { + f.logger.Debug("forwarder: inConn close error: %v", err) + } + if err := outConn.Close(); err != nil && !isClosedError(err) { + f.logger.Debug("forwarder: outConn close error: %v", err) + } + + ep.Close() + }() + + var wg sync.WaitGroup + wg.Add(2) + + var ( + bytesFromInToOut int64 // bytes from client to server (tx for client) + bytesFromOutToIn int64 // bytes from server to client (rx for client) + errInToOut error + errOutToIn error + ) + + go func() { + bytesFromInToOut, errInToOut = io.Copy(outConn, inConn) + cancel() + wg.Done() + }() + + go func() { + + bytesFromOutToIn, errOutToIn = io.Copy(inConn, outConn) + cancel() + wg.Done() + }() + + wg.Wait() + + if errInToOut != nil { + if !isClosedError(errInToOut) { + f.logger.Error("proxyTCP: copy error (in -> out): %v", errInToOut) } - f.logger.Trace("forwarder: tearing down TCP connection %v", epID(id)) - return } + if errOutToIn != nil { + if !isClosedError(errOutToIn) { + f.logger.Error("proxyTCP: copy error (out -> in): %v", errOutToIn) + } + } + + var rxPackets, txPackets uint64 + if tcpStats, ok := ep.Stats().(*tcp.Stats); ok { + // fields are flipped since this is the in conn + rxPackets = tcpStats.SegmentsSent.Value() + txPackets = tcpStats.SegmentsReceived.Value() + } + + f.logger.Trace("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut) + + f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets) } -func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) { +func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) { + srcIp := netip.AddrFrom4(id.RemoteAddress.As4()) + dstIp := netip.AddrFrom4(id.LocalAddress.As4()) + fields := nftypes.EventFields{ FlowID: flowID, Type: typ, Direction: nftypes.Ingress, Protocol: nftypes.TCP, // TODO: handle ipv6 - SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()), - DestIP: netip.AddrFrom4(id.LocalAddress.As4()), + SourceIP: srcIp, + DestIP: dstIp, SourcePort: id.RemotePort, DestPort: id.LocalPort, + RxBytes: rxBytes, + TxBytes: txBytes, + RxPackets: rxPackets, + TxPackets: txPackets, } - if ep != nil { - if tcpStats, ok := ep.Stats().(*tcp.Stats); ok { - // fields are flipped since this is the in conn - // TODO: get bytes - fields.RxPackets = tcpStats.SegmentsSent.Value() - fields.TxPackets = tcpStats.SegmentsReceived.Value() + if typ == nftypes.TypeStart { + if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok { + fields.RuleID = ruleId } + } else { + f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort) } f.flowLogger.StoreEvent(fields) diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go index 7ce85e2b6..cb88aa59a 100644 --- a/client/firewall/uspfilter/forwarder/udp.go +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -149,11 +149,11 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { flowID := uuid.New() - f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil) + f.sendUDPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0) var success bool defer func() { if !success { - f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil) + f.sendUDPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0) } }() @@ -199,7 +199,6 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { if err := outConn.Close(); err != nil { f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err) } - return } f.udpForwarder.conns[id] = pConn @@ -212,68 +211,94 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { } func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) { - defer func() { + + ctx, cancel := context.WithCancel(f.ctx) + defer cancel() + + go func() { + <-ctx.Done() + pConn.cancel() - if err := pConn.conn.Close(); err != nil { + if err := pConn.conn.Close(); err != nil && !isClosedError(err) { f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err) } - if err := pConn.outConn.Close(); err != nil { + if err := pConn.outConn.Close(); err != nil && !isClosedError(err) { f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err) } ep.Close() - - f.udpForwarder.Lock() - delete(f.udpForwarder.conns, id) - f.udpForwarder.Unlock() - - f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep) }() - errChan := make(chan error, 2) + var wg sync.WaitGroup + wg.Add(2) + var txBytes, rxBytes int64 + var outboundErr, inboundErr error + + // outbound->inbound: copy from pConn.conn to pConn.outConn go func() { - errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound") + defer wg.Done() + txBytes, outboundErr = pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound") }() + // inbound->outbound: copy from pConn.outConn to pConn.conn go func() { - errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound") + defer wg.Done() + rxBytes, inboundErr = pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound") }() - select { - case <-ctx.Done(): - f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", epID(id)) - return - case err := <-errChan: - if err != nil && !isClosedError(err) { - f.logger.Error("proxyUDP: copy error: %v", err) - } - f.logger.Trace("forwarder: tearing down UDP connection %v", epID(id)) - return + wg.Wait() + + if outboundErr != nil && !isClosedError(outboundErr) { + f.logger.Error("proxyUDP: copy error (outbound->inbound): %v", outboundErr) } + if inboundErr != nil && !isClosedError(inboundErr) { + f.logger.Error("proxyUDP: copy error (inbound->outbound): %v", inboundErr) + } + + var rxPackets, txPackets uint64 + if udpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok { + // fields are flipped since this is the in conn + rxPackets = udpStats.PacketsSent.Value() + txPackets = udpStats.PacketsReceived.Value() + } + + f.logger.Trace("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes) + + f.udpForwarder.Lock() + delete(f.udpForwarder.conns, id) + f.udpForwarder.Unlock() + + f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, uint64(rxBytes), uint64(txBytes), rxPackets, txPackets) } // sendUDPEvent stores flow events for UDP connections -func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) { +func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) { + srcIp := netip.AddrFrom4(id.RemoteAddress.As4()) + dstIp := netip.AddrFrom4(id.LocalAddress.As4()) + fields := nftypes.EventFields{ FlowID: flowID, Type: typ, Direction: nftypes.Ingress, Protocol: nftypes.UDP, // TODO: handle ipv6 - SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()), - DestIP: netip.AddrFrom4(id.LocalAddress.As4()), + SourceIP: srcIp, + DestIP: dstIp, SourcePort: id.RemotePort, DestPort: id.LocalPort, + RxBytes: rxBytes, + TxBytes: txBytes, + RxPackets: rxPackets, + TxPackets: txPackets, } - if ep != nil { - if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok { - // fields are flipped since this is the in conn - // TODO: get bytes - fields.RxPackets = tcpStats.PacketsSent.Value() - fields.TxPackets = tcpStats.PacketsReceived.Value() + if typ == nftypes.TypeStart { + if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok { + fields.RuleID = ruleId } + } else { + f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort) } f.flowLogger.StoreEvent(fields) @@ -288,18 +313,20 @@ func (c *udpPacketConn) getIdleDuration() time.Duration { return time.Since(lastSeen) } -func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error { +// copy reads from src and writes to dst. +func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) (int64, error) { bufp := bufPool.Get().(*[]byte) defer bufPool.Put(bufp) buffer := *bufp + var totalBytes int64 = 0 for { if ctx.Err() != nil { - return ctx.Err() + return totalBytes, ctx.Err() } if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil { - return fmt.Errorf("set read deadline: %w", err) + return totalBytes, fmt.Errorf("set read deadline: %w", err) } n, err := src.Read(buffer) @@ -307,14 +334,15 @@ func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bu if isTimeout(err) { continue } - return fmt.Errorf("read from %s: %w", direction, err) + return totalBytes, fmt.Errorf("read from %s: %w", direction, err) } - _, err = dst.Write(buffer[:n]) + nWritten, err := dst.Write(buffer[:n]) if err != nil { - return fmt.Errorf("write to %s: %w", direction, err) + return totalBytes, fmt.Errorf("write to %s: %w", direction, err) } + totalBytes += int64(nWritten) c.updateLastSeen() } } diff --git a/client/firewall/uspfilter/rule.go b/client/firewall/uspfilter/rule.go index a23d2011b..b765c72e9 100644 --- a/client/firewall/uspfilter/rule.go +++ b/client/firewall/uspfilter/rule.go @@ -29,14 +29,15 @@ func (r *PeerRule) ID() string { } type RouteRule struct { - id string - mgmtId []byte - sources []netip.Prefix - destination netip.Prefix - proto firewall.Protocol - srcPort *firewall.Port - dstPort *firewall.Port - action firewall.Action + id string + mgmtId []byte + sources []netip.Prefix + dstSet firewall.Set + destinations []netip.Prefix + proto firewall.Protocol + srcPort *firewall.Port + dstPort *firewall.Port + action firewall.Action } // ID returns the rule id diff --git a/client/firewall/uspfilter/tracer_test.go b/client/firewall/uspfilter/tracer_test.go index 48b0ec44d..bd87879a5 100644 --- a/client/firewall/uspfilter/tracer_test.go +++ b/client/firewall/uspfilter/tracer_test.go @@ -198,12 +198,12 @@ func TestTracePacket(t *testing.T) { m.forwarder.Store(&forwarder.Forwarder{}) src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32) - dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32) - _, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept) + dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32) + _, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept) require.NoError(t, err) }, packetBuilder: func() *PacketBuilder { - return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN) + return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN) }, expectedStages: []PacketStage{ StageReceived, @@ -222,12 +222,12 @@ func TestTracePacket(t *testing.T) { m.nativeRouter.Store(false) src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32) - dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32) - _, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop) + dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32) + _, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop) require.NoError(t, err) }, packetBuilder: func() *PacketBuilder { - return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN) + return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN) }, expectedStages: []PacketStage{ StageReceived, @@ -245,7 +245,7 @@ func TestTracePacket(t *testing.T) { m.nativeRouter.Store(true) }, packetBuilder: func() *PacketBuilder { - return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN) + return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN) }, expectedStages: []PacketStage{ StageReceived, @@ -263,7 +263,7 @@ func TestTracePacket(t *testing.T) { m.routingEnabled.Store(false) }, packetBuilder: func() *PacketBuilder { - return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN) + return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN) }, expectedStages: []PacketStage{ StageReceived, @@ -425,8 +425,8 @@ func TestTracePacket(t *testing.T) { require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")), "100.10.0.100 should be recognized as a local IP") - require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("172.17.0.2")), - "172.17.0.2 should not be recognized as a local IP") + require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("192.168.17.2")), + "192.168.17.2 should not be recognized as a local IP") pb := tc.packetBuilder() diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 466c6a18b..11730dbb3 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -49,10 +49,10 @@ var errNatNotSupported = errors.New("nat not supported with userspace firewall") // RuleSet is a set of rules grouped by a string key type RuleSet map[string]PeerRule -type RouteRules []RouteRule +type RouteRules []*RouteRule func (r RouteRules) Sort() { - slices.SortStableFunc(r, func(a, b RouteRule) int { + slices.SortStableFunc(r, func(a, b *RouteRule) int { // Deny rules come first if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop { return -1 @@ -99,6 +99,8 @@ type Manager struct { forwarder atomic.Pointer[forwarder.Forwarder] logger *nblog.Logger flowLogger nftypes.FlowLogger + + blockRule firewall.Rule } // decoder for packages @@ -201,41 +203,35 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe } } - if err := m.blockInvalidRouted(iface); err != nil { - log.Errorf("failed to block invalid routed traffic: %v", err) - } - if err := iface.SetFilter(m); err != nil { return nil, fmt.Errorf("set filter: %w", err) } return m, nil } -func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error { - if m.forwarder.Load() == nil { - return nil - } +func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) { wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String()) if err != nil { - return fmt.Errorf("parse wireguard network: %w", err) + return nil, fmt.Errorf("parse wireguard network: %w", err) } log.Debugf("blocking invalid routed traffic for %s", wgPrefix) - if _, err := m.AddRouteFiltering( + rule, err := m.addRouteFiltering( nil, []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}, - wgPrefix, + firewall.Network{Prefix: wgPrefix}, firewall.ProtocolALL, nil, nil, firewall.ActionDrop, - ); err != nil { - return fmt.Errorf("block wg nte : %w", err) + ) + if err != nil { + return nil, fmt.Errorf("block wg nte : %w", err) } // TODO: Block networks that we're a client of - return nil + return rule, nil } func (m *Manager) determineRouting() error { @@ -413,10 +409,23 @@ func (m *Manager) AddPeerFiltering( func (m *Manager) AddRouteFiltering( id []byte, sources []netip.Prefix, - destination netip.Prefix, + destination firewall.Network, proto firewall.Protocol, - sPort *firewall.Port, - dPort *firewall.Port, + sPort, dPort *firewall.Port, + action firewall.Action, +) (firewall.Rule, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.addRouteFiltering(id, sources, destination, proto, sPort, dPort, action) +} + +func (m *Manager) addRouteFiltering( + id []byte, + sources []netip.Prefix, + destination firewall.Network, + proto firewall.Protocol, + sPort, dPort *firewall.Port, action firewall.Action, ) (firewall.Rule, error) { if m.nativeRouter.Load() && m.nativeFirewall != nil { @@ -426,34 +435,39 @@ func (m *Manager) AddRouteFiltering( ruleID := uuid.New().String() rule := RouteRule{ // TODO: consolidate these IDs - id: ruleID, - mgmtId: id, - sources: sources, - destination: destination, - proto: proto, - srcPort: sPort, - dstPort: dPort, - action: action, + id: ruleID, + mgmtId: id, + sources: sources, + dstSet: destination.Set, + proto: proto, + srcPort: sPort, + dstPort: dPort, + action: action, + } + if destination.IsPrefix() { + rule.destinations = []netip.Prefix{destination.Prefix} } - m.mutex.Lock() - m.routeRules = append(m.routeRules, rule) + m.routeRules = append(m.routeRules, &rule) m.routeRules.Sort() - m.mutex.Unlock() return &rule, nil } func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.deleteRouteRule(rule) +} + +func (m *Manager) deleteRouteRule(rule firewall.Rule) error { if m.nativeRouter.Load() && m.nativeFirewall != nil { return m.nativeFirewall.DeleteRouteRule(rule) } - m.mutex.Lock() - defer m.mutex.Unlock() - ruleID := rule.ID() - idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool { + idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool { return r.id == ruleID }) if idx < 0 { @@ -509,6 +523,52 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { return m.nativeFirewall.DeleteDNATRule(rule) } +// UpdateSet updates the rule destinations associated with the given set +// by merging the existing prefixes with the new ones, then deduplicating. +func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { + if m.nativeRouter.Load() && m.nativeFirewall != nil { + return m.nativeFirewall.UpdateSet(set, prefixes) + } + + m.mutex.Lock() + defer m.mutex.Unlock() + + var matches []*RouteRule + for _, rule := range m.routeRules { + if rule.dstSet == set { + matches = append(matches, rule) + } + } + + if len(matches) == 0 { + return fmt.Errorf("no route rule found for set: %s", set) + } + + destinations := matches[0].destinations + for _, prefix := range prefixes { + if prefix.Addr().Is4() { + destinations = append(destinations, prefix) + } + } + + slices.SortFunc(destinations, func(a, b netip.Prefix) int { + cmp := a.Addr().Compare(b.Addr()) + if cmp != 0 { + return cmp + } + return a.Bits() - b.Bits() + }) + + destinations = slices.Compact(destinations) + + for _, rule := range matches { + rule.destinations = destinations + } + log.Debugf("updated set %s to prefixes %v", set.HashedName(), destinations) + + return nil +} + // DropOutgoing filter outgoing packets func (m *Manager) DropOutgoing(packetData []byte, size int) bool { return m.processOutgoingHooks(packetData, size) @@ -764,7 +824,8 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe proto, pnum := getProtocolFromPacket(d) srcPort, dstPort := getPortsFromPacket(d) - if ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass { + ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) + if !pass { m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", ruleID, pnum, srcIP, srcPort, dstIP, dstPort) @@ -790,8 +851,11 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe if fwd == nil { m.logger.Trace("failed to forward routed packet (forwarder not initialized)") } else { + fwd.RegisterRuleID(srcIP, dstIP, srcPort, dstPort, ruleID) + if err := fwd.InjectIncomingPacket(packetData); err != nil { m.logger.Error("Failed to inject routed packet: %v", err) + fwd.DeleteRuleID(srcIP, dstIP, srcPort, dstPort) } } @@ -988,8 +1052,15 @@ func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol return nil, false } -func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool { - if !rule.destination.Contains(dstAddr) { +func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool { + destMatched := false + for _, dst := range rule.destinations { + if dst.Contains(dstAddr) { + destMatched = true + break + } + } + if !destMatched { return false } @@ -1091,7 +1162,22 @@ func (m *Manager) EnableRouting() error { m.mutex.Lock() defer m.mutex.Unlock() - return m.determineRouting() + if err := m.determineRouting(); err != nil { + return fmt.Errorf("determine routing: %w", err) + } + + if m.forwarder.Load() == nil { + return nil + } + + rule, err := m.blockInvalidRouted(m.wgIface) + if err != nil { + return fmt.Errorf("block invalid routed: %w", err) + } + + m.blockRule = rule + + return nil } func (m *Manager) DisableRouting() error { @@ -1116,5 +1202,12 @@ func (m *Manager) DisableRouting() error { log.Debug("forwarder stopped") + if m.blockRule != nil { + if err := m.deleteRouteRule(m.blockRule); err != nil { + return fmt.Errorf("delete block rule: %w", err) + } + m.blockRule = nil + } + return nil } diff --git a/client/firewall/uspfilter/uspfilter_filter_test.go b/client/firewall/uspfilter/uspfilter_filter_test.go index ba97c2643..04a398d1f 100644 --- a/client/firewall/uspfilter/uspfilter_filter_test.go +++ b/client/firewall/uspfilter/uspfilter_filter_test.go @@ -15,6 +15,7 @@ import ( "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/management/domain" ) func TestPeerACLFiltering(t *testing.T) { @@ -188,6 +189,281 @@ func TestPeerACLFiltering(t *testing.T) { ruleAction: fw.ActionAccept, shouldBeBlocked: true, }, + { + name: "Allow TCP traffic without port specification", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "Allow UDP traffic without port specification", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolUDP, + srcPort: 12345, + dstPort: 53, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolUDP, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "TCP packet doesn't match UDP filter with same port", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolUDP, + ruleDstPort: &fw.Port{Values: []uint16{443}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: true, + }, + { + name: "UDP packet doesn't match TCP filter with same port", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolUDP, + srcPort: 12345, + dstPort: 443, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{Values: []uint16{443}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: true, + }, + { + name: "ICMP packet doesn't match TCP filter", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolICMP, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleAction: fw.ActionAccept, + shouldBeBlocked: true, + }, + { + name: "ICMP packet doesn't match UDP filter", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolICMP, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolUDP, + ruleAction: fw.ActionAccept, + shouldBeBlocked: true, + }, + { + name: "Allow TCP traffic within port range", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 8080, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "Block TCP traffic outside port range", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 7999, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: true, + }, + { + name: "Edge Case - Port at Range Boundary", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 8100, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "UDP Port Range", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolUDP, + srcPort: 12345, + dstPort: 5060, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolUDP, + ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{5060, 5070}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "Allow multiple destination ports", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 8080, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{Values: []uint16{80, 8080, 443}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "Allow multiple source ports", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleSrcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + // New drop test cases + { + name: "Drop TCP traffic from WG peer", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{Values: []uint16{443}}, + ruleAction: fw.ActionDrop, + shouldBeBlocked: true, + }, + { + name: "Drop UDP traffic from WG peer", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolUDP, + srcPort: 12345, + dstPort: 53, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolUDP, + ruleDstPort: &fw.Port{Values: []uint16{53}}, + ruleAction: fw.ActionDrop, + shouldBeBlocked: true, + }, + { + name: "Drop ICMP traffic from WG peer", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolICMP, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolICMP, + ruleAction: fw.ActionDrop, + shouldBeBlocked: true, + }, + { + name: "Drop all traffic from WG peer", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolALL, + ruleAction: fw.ActionDrop, + shouldBeBlocked: true, + }, + { + name: "Drop traffic from multiple source ports", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleSrcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}}, + ruleAction: fw.ActionDrop, + shouldBeBlocked: true, + }, + { + name: "Drop multiple destination ports", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 8080, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{Values: []uint16{80, 8080, 443}}, + ruleAction: fw.ActionDrop, + shouldBeBlocked: true, + }, + { + name: "Drop TCP traffic within port range", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 8080, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}}, + ruleAction: fw.ActionDrop, + shouldBeBlocked: true, + }, + { + name: "Accept TCP traffic outside drop port range", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 7999, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}}, + ruleAction: fw.ActionDrop, + shouldBeBlocked: false, + }, + { + name: "Drop TCP traffic with source port range", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 32100, + dstPort: 80, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleSrcPort: &fw.Port{IsRange: true, Values: []uint16{32000, 33000}}, + ruleAction: fw.ActionDrop, + shouldBeBlocked: true, + }, + { + name: "Mixed rule - drop specific port but allow other ports", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{Values: []uint16{443}}, + ruleAction: fw.ActionDrop, + shouldBeBlocked: true, + }, } t.Run("Implicit DROP (no rules)", func(t *testing.T) { @@ -198,6 +474,28 @@ func TestPeerACLFiltering(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + + if tc.ruleAction == fw.ActionDrop { + // add general accept rule to test drop rule + // TODO: this only works because 0.0.0.0 is tested last, we need to implement order + rules, err := manager.AddPeerFiltering( + nil, + net.ParseIP("0.0.0.0"), + fw.ProtocolALL, + nil, + nil, + fw.ActionAccept, + "", + ) + require.NoError(t, err) + require.NotEmpty(t, rules) + t.Cleanup(func() { + for _, rule := range rules { + require.NoError(t, manager.DeletePeerRule(rule)) + } + }) + } + rules, err := manager.AddPeerFiltering( nil, net.ParseIP(tc.ruleIP), @@ -303,8 +601,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager { } manager, err := Create(ifaceMock, false, flowLogger) - require.NoError(tb, manager.EnableRouting()) require.NoError(tb, err) + require.NoError(tb, manager.EnableRouting()) require.NotNil(tb, manager) require.True(tb, manager.routingEnabled.Load()) require.False(tb, manager.nativeRouter.Load()) @@ -321,7 +619,7 @@ func TestRouteACLFiltering(t *testing.T) { type rule struct { sources []netip.Prefix - dest netip.Prefix + dest fw.Network proto fw.Protocol srcPort *fw.Port dstPort *fw.Port @@ -347,7 +645,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 443, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{443}}, action: fw.ActionAccept, @@ -363,7 +661,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 443, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{443}}, action: fw.ActionAccept, @@ -379,7 +677,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 443, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, - dest: netip.MustParsePrefix("0.0.0.0/0"), + dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{443}}, action: fw.ActionAccept, @@ -395,7 +693,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 53, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolUDP, dstPort: &fw.Port{Values: []uint16{53}}, action: fw.ActionAccept, @@ -409,7 +707,7 @@ func TestRouteACLFiltering(t *testing.T) { proto: fw.ProtocolICMP, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("0.0.0.0/0"), + dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")}, proto: fw.ProtocolICMP, action: fw.ActionAccept, }, @@ -424,7 +722,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolALL, dstPort: &fw.Port{Values: []uint16{80}}, action: fw.ActionAccept, @@ -440,7 +738,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 8080, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{80}}, action: fw.ActionAccept, @@ -456,7 +754,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{80}}, action: fw.ActionAccept, @@ -472,7 +770,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{80}}, action: fw.ActionAccept, @@ -488,7 +786,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, srcPort: &fw.Port{Values: []uint16{12345}}, action: fw.ActionAccept, @@ -507,7 +805,7 @@ func TestRouteACLFiltering(t *testing.T) { netip.MustParsePrefix("100.10.0.0/16"), netip.MustParsePrefix("172.16.0.0/16"), }, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{80}}, action: fw.ActionAccept, @@ -521,7 +819,7 @@ func TestRouteACLFiltering(t *testing.T) { proto: fw.ProtocolICMP, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolALL, action: fw.ActionAccept, }, @@ -536,33 +834,13 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolALL, dstPort: &fw.Port{Values: []uint16{80}}, action: fw.ActionAccept, }, shouldPass: true, }, - { - name: "Multiple source networks with mismatched protocol", - srcIP: "172.16.0.1", - dstIP: "192.168.1.100", - // Should not match TCP rule - proto: fw.ProtocolUDP, - srcPort: 12345, - dstPort: 80, - rule: rule{ - sources: []netip.Prefix{ - netip.MustParsePrefix("100.10.0.0/16"), - netip.MustParsePrefix("172.16.0.0/16"), - }, - dest: netip.MustParsePrefix("192.168.1.0/24"), - proto: fw.ProtocolTCP, - dstPort: &fw.Port{Values: []uint16{80}}, - action: fw.ActionAccept, - }, - shouldPass: false, - }, { name: "Allow multiple destination ports", srcIP: "100.10.0.1", @@ -572,7 +850,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 8080, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{80, 8080, 443}}, action: fw.ActionAccept, @@ -588,7 +866,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, srcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}}, action: fw.ActionAccept, @@ -604,7 +882,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolALL, srcPort: &fw.Port{Values: []uint16{12345}}, dstPort: &fw.Port{Values: []uint16{80}}, @@ -621,7 +899,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 8080, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{ IsRange: true, @@ -640,7 +918,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 7999, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{ IsRange: true, @@ -659,7 +937,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, srcPort: &fw.Port{ IsRange: true, @@ -678,7 +956,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 443, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, srcPort: &fw.Port{ IsRange: true, @@ -700,7 +978,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 8100, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{ IsRange: true, @@ -719,7 +997,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 5060, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolUDP, dstPort: &fw.Port{ IsRange: true, @@ -738,7 +1016,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 8080, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolALL, dstPort: &fw.Port{ IsRange: true, @@ -757,7 +1035,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 443, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{443}}, action: fw.ActionDrop, @@ -773,7 +1051,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolALL, action: fw.ActionDrop, }, @@ -791,17 +1069,158 @@ func TestRouteACLFiltering(t *testing.T) { netip.MustParsePrefix("100.10.0.0/16"), netip.MustParsePrefix("172.16.0.0/16"), }, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{80}}, action: fw.ActionDrop, }, shouldPass: false, }, + + { + name: "Drop empty destination set", + srcIP: "172.16.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/16"), + }, + dest: fw.Network{Set: fw.Set{}}, + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{80}}, + action: fw.ActionAccept, + }, + shouldPass: false, + }, + { + name: "Accept TCP traffic outside drop port range", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 7999, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, + proto: fw.ProtocolTCP, + dstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}}, + action: fw.ActionDrop, + }, + shouldPass: true, + }, + { + name: "Allow TCP traffic without port specification", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, + proto: fw.ProtocolTCP, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Allow UDP traffic without port specification", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolUDP, + srcPort: 12345, + dstPort: 53, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, + proto: fw.ProtocolUDP, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "TCP packet doesn't match UDP filter with same port", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, + proto: fw.ProtocolUDP, + dstPort: &fw.Port{Values: []uint16{80}}, + action: fw.ActionAccept, + }, + shouldPass: false, + }, + { + name: "UDP packet doesn't match TCP filter with same port", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolUDP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{80}}, + action: fw.ActionAccept, + }, + shouldPass: false, + }, + { + name: "ICMP packet doesn't match TCP filter", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolICMP, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, + proto: fw.ProtocolTCP, + action: fw.ActionAccept, + }, + shouldPass: false, + }, + { + name: "ICMP packet doesn't match UDP filter", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolICMP, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, + proto: fw.ProtocolUDP, + action: fw.ActionAccept, + }, + shouldPass: false, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + if tc.rule.action == fw.ActionDrop { + // add general accept rule to test drop rule + rule, err := manager.AddRouteFiltering( + nil, + []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, + fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")}, + fw.ProtocolALL, + nil, + nil, + fw.ActionAccept, + ) + require.NoError(t, err) + require.NotNil(t, rule) + t.Cleanup(func() { + require.NoError(t, manager.DeleteRouteRule(rule)) + }) + } + rule, err := manager.AddRouteFiltering( nil, tc.rule.sources, @@ -836,7 +1255,7 @@ func TestRouteACLOrder(t *testing.T) { name string rules []struct { sources []netip.Prefix - dest netip.Prefix + dest fw.Network proto fw.Protocol srcPort *fw.Port dstPort *fw.Port @@ -857,7 +1276,7 @@ func TestRouteACLOrder(t *testing.T) { name: "Drop rules take precedence over accept", rules: []struct { sources []netip.Prefix - dest netip.Prefix + dest fw.Network proto fw.Protocol srcPort *fw.Port dstPort *fw.Port @@ -866,7 +1285,7 @@ func TestRouteACLOrder(t *testing.T) { { // Accept rule added first sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{80, 443}}, action: fw.ActionAccept, @@ -874,7 +1293,7 @@ func TestRouteACLOrder(t *testing.T) { { // Drop rule added second but should be evaluated first sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{443}}, action: fw.ActionDrop, @@ -912,7 +1331,7 @@ func TestRouteACLOrder(t *testing.T) { name: "Multiple drop rules take precedence", rules: []struct { sources []netip.Prefix - dest netip.Prefix + dest fw.Network proto fw.Protocol srcPort *fw.Port dstPort *fw.Port @@ -921,14 +1340,14 @@ func TestRouteACLOrder(t *testing.T) { { // Accept all sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, - dest: netip.MustParsePrefix("0.0.0.0/0"), + dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")}, proto: fw.ProtocolALL, action: fw.ActionAccept, }, { // Drop specific port sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{443}}, action: fw.ActionDrop, @@ -936,7 +1355,7 @@ func TestRouteACLOrder(t *testing.T) { { // Drop different port sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{80}}, action: fw.ActionDrop, @@ -1015,3 +1434,53 @@ func TestRouteACLOrder(t *testing.T) { }) } } + +func TestRouteACLSet(t *testing.T) { + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ + IP: net.ParseIP("100.10.0.100"), + Network: &net.IPNet{ + IP: net.ParseIP("100.10.0.0"), + Mask: net.CIDRMask(16, 32), + }, + } + }, + } + + manager, err := Create(ifaceMock, false, flowLogger) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, manager.Close(nil)) + }) + + set := fw.NewDomainSet(domain.List{"example.org"}) + + // Add rule that uses the set (initially empty) + rule, err := manager.AddRouteFiltering( + nil, + []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, + fw.Network{Set: set}, + fw.ProtocolTCP, + nil, + nil, + fw.ActionAccept, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + srcIP := netip.MustParseAddr("100.10.0.1") + dstIP := netip.MustParseAddr("192.168.1.100") + + // Check that traffic is dropped (empty set shouldn't match anything) + _, isAllowed := manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80) + require.False(t, isAllowed, "Empty set should not allow any traffic") + + err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")}) + require.NoError(t, err) + + // Now the packet should be allowed + _, isAllowed = manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80) + require.True(t, isAllowed, "After set update, traffic to the added network should be allowed") +} diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index a48a483f8..24a6a2c40 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/netflow" + "github.com/netbirdio/netbird/management/domain" ) var logger = log.NewFromLogrus(logrus.StandardLogger()) @@ -711,3 +712,203 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { }) } } + +func TestUpdateSetMerge(t *testing.T) { + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + } + + manager, err := Create(ifaceMock, false, flowLogger) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, manager.Close(nil)) + }) + + set := fw.NewDomainSet(domain.List{"example.org"}) + + initialPrefixes := []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.1.0/24"), + } + + rule, err := manager.AddRouteFiltering( + nil, + []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, + fw.Network{Set: set}, + fw.ProtocolTCP, + nil, + nil, + fw.ActionAccept, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + // Update the set with initial prefixes + err = manager.UpdateSet(set, initialPrefixes) + require.NoError(t, err) + + // Test initial prefixes work + srcIP := netip.MustParseAddr("100.10.0.1") + dstIP1 := netip.MustParseAddr("10.0.0.100") + dstIP2 := netip.MustParseAddr("192.168.1.100") + dstIP3 := netip.MustParseAddr("172.16.0.100") + + _, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80) + _, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80) + _, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, fw.ProtocolTCP, 12345, 80) + + require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed") + require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed") + require.False(t, isAllowed3, "Traffic to 172.16.0.100 should be denied") + + newPrefixes := []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/16"), + netip.MustParsePrefix("10.1.0.0/24"), + } + + err = manager.UpdateSet(set, newPrefixes) + require.NoError(t, err) + + // Check that all original prefixes are still included + _, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80) + _, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80) + require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update") + require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update") + + // Check that new prefixes are included + dstIP4 := netip.MustParseAddr("172.16.1.100") + dstIP5 := netip.MustParseAddr("10.1.0.50") + + _, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, fw.ProtocolTCP, 12345, 80) + _, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, fw.ProtocolTCP, 12345, 80) + + require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed") + require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed") + + // Verify the rule has all prefixes + manager.mutex.RLock() + foundRule := false + for _, r := range manager.routeRules { + if r.id == rule.ID() { + foundRule = true + require.Len(t, r.destinations, len(initialPrefixes)+len(newPrefixes), + "Rule should have all prefixes merged") + } + } + manager.mutex.RUnlock() + require.True(t, foundRule, "Rule should be found") +} + +func TestUpdateSetDeduplication(t *testing.T) { + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + } + + manager, err := Create(ifaceMock, false, flowLogger) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, manager.Close(nil)) + }) + + set := fw.NewDomainSet(domain.List{"example.org"}) + + rule, err := manager.AddRouteFiltering( + nil, + []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, + fw.Network{Set: set}, + fw.ProtocolTCP, + nil, + nil, + fw.ActionAccept, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + initialPrefixes := []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("10.0.0.0/24"), // Duplicate + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.1.0/24"), // Duplicate + } + + err = manager.UpdateSet(set, initialPrefixes) + require.NoError(t, err) + + // Check the internal state for deduplication + manager.mutex.RLock() + foundRule := false + for _, r := range manager.routeRules { + if r.id == rule.ID() { + foundRule = true + // Should have deduplicated to 2 prefixes + require.Len(t, r.destinations, 2, "Duplicate prefixes should be removed") + + // Check the prefixes are correct + expectedPrefixes := []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.1.0/24"), + } + for i, prefix := range expectedPrefixes { + require.True(t, r.destinations[i] == prefix, + "Prefix should match expected value") + } + } + } + manager.mutex.RUnlock() + require.True(t, foundRule, "Rule should be found") + + // Test with overlapping prefixes of different sizes + overlappingPrefixes := []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/16"), // More general + netip.MustParsePrefix("10.0.0.0/24"), // More specific (already exists) + netip.MustParsePrefix("192.168.0.0/16"), // More general + netip.MustParsePrefix("192.168.1.0/24"), // More specific (already exists) + } + + err = manager.UpdateSet(set, overlappingPrefixes) + require.NoError(t, err) + + // Check that all prefixes are included (no deduplication of overlapping prefixes) + manager.mutex.RLock() + for _, r := range manager.routeRules { + if r.id == rule.ID() { + // Should have all 4 prefixes (2 original + 2 new more general ones) + require.Len(t, r.destinations, 4, + "Overlapping prefixes should not be deduplicated") + + // Verify they're sorted correctly (more specific prefixes should come first) + prefixes := make([]string, 0, len(r.destinations)) + for _, p := range r.destinations { + prefixes = append(prefixes, p.String()) + } + + // Check sorted order + require.Equal(t, []string{ + "10.0.0.0/16", + "10.0.0.0/24", + "192.168.0.0/16", + "192.168.1.0/24", + }, prefixes, "Prefixes should be sorted") + } + } + manager.mutex.RUnlock() + + // Test functionality with all prefixes + testCases := []struct { + dstIP netip.Addr + expected bool + desc string + }{ + {netip.MustParseAddr("10.0.0.100"), true, "IP in both /16 and /24"}, + {netip.MustParseAddr("10.0.1.100"), true, "IP only in /16"}, + {netip.MustParseAddr("192.168.1.100"), true, "IP in both /16 and /24"}, + {netip.MustParseAddr("192.168.2.100"), true, "IP only in /16"}, + {netip.MustParseAddr("172.16.0.100"), false, "IP not in any prefix"}, + } + + srcIP := netip.MustParseAddr("100.10.0.1") + for _, tc := range testCases { + _, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, fw.ProtocolTCP, 12345, 80) + require.Equal(t, tc.expected, isAllowed, tc.desc) + } +} diff --git a/client/internal/acl/id/id.go b/client/internal/acl/id/id.go index 93f16b429..23451453e 100644 --- a/client/internal/acl/id/id.go +++ b/client/internal/acl/id/id.go @@ -18,7 +18,7 @@ func (r RuleID) ID() string { func GenerateRouteRuleKey( sources []netip.Prefix, - destination netip.Prefix, + destination manager.Network, proto manager.Protocol, sPort *manager.Port, dPort *manager.Port, diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index 61fbb10ca..6fa35d5c2 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -18,6 +18,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/ssh" + "github.com/netbirdio/netbird/management/domain" mgmProto "github.com/netbirdio/netbird/management/proto" ) @@ -25,7 +26,7 @@ var ErrSourceRangesEmpty = errors.New("sources range is empty") // Manager is a ACL rules manager type Manager interface { - ApplyFiltering(networkMap *mgmProto.NetworkMap) + ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) } type protoMatch struct { @@ -53,7 +54,7 @@ func NewDefaultManager(fm firewall.Manager) *DefaultManager { // ApplyFiltering firewall rules to the local firewall manager processed by ACL policy. // // If allowByDefault is true it appends allow ALL traffic rules to input and output chains. -func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { +func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) { d.mutex.Lock() defer d.mutex.Unlock() @@ -82,7 +83,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { log.Errorf("failed to set legacy management flag: %v", err) } - if err := d.applyRouteACLs(networkMap.RoutesFirewallRules); err != nil { + if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil { log.Errorf("Failed to apply route ACLs: %v", err) } @@ -176,16 +177,16 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { d.peerRulesPairs = newRulePairs } -func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error { +func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dynamicResolver bool) error { newRouteRules := make(map[id.RuleID]struct{}, len(rules)) var merr *multierror.Error // Apply new rules - firewall manager will return existing rule ID if already present for _, rule := range rules { - id, err := d.applyRouteACL(rule) + id, err := d.applyRouteACL(rule, dynamicResolver) if err != nil { if errors.Is(err, ErrSourceRangesEmpty) { - log.Debugf("skipping empty rule with destination %s: %v", rule.Destination, err) + log.Debugf("skipping empty sources rule with destination %s: %v", rule.Destination, err) } else { merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err)) } @@ -208,7 +209,7 @@ func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) err return nberrors.FormatErrorOrNil(merr) } -func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) { +func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamicResolver bool) (id.RuleID, error) { if len(rule.SourceRanges) == 0 { return "", ErrSourceRangesEmpty } @@ -222,15 +223,9 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.Rul sources = append(sources, source) } - var destination netip.Prefix - if rule.IsDynamic { - destination = getDefault(sources[0]) - } else { - var err error - destination, err = netip.ParsePrefix(rule.Destination) - if err != nil { - return "", fmt.Errorf("parse destination: %w", err) - } + destination, err := determineDestination(rule, dynamicResolver, sources) + if err != nil { + return "", fmt.Errorf("determine destination: %w", err) } protocol, err := convertToFirewallProtocol(rule.Protocol) @@ -580,6 +575,33 @@ func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port { return nil } +func determineDestination(rule *mgmProto.RouteFirewallRule, dynamicResolver bool, sources []netip.Prefix) (firewall.Network, error) { + var destination firewall.Network + + if rule.IsDynamic { + if dynamicResolver { + if len(rule.Domains) > 0 { + destination.Set = firewall.NewDomainSet(domain.FromPunycodeList(rule.Domains)) + } else { + // isDynamic is set but no domains = outdated management server + log.Warn("connected to an older version of management server (no domains in rules), using default destination") + destination.Prefix = getDefault(sources[0]) + } + } else { + // client resolves DNS, we (router) don't know the destination + destination.Prefix = getDefault(sources[0]) + } + return destination, nil + } + + prefix, err := netip.ParsePrefix(rule.Destination) + if err != nil { + return destination, fmt.Errorf("parse destination: %w", err) + } + destination.Prefix = prefix + return destination, nil +} + func getDefault(prefix netip.Prefix) netip.Prefix { if prefix.Addr().Is6() { return netip.PrefixFrom(netip.IPv6Unspecified(), 0) diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 9488d33ab..3595ca600 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -66,7 +66,7 @@ func TestDefaultManager(t *testing.T) { acl := NewDefaultManager(fw) t.Run("apply firewall rules", func(t *testing.T) { - acl.ApplyFiltering(networkMap) + acl.ApplyFiltering(networkMap, false) if len(acl.peerRulesPairs) != 2 { t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs) @@ -92,7 +92,7 @@ func TestDefaultManager(t *testing.T) { }, ) - acl.ApplyFiltering(networkMap) + acl.ApplyFiltering(networkMap, false) // we should have one old and one new rule in the existed rules if len(acl.peerRulesPairs) != 2 { @@ -116,13 +116,13 @@ func TestDefaultManager(t *testing.T) { networkMap.FirewallRules = networkMap.FirewallRules[:0] networkMap.FirewallRulesIsEmpty = true - if acl.ApplyFiltering(networkMap); len(acl.peerRulesPairs) != 0 { + if acl.ApplyFiltering(networkMap, false); len(acl.peerRulesPairs) != 0 { t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs)) return } networkMap.FirewallRulesIsEmpty = false - acl.ApplyFiltering(networkMap) + acl.ApplyFiltering(networkMap, false) if len(acl.peerRulesPairs) != 1 { t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs)) return @@ -359,7 +359,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { }(fw) acl := NewDefaultManager(fw) - acl.ApplyFiltering(networkMap) + acl.ApplyFiltering(networkMap, false) if len(acl.peerRulesPairs) != 3 { t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs)) diff --git a/client/internal/debug/debug_linux.go b/client/internal/debug/debug_linux.go index 291531fea..b4907beca 100644 --- a/client/internal/debug/debug_linux.go +++ b/client/internal/debug/debug_linux.go @@ -59,6 +59,16 @@ func collectIPTablesRules() (string, error) { builder.WriteString("\n") } + // Collect ipset information + ipsetOutput, err := collectIPSets() + if err != nil { + log.Warnf("Failed to collect ipset information: %v", err) + } else { + builder.WriteString("=== ipset list output ===\n") + builder.WriteString(ipsetOutput) + builder.WriteString("\n") + } + builder.WriteString("=== iptables -v -n -L output ===\n") tables := []string{"filter", "nat", "mangle", "raw", "security"} @@ -78,6 +88,28 @@ func collectIPTablesRules() (string, error) { return builder.String(), nil } +// collectIPSets collects information about ipsets +func collectIPSets() (string, error) { + cmd := exec.Command("ipset", "list") + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + if strings.Contains(err.Error(), "executable file not found") { + return "", fmt.Errorf("ipset command not found: %w", err) + } + return "", fmt.Errorf("execute ipset list: %w (stderr: %s)", err, stderr.String()) + } + + ipsets := stdout.String() + if strings.TrimSpace(ipsets) == "" { + return "No ipsets found", nil + } + + return ipsets, nil +} + // collectIPTablesSave uses iptables-save to get rule definitions func collectIPTablesSave() (string, error) { cmd := exec.Command("iptables-save") diff --git a/client/internal/dns/handler_chain_test.go b/client/internal/dns/handler_chain_test.go index 4c910a95f..5f03e0758 100644 --- a/client/internal/dns/handler_chain_test.go +++ b/client/internal/dns/handler_chain_test.go @@ -1,7 +1,6 @@ package dns_test import ( - "net" "testing" "github.com/miekg/dns" @@ -9,6 +8,7 @@ import ( "github.com/stretchr/testify/mock" nbdns "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/dns/test" ) // TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order @@ -30,7 +30,7 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) { r.SetQuestion("example.com.", dns.TypeA) // Create test writer - w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} // Setup expectations - only highest priority handler should be called dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once() @@ -142,7 +142,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) { r := new(dns.Msg) r.SetQuestion(tt.queryDomain, dns.TypeA) - w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} chain.ServeDNS(w, r) @@ -259,7 +259,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) { // Create and execute request r := new(dns.Msg) r.SetQuestion(tt.queryDomain, dns.TypeA) - w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} chain.ServeDNS(w, r) // Verify expectations @@ -316,7 +316,7 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) { }).Once() // Execute - w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} chain.ServeDNS(w, r) // Verify all handlers were called in order @@ -325,20 +325,6 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) { handler3.AssertExpectations(t) } -// mockResponseWriter implements dns.ResponseWriter for testing -type mockResponseWriter struct { - mock.Mock -} - -func (m *mockResponseWriter) LocalAddr() net.Addr { return nil } -func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil } -func (m *mockResponseWriter) WriteMsg(*dns.Msg) error { return nil } -func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil } -func (m *mockResponseWriter) Close() error { return nil } -func (m *mockResponseWriter) TsigStatus() error { return nil } -func (m *mockResponseWriter) TsigTimersOnly(bool) {} -func (m *mockResponseWriter) Hijack() {} - func TestHandlerChain_PriorityDeregistration(t *testing.T) { tests := []struct { name string @@ -425,7 +411,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) { // Create test request r := new(dns.Msg) r.SetQuestion(tt.query, dns.TypeA) - w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} // Setup expectations for priority, handler := range handlers { @@ -471,7 +457,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) { chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain) // Test 1: Initial state - w1 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} // Highest priority handler (routeHandler) should be called routeHandler.On("ServeDNS", mock.Anything, r).Return().Once() matchHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet @@ -490,7 +476,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) { // Test 2: Remove highest priority handler chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute) - w2 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w2 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} // Now middle priority handler (matchHandler) should be called matchHandler.On("ServeDNS", mock.Anything, r).Return().Once() defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure default is not expected yet @@ -506,7 +492,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) { // Test 3: Remove middle priority handler chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain) - w3 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} // Now lowest priority handler (defaultHandler) should be called defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once() @@ -519,7 +505,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) { // Test 4: Remove last handler chain.RemoveHandler(testDomain, nbdns.PriorityDefault) - w4 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w4 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} chain.ServeDNS(w4, r) // Call ServeDNS on the now empty chain for this domain for _, m := range mocks { @@ -675,7 +661,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) { // Execute request r := new(dns.Msg) r.SetQuestion(tt.query, dns.TypeA) - chain.ServeDNS(&mockResponseWriter{}, r) + chain.ServeDNS(&test.MockResponseWriter{}, r) // Verify each handler was called exactly as expected for _, h := range tt.addHandlers { @@ -819,7 +805,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) { r := new(dns.Msg) r.SetQuestion(tt.query, dns.TypeA) - w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} // Setup handler expectations for pattern, handler := range handlers { @@ -969,7 +955,7 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) { handler := &nbdns.MockHandler{} r := new(dns.Msg) r.SetQuestion(tt.queryPattern, dns.TypeA) - w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} // First verify no handler is called before adding any chain.ServeDNS(w, r) diff --git a/client/internal/dns/local.go b/client/internal/dns/local.go deleted file mode 100644 index 76e18e3ce..000000000 --- a/client/internal/dns/local.go +++ /dev/null @@ -1,130 +0,0 @@ -package dns - -import ( - "fmt" - "strings" - "sync" - - "github.com/miekg/dns" - log "github.com/sirupsen/logrus" - - nbdns "github.com/netbirdio/netbird/dns" -) - -type registrationMap map[string]struct{} - -type localResolver struct { - registeredMap registrationMap - records sync.Map // key: string (domain_class_type), value: []dns.RR -} - -func (d *localResolver) MatchSubdomains() bool { - return true -} - -func (d *localResolver) stop() { -} - -// String returns a string representation of the local resolver -func (d *localResolver) String() string { - return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap)) -} - -// ID returns the unique handler ID -func (d *localResolver) id() handlerID { - return "local-resolver" -} - -// ServeDNS handles a DNS request -func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - if len(r.Question) > 0 { - log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) - } - - replyMessage := &dns.Msg{} - replyMessage.SetReply(r) - replyMessage.RecursionAvailable = true - - // lookup all records matching the question - records := d.lookupRecords(r) - if len(records) > 0 { - replyMessage.Rcode = dns.RcodeSuccess - replyMessage.Answer = append(replyMessage.Answer, records...) - } else { - replyMessage.Rcode = dns.RcodeNameError - } - - err := w.WriteMsg(replyMessage) - if err != nil { - log.Debugf("got an error while writing the local resolver response, error: %v", err) - } -} - -// lookupRecords fetches *all* DNS records matching the first question in r. -func (d *localResolver) lookupRecords(r *dns.Msg) []dns.RR { - if len(r.Question) == 0 { - return nil - } - question := r.Question[0] - question.Name = strings.ToLower(question.Name) - key := buildRecordKey(question.Name, question.Qclass, question.Qtype) - - value, found := d.records.Load(key) - if !found { - // alternatively check if we have a cname - if question.Qtype != dns.TypeCNAME { - r.Question[0].Qtype = dns.TypeCNAME - return d.lookupRecords(r) - } - - return nil - } - - records, ok := value.([]dns.RR) - if !ok { - log.Errorf("failed to cast records to []dns.RR, records: %v", value) - return nil - } - - // if there's more than one record, rotate them (round-robin) - if len(records) > 1 { - first := records[0] - records = append(records[1:], first) - d.records.Store(key, records) - } - - return records -} - -// registerRecord stores a new record by appending it to any existing list -func (d *localResolver) registerRecord(record nbdns.SimpleRecord) (string, error) { - rr, err := dns.NewRR(record.String()) - if err != nil { - return "", fmt.Errorf("register record: %w", err) - } - - rr.Header().Rdlength = record.Len() - header := rr.Header() - key := buildRecordKey(header.Name, header.Class, header.Rrtype) - - // load any existing slice of records, then append - existing, _ := d.records.LoadOrStore(key, []dns.RR{}) - records := existing.([]dns.RR) - records = append(records, rr) - - // store updated slice - d.records.Store(key, records) - return key, nil -} - -// deleteRecord removes *all* records under the recordKey. -func (d *localResolver) deleteRecord(recordKey string) { - d.records.Delete(dns.Fqdn(recordKey)) -} - -// buildRecordKey consistently generates a key: name_class_type -func buildRecordKey(name string, class, qType uint16) string { - return fmt.Sprintf("%s_%d_%d", dns.Fqdn(name), class, qType) -} - -func (d *localResolver) probeAvailability() {} diff --git a/client/internal/dns/local/local.go b/client/internal/dns/local/local.go new file mode 100644 index 000000000..de3d8514b --- /dev/null +++ b/client/internal/dns/local/local.go @@ -0,0 +1,149 @@ +package local + +import ( + "fmt" + "slices" + "strings" + "sync" + + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + + "github.com/netbirdio/netbird/client/internal/dns/types" + nbdns "github.com/netbirdio/netbird/dns" +) + +type Resolver struct { + mu sync.RWMutex + records map[dns.Question][]dns.RR +} + +func NewResolver() *Resolver { + return &Resolver{ + records: make(map[dns.Question][]dns.RR), + } +} + +func (d *Resolver) MatchSubdomains() bool { + return true +} + +// String returns a string representation of the local resolver +func (d *Resolver) String() string { + return fmt.Sprintf("local resolver [%d records]", len(d.records)) +} + +func (d *Resolver) Stop() {} + +// ID returns the unique handler ID +func (d *Resolver) ID() types.HandlerID { + return "local-resolver" +} + +func (d *Resolver) ProbeAvailability() {} + +// ServeDNS handles a DNS request +func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + if len(r.Question) == 0 { + log.Debugf("received local resolver request with no question") + return + } + question := r.Question[0] + question.Name = strings.ToLower(dns.Fqdn(question.Name)) + + log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, question.Qtype, question.Qclass) + + replyMessage := &dns.Msg{} + replyMessage.SetReply(r) + replyMessage.RecursionAvailable = true + + // lookup all records matching the question + records := d.lookupRecords(question) + if len(records) > 0 { + replyMessage.Rcode = dns.RcodeSuccess + replyMessage.Answer = append(replyMessage.Answer, records...) + } else { + // TODO: return success if we have a different record type for the same name, relevant for search domains + replyMessage.Rcode = dns.RcodeNameError + } + + if err := w.WriteMsg(replyMessage); err != nil { + log.Warnf("failed to write the local resolver response: %v", err) + } +} + +// lookupRecords fetches *all* DNS records matching the first question in r. +func (d *Resolver) lookupRecords(question dns.Question) []dns.RR { + d.mu.RLock() + records, found := d.records[question] + + if !found { + d.mu.RUnlock() + // alternatively check if we have a cname + if question.Qtype != dns.TypeCNAME { + question.Qtype = dns.TypeCNAME + return d.lookupRecords(question) + } + return nil + } + + recordsCopy := slices.Clone(records) + d.mu.RUnlock() + + // if there's more than one record, rotate them (round-robin) + if len(recordsCopy) > 1 { + d.mu.Lock() + records = d.records[question] + if len(records) > 1 { + first := records[0] + records = append(records[1:], first) + d.records[question] = records + } + d.mu.Unlock() + } + + return recordsCopy +} + +func (d *Resolver) Update(update []nbdns.SimpleRecord) { + d.mu.Lock() + defer d.mu.Unlock() + + maps.Clear(d.records) + + for _, rec := range update { + if err := d.registerRecord(rec); err != nil { + log.Warnf("failed to register the record (%s): %v", rec, err) + continue + } + } +} + +// RegisterRecord stores a new record by appending it to any existing list +func (d *Resolver) RegisterRecord(record nbdns.SimpleRecord) error { + d.mu.Lock() + defer d.mu.Unlock() + + return d.registerRecord(record) +} + +// registerRecord performs the registration with the lock already held +func (d *Resolver) registerRecord(record nbdns.SimpleRecord) error { + rr, err := dns.NewRR(record.String()) + if err != nil { + return fmt.Errorf("register record: %w", err) + } + + rr.Header().Rdlength = record.Len() + header := rr.Header() + q := dns.Question{ + Name: strings.ToLower(dns.Fqdn(header.Name)), + Qtype: header.Rrtype, + Qclass: header.Class, + } + + d.records[q] = append(d.records[q], rr) + + return nil +} diff --git a/client/internal/dns/local/local_test.go b/client/internal/dns/local/local_test.go new file mode 100644 index 000000000..1d38191e7 --- /dev/null +++ b/client/internal/dns/local/local_test.go @@ -0,0 +1,472 @@ +package local + +import ( + "strings" + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/dns/test" + nbdns "github.com/netbirdio/netbird/dns" +) + +func TestLocalResolver_ServeDNS(t *testing.T) { + recordA := nbdns.SimpleRecord{ + Name: "peera.netbird.cloud.", + Type: 1, + Class: nbdns.DefaultClass, + TTL: 300, + RData: "1.2.3.4", + } + + recordCNAME := nbdns.SimpleRecord{ + Name: "peerb.netbird.cloud.", + Type: 5, + Class: nbdns.DefaultClass, + TTL: 300, + RData: "www.netbird.io", + } + + testCases := []struct { + name string + inputRecord nbdns.SimpleRecord + inputMSG *dns.Msg + responseShouldBeNil bool + }{ + { + name: "Should Resolve A Record", + inputRecord: recordA, + inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA), + }, + { + name: "Should Resolve CNAME Record", + inputRecord: recordCNAME, + inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME), + }, + { + name: "Should Not Write When Not Found A Record", + inputRecord: recordA, + inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA), + responseShouldBeNil: true, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + resolver := NewResolver() + _ = resolver.RegisterRecord(testCase.inputRecord) + var responseMSG *dns.Msg + responseWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responseMSG = m + return nil + }, + } + + resolver.ServeDNS(responseWriter, testCase.inputMSG) + + if responseMSG == nil || len(responseMSG.Answer) == 0 { + if testCase.responseShouldBeNil { + return + } + t.Fatalf("should write a response message") + } + + answerString := responseMSG.Answer[0].String() + if !strings.Contains(answerString, testCase.inputRecord.Name) { + t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString) + } + if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) { + t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString) + } + if !strings.Contains(answerString, testCase.inputRecord.RData) { + t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString) + } + }) + } +} + +// TestLocalResolver_Update_StaleRecord verifies that updating +// a record correctly replaces the old one, preventing stale entries. +func TestLocalResolver_Update_StaleRecord(t *testing.T) { + recordName := "host.example.com." + recordType := dns.TypeA + recordClass := dns.ClassINET + + record1 := nbdns.SimpleRecord{ + Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "1.1.1.1", + } + record2 := nbdns.SimpleRecord{ + Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "2.2.2.2", + } + + recordKey := dns.Question{Name: recordName, Qtype: uint16(recordClass), Qclass: recordType} + + resolver := NewResolver() + + update1 := []nbdns.SimpleRecord{record1} + update2 := []nbdns.SimpleRecord{record2} + + // Apply first update + resolver.Update(update1) + + // Verify first update + resolver.mu.RLock() + rrSlice1, found1 := resolver.records[recordKey] + resolver.mu.RUnlock() + + require.True(t, found1, "Record key %s not found after first update", recordKey) + require.Len(t, rrSlice1, 1, "Should have exactly 1 record after first update") + assert.Contains(t, rrSlice1[0].String(), record1.RData, "Record after first update should be %s", record1.RData) + + // Apply second update + resolver.Update(update2) + + // Verify second update + resolver.mu.RLock() + rrSlice2, found2 := resolver.records[recordKey] + resolver.mu.RUnlock() + + require.True(t, found2, "Record key %s not found after second update", recordKey) + require.Len(t, rrSlice2, 1, "Should have exactly 1 record after update overwriting the key") + assert.Contains(t, rrSlice2[0].String(), record2.RData, "The single record should be the updated one (%s)", record2.RData) + assert.NotContains(t, rrSlice2[0].String(), record1.RData, "The stale record (%s) should not be present", record1.RData) +} + +// TestLocalResolver_MultipleRecords_SameQuestion verifies that multiple records +// with the same question are stored properly +func TestLocalResolver_MultipleRecords_SameQuestion(t *testing.T) { + resolver := NewResolver() + + recordName := "multi.example.com." + recordType := dns.TypeA + + // Create two records with the same name and type but different IPs + record1 := nbdns.SimpleRecord{ + Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1", + } + record2 := nbdns.SimpleRecord{ + Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2", + } + + update := []nbdns.SimpleRecord{record1, record2} + + // Apply update with both records + resolver.Update(update) + + // Create question that matches both records + question := dns.Question{ + Name: recordName, + Qtype: recordType, + Qclass: dns.ClassINET, + } + + // Verify both records are stored + resolver.mu.RLock() + records, found := resolver.records[question] + resolver.mu.RUnlock() + + require.True(t, found, "Records for question %v not found", question) + require.Len(t, records, 2, "Should have exactly 2 records for the same question") + + // Verify both record data values are present + recordStrings := []string{records[0].String(), records[1].String()} + assert.Contains(t, recordStrings[0]+recordStrings[1], record1.RData, "First record data should be present") + assert.Contains(t, recordStrings[0]+recordStrings[1], record2.RData, "Second record data should be present") +} + +// TestLocalResolver_RecordRotation verifies that records are rotated in a round-robin fashion +func TestLocalResolver_RecordRotation(t *testing.T) { + resolver := NewResolver() + + recordName := "rotation.example.com." + recordType := dns.TypeA + + // Create three records with the same name and type but different IPs + record1 := nbdns.SimpleRecord{ + Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.1", + } + record2 := nbdns.SimpleRecord{ + Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.2", + } + record3 := nbdns.SimpleRecord{ + Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.3", + } + + update := []nbdns.SimpleRecord{record1, record2, record3} + + // Apply update with all three records + resolver.Update(update) + + msg := new(dns.Msg).SetQuestion(recordName, recordType) + + // First lookup - should return the records in original order + var responses [3]*dns.Msg + + // Perform three lookups to verify rotation + for i := 0; i < 3; i++ { + responseWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responses[i] = m + return nil + }, + } + + resolver.ServeDNS(responseWriter, msg) + } + + // Verify all three responses contain answers + for i, resp := range responses { + require.NotNil(t, resp, "Response %d should not be nil", i) + require.Len(t, resp.Answer, 3, "Response %d should have 3 answers", i) + } + + // Verify the first record in each response is different due to rotation + firstRecordIPs := []string{ + responses[0].Answer[0].String(), + responses[1].Answer[0].String(), + responses[2].Answer[0].String(), + } + + // Each record should be different (rotated) + assert.NotEqual(t, firstRecordIPs[0], firstRecordIPs[1], "First lookup should differ from second lookup due to rotation") + assert.NotEqual(t, firstRecordIPs[1], firstRecordIPs[2], "Second lookup should differ from third lookup due to rotation") + assert.NotEqual(t, firstRecordIPs[0], firstRecordIPs[2], "First lookup should differ from third lookup due to rotation") + + // After three rotations, we should have cycled through all records + assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record1.RData) + assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record2.RData) + assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record3.RData) +} + +// TestLocalResolver_CaseInsensitiveMatching verifies that DNS record lookups are case-insensitive +func TestLocalResolver_CaseInsensitiveMatching(t *testing.T) { + resolver := NewResolver() + + // Create record with lowercase name + lowerCaseRecord := nbdns.SimpleRecord{ + Name: "lower.example.com.", + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "10.10.10.10", + } + + // Create record with mixed case name + mixedCaseRecord := nbdns.SimpleRecord{ + Name: "MiXeD.ExAmPlE.CoM.", + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "20.20.20.20", + } + + // Update resolver with the records + resolver.Update([]nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord}) + + testCases := []struct { + name string + queryName string + expectedRData string + shouldResolve bool + }{ + { + name: "Query lowercase with lowercase record", + queryName: "lower.example.com.", + expectedRData: "10.10.10.10", + shouldResolve: true, + }, + { + name: "Query uppercase with lowercase record", + queryName: "LOWER.EXAMPLE.COM.", + expectedRData: "10.10.10.10", + shouldResolve: true, + }, + { + name: "Query mixed case with lowercase record", + queryName: "LoWeR.eXaMpLe.CoM.", + expectedRData: "10.10.10.10", + shouldResolve: true, + }, + { + name: "Query lowercase with mixed case record", + queryName: "mixed.example.com.", + expectedRData: "20.20.20.20", + shouldResolve: true, + }, + { + name: "Query uppercase with mixed case record", + queryName: "MIXED.EXAMPLE.COM.", + expectedRData: "20.20.20.20", + shouldResolve: true, + }, + { + name: "Query with different casing pattern", + queryName: "mIxEd.ExaMpLe.cOm.", + expectedRData: "20.20.20.20", + shouldResolve: true, + }, + { + name: "Query non-existent domain", + queryName: "nonexistent.example.com.", + shouldResolve: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var responseMSG *dns.Msg + + // Create DNS query with the test case name + msg := new(dns.Msg).SetQuestion(tc.queryName, dns.TypeA) + + // Create mock response writer to capture the response + responseWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responseMSG = m + return nil + }, + } + + // Perform DNS query + resolver.ServeDNS(responseWriter, msg) + + // Check if we expect a successful resolution + if !tc.shouldResolve { + if responseMSG == nil || len(responseMSG.Answer) == 0 { + // Expected no answer, test passes + return + } + t.Fatalf("Expected no resolution for %s, but got answer: %v", tc.queryName, responseMSG.Answer) + } + + // Verify we got a response + require.NotNil(t, responseMSG, "Should have received a response message") + require.Greater(t, len(responseMSG.Answer), 0, "Response should contain at least one answer") + + // Verify the response contains the expected data + answerString := responseMSG.Answer[0].String() + assert.Contains(t, answerString, tc.expectedRData, + "Answer should contain the expected IP address %s, got: %s", + tc.expectedRData, answerString) + }) + } +} + +// TestLocalResolver_CNAMEFallback verifies that the resolver correctly falls back +// to checking for CNAME records when the requested record type isn't found +func TestLocalResolver_CNAMEFallback(t *testing.T) { + resolver := NewResolver() + + // Create a CNAME record (but no A record for this name) + cnameRecord := nbdns.SimpleRecord{ + Name: "alias.example.com.", + Type: int(dns.TypeCNAME), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "target.example.com.", + } + + // Create an A record for the CNAME target + targetRecord := nbdns.SimpleRecord{ + Name: "target.example.com.", + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "192.168.100.100", + } + + // Update resolver with both records + resolver.Update([]nbdns.SimpleRecord{cnameRecord, targetRecord}) + + testCases := []struct { + name string + queryName string + queryType uint16 + expectedType string + expectedRData string + shouldResolve bool + }{ + { + name: "Directly query CNAME record", + queryName: "alias.example.com.", + queryType: dns.TypeCNAME, + expectedType: "CNAME", + expectedRData: "target.example.com.", + shouldResolve: true, + }, + { + name: "Query A record but get CNAME fallback", + queryName: "alias.example.com.", + queryType: dns.TypeA, + expectedType: "CNAME", + expectedRData: "target.example.com.", + shouldResolve: true, + }, + { + name: "Query AAAA record but get CNAME fallback", + queryName: "alias.example.com.", + queryType: dns.TypeAAAA, + expectedType: "CNAME", + expectedRData: "target.example.com.", + shouldResolve: true, + }, + { + name: "Query direct A record", + queryName: "target.example.com.", + queryType: dns.TypeA, + expectedType: "A", + expectedRData: "192.168.100.100", + shouldResolve: true, + }, + { + name: "Query non-existent name", + queryName: "nonexistent.example.com.", + queryType: dns.TypeA, + shouldResolve: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var responseMSG *dns.Msg + + // Create DNS query with the test case parameters + msg := new(dns.Msg).SetQuestion(tc.queryName, tc.queryType) + + // Create mock response writer to capture the response + responseWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responseMSG = m + return nil + }, + } + + // Perform DNS query + resolver.ServeDNS(responseWriter, msg) + + // Check if we expect a successful resolution + if !tc.shouldResolve { + if responseMSG == nil || len(responseMSG.Answer) == 0 || responseMSG.Rcode != dns.RcodeSuccess { + // Expected no resolution, test passes + return + } + t.Fatalf("Expected no resolution for %s, but got answer: %v", tc.queryName, responseMSG.Answer) + } + + // Verify we got a successful response + require.NotNil(t, responseMSG, "Should have received a response message") + require.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "Response should have success status code") + require.Greater(t, len(responseMSG.Answer), 0, "Response should contain at least one answer") + + // Verify the response contains the expected data + answerString := responseMSG.Answer[0].String() + assert.Contains(t, answerString, tc.expectedType, + "Answer should be of type %s, got: %s", tc.expectedType, answerString) + assert.Contains(t, answerString, tc.expectedRData, + "Answer should contain the expected data %s, got: %s", tc.expectedRData, answerString) + }) + } +} diff --git a/client/internal/dns/local_test.go b/client/internal/dns/local_test.go deleted file mode 100644 index 0a42b321a..000000000 --- a/client/internal/dns/local_test.go +++ /dev/null @@ -1,88 +0,0 @@ -package dns - -import ( - "strings" - "testing" - - "github.com/miekg/dns" - - nbdns "github.com/netbirdio/netbird/dns" -) - -func TestLocalResolver_ServeDNS(t *testing.T) { - recordA := nbdns.SimpleRecord{ - Name: "peera.netbird.cloud.", - Type: 1, - Class: nbdns.DefaultClass, - TTL: 300, - RData: "1.2.3.4", - } - - recordCNAME := nbdns.SimpleRecord{ - Name: "peerb.netbird.cloud.", - Type: 5, - Class: nbdns.DefaultClass, - TTL: 300, - RData: "www.netbird.io", - } - - testCases := []struct { - name string - inputRecord nbdns.SimpleRecord - inputMSG *dns.Msg - responseShouldBeNil bool - }{ - { - name: "Should Resolve A Record", - inputRecord: recordA, - inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA), - }, - { - name: "Should Resolve CNAME Record", - inputRecord: recordCNAME, - inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME), - }, - { - name: "Should Not Write When Not Found A Record", - inputRecord: recordA, - inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA), - responseShouldBeNil: true, - }, - } - - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { - resolver := &localResolver{ - registeredMap: make(registrationMap), - } - _, _ = resolver.registerRecord(testCase.inputRecord) - var responseMSG *dns.Msg - responseWriter := &mockResponseWriter{ - WriteMsgFunc: func(m *dns.Msg) error { - responseMSG = m - return nil - }, - } - - resolver.ServeDNS(responseWriter, testCase.inputMSG) - - if responseMSG == nil || len(responseMSG.Answer) == 0 { - if testCase.responseShouldBeNil { - return - } - t.Fatalf("should write a response message") - } - - answerString := responseMSG.Answer[0].String() - if !strings.Contains(answerString, testCase.inputRecord.Name) { - t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString) - } - if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) { - t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString) - } - if !strings.Contains(answerString, testCase.inputRecord.RData) { - t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString) - } - }) - } -} diff --git a/client/internal/dns/mock_test.go b/client/internal/dns/mock_test.go deleted file mode 100644 index d52ae24da..000000000 --- a/client/internal/dns/mock_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package dns - -import ( - "net" - - "github.com/miekg/dns" -) - -type mockResponseWriter struct { - WriteMsgFunc func(m *dns.Msg) error -} - -func (rw *mockResponseWriter) WriteMsg(m *dns.Msg) error { - if rw.WriteMsgFunc != nil { - return rw.WriteMsgFunc(m) - } - return nil -} - -func (rw *mockResponseWriter) LocalAddr() net.Addr { return nil } -func (rw *mockResponseWriter) RemoteAddr() net.Addr { return nil } -func (rw *mockResponseWriter) Write([]byte) (int, error) { return 0, nil } -func (rw *mockResponseWriter) Close() error { return nil } -func (rw *mockResponseWriter) TsigStatus() error { return nil } -func (rw *mockResponseWriter) TsigTimersOnly(bool) {} -func (rw *mockResponseWriter) Hijack() {} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 65b90e5f0..3f49c23fd 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -15,6 +15,8 @@ import ( "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/internal/dns/local" + "github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/statemanager" @@ -46,8 +48,6 @@ type Server interface { ProbeAvailability() } -type handlerID string - type nsGroupsByDomain struct { domain string groups []*nbdns.NameServerGroup @@ -61,7 +61,7 @@ type DefaultServer struct { mux sync.Mutex service service dnsMuxMap registeredHandlerMap - localResolver *localResolver + localResolver *local.Resolver wgInterface WGIface hostManager hostManager updateSerial uint64 @@ -84,9 +84,9 @@ type DefaultServer struct { type handlerWithStop interface { dns.Handler - stop() - probeAvailability() - id() handlerID + Stop() + ProbeAvailability() + ID() types.HandlerID } type handlerWrapper struct { @@ -95,7 +95,7 @@ type handlerWrapper struct { priority int } -type registeredHandlerMap map[handlerID]handlerWrapper +type registeredHandlerMap map[types.HandlerID]handlerWrapper // NewDefaultServer returns a new dns server func NewDefaultServer( @@ -171,16 +171,14 @@ func newDefaultServer( handlerChain := NewHandlerChain() ctx, stop := context.WithCancel(ctx) defaultServer := &DefaultServer{ - ctx: ctx, - ctxCancel: stop, - disableSys: disableSys, - service: dnsService, - handlerChain: handlerChain, - extraDomains: make(map[domain.Domain]int), - dnsMuxMap: make(registeredHandlerMap), - localResolver: &localResolver{ - registeredMap: make(registrationMap), - }, + ctx: ctx, + ctxCancel: stop, + disableSys: disableSys, + service: dnsService, + handlerChain: handlerChain, + extraDomains: make(map[domain.Domain]int), + dnsMuxMap: make(registeredHandlerMap), + localResolver: local.NewResolver(), wgInterface: wgInterface, statusRecorder: statusRecorder, stateManager: stateManager, @@ -403,7 +401,7 @@ func (s *DefaultServer) ProbeAvailability() { wg.Add(1) go func(mux handlerWithStop) { defer wg.Done() - mux.probeAvailability() + mux.ProbeAvailability() }(mux.handler) } wg.Wait() @@ -420,7 +418,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { s.service.Stop() } - localMuxUpdates, localRecordsByDomain, err := s.buildLocalHandlerUpdate(update.CustomZones) + localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones) if err != nil { return fmt.Errorf("local handler updater: %w", err) } @@ -434,7 +432,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { s.updateMux(muxUpdates) // register local records - s.updateLocalResolver(localRecordsByDomain) + s.localResolver.Update(localRecords) s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort()) @@ -516,11 +514,9 @@ func (s *DefaultServer) handleErrNoGroupaAll(err error) { ) } -func (s *DefaultServer) buildLocalHandlerUpdate( - customZones []nbdns.CustomZone, -) ([]handlerWrapper, map[string][]nbdns.SimpleRecord, error) { +func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, error) { var muxUpdates []handlerWrapper - localRecords := make(map[string][]nbdns.SimpleRecord) + var localRecords []nbdns.SimpleRecord for _, customZone := range customZones { if len(customZone.Records) == 0 { @@ -534,17 +530,13 @@ func (s *DefaultServer) buildLocalHandlerUpdate( priority: PriorityMatchDomain, }) - // group all records under this domain for _, record := range customZone.Records { - var class uint16 = dns.ClassINET if record.Class != nbdns.DefaultClass { log.Warnf("received an invalid class type: %s", record.Class) continue } - - key := buildRecordKey(record.Name, class, uint16(record.Type)) - - localRecords[key] = append(localRecords[key], record) + // zone records contain the fqdn, so we can just flatten them + localRecords = append(localRecords, record) } } @@ -627,7 +619,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai } if len(handler.upstreamServers) == 0 { - handler.stop() + handler.Stop() log.Errorf("received a nameserver group with an invalid nameserver list") continue } @@ -656,7 +648,7 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { // this will introduce a short period of time when the server is not able to handle DNS requests for _, existing := range s.dnsMuxMap { s.deregisterHandler([]string{existing.domain}, existing.priority) - existing.handler.stop() + existing.handler.Stop() } muxUpdateMap := make(registeredHandlerMap) @@ -667,7 +659,7 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { containsRootUpdate = true } s.registerHandler([]string{update.domain}, update.handler, update.priority) - muxUpdateMap[update.handler.id()] = update + muxUpdateMap[update.handler.ID()] = update } // If there's no root update and we had a root handler, restore it @@ -683,33 +675,6 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { s.dnsMuxMap = muxUpdateMap } -func (s *DefaultServer) updateLocalResolver(update map[string][]nbdns.SimpleRecord) { - // remove old records that are no longer present - for key := range s.localResolver.registeredMap { - _, found := update[key] - if !found { - s.localResolver.deleteRecord(key) - } - } - - updatedMap := make(registrationMap) - for _, recs := range update { - for _, rec := range recs { - // convert the record to a dns.RR and register - key, err := s.localResolver.registerRecord(rec) - if err != nil { - log.Warnf("got an error while registering the record (%s), error: %v", - rec.String(), err) - continue - } - - updatedMap[key] = struct{}{} - } - } - - s.localResolver.registeredMap = updatedMap -} - func getNSHostPort(ns nbdns.NameServer) string { return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port) } diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index ed69b0e93..1c7c9b117 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -23,6 +23,9 @@ import ( "github.com/netbirdio/netbird/client/iface/device" pfmock "github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/client/internal/dns/local" + "github.com/netbirdio/netbird/client/internal/dns/test" + "github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/statemanager" @@ -107,6 +110,7 @@ func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamRe } func TestUpdateDNSServer(t *testing.T) { + nameServers := []nbdns.NameServer{ { IP: netip.MustParseAddr("8.8.8.8"), @@ -120,22 +124,21 @@ func TestUpdateDNSServer(t *testing.T) { }, } - dummyHandler := &localResolver{} + dummyHandler := local.NewResolver() testCases := []struct { name string initUpstreamMap registeredHandlerMap - initLocalMap registrationMap + initLocalRecords []nbdns.SimpleRecord initSerial uint64 inputSerial uint64 inputUpdate nbdns.Config shouldFail bool expectedUpstreamMap registeredHandlerMap - expectedLocalMap registrationMap + expectedLocalQs []dns.Question }{ { name: "Initial Config Should Succeed", - initLocalMap: make(registrationMap), initUpstreamMap: make(registeredHandlerMap), initSerial: 0, inputSerial: 1, @@ -159,30 +162,30 @@ func TestUpdateDNSServer(t *testing.T) { }, }, expectedUpstreamMap: registeredHandlerMap{ - generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{ + generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{ domain: "netbird.io", handler: dummyHandler, priority: PriorityMatchDomain, }, - dummyHandler.id(): handlerWrapper{ + dummyHandler.ID(): handlerWrapper{ domain: "netbird.cloud", handler: dummyHandler, priority: PriorityMatchDomain, }, - generateDummyHandler(".", nameServers).id(): handlerWrapper{ + generateDummyHandler(".", nameServers).ID(): handlerWrapper{ domain: nbdns.RootZone, handler: dummyHandler, priority: PriorityDefault, }, }, - expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, + expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}}, }, { - name: "New Config Should Succeed", - initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, + name: "New Config Should Succeed", + initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}, initUpstreamMap: registeredHandlerMap{ - generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{ - domain: buildRecordKey(zoneRecords[0].Name, 1, 1), + generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ + domain: "netbird.cloud", handler: dummyHandler, priority: PriorityMatchDomain, }, @@ -205,7 +208,7 @@ func TestUpdateDNSServer(t *testing.T) { }, }, expectedUpstreamMap: registeredHandlerMap{ - generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{ + generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{ domain: "netbird.io", handler: dummyHandler, priority: PriorityMatchDomain, @@ -216,22 +219,22 @@ func TestUpdateDNSServer(t *testing.T) { priority: PriorityMatchDomain, }, }, - expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, + expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}}, }, { - name: "Smaller Config Serial Should Be Skipped", - initLocalMap: make(registrationMap), - initUpstreamMap: make(registeredHandlerMap), - initSerial: 2, - inputSerial: 1, - shouldFail: true, + name: "Smaller Config Serial Should Be Skipped", + initLocalRecords: []nbdns.SimpleRecord{}, + initUpstreamMap: make(registeredHandlerMap), + initSerial: 2, + inputSerial: 1, + shouldFail: true, }, { - name: "Empty NS Group Domain Or Not Primary Element Should Fail", - initLocalMap: make(registrationMap), - initUpstreamMap: make(registeredHandlerMap), - initSerial: 0, - inputSerial: 1, + name: "Empty NS Group Domain Or Not Primary Element Should Fail", + initLocalRecords: []nbdns.SimpleRecord{}, + initUpstreamMap: make(registeredHandlerMap), + initSerial: 0, + inputSerial: 1, inputUpdate: nbdns.Config{ ServiceEnable: true, CustomZones: []nbdns.CustomZone{ @@ -249,11 +252,11 @@ func TestUpdateDNSServer(t *testing.T) { shouldFail: true, }, { - name: "Invalid NS Group Nameservers list Should Fail", - initLocalMap: make(registrationMap), - initUpstreamMap: make(registeredHandlerMap), - initSerial: 0, - inputSerial: 1, + name: "Invalid NS Group Nameservers list Should Fail", + initLocalRecords: []nbdns.SimpleRecord{}, + initUpstreamMap: make(registeredHandlerMap), + initSerial: 0, + inputSerial: 1, inputUpdate: nbdns.Config{ ServiceEnable: true, CustomZones: []nbdns.CustomZone{ @@ -271,11 +274,11 @@ func TestUpdateDNSServer(t *testing.T) { shouldFail: true, }, { - name: "Invalid Custom Zone Records list Should Skip", - initLocalMap: make(registrationMap), - initUpstreamMap: make(registeredHandlerMap), - initSerial: 0, - inputSerial: 1, + name: "Invalid Custom Zone Records list Should Skip", + initLocalRecords: []nbdns.SimpleRecord{}, + initUpstreamMap: make(registeredHandlerMap), + initSerial: 0, + inputSerial: 1, inputUpdate: nbdns.Config{ ServiceEnable: true, CustomZones: []nbdns.CustomZone{ @@ -290,17 +293,17 @@ func TestUpdateDNSServer(t *testing.T) { }, }, }, - expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).id(): handlerWrapper{ + expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).ID(): handlerWrapper{ domain: ".", handler: dummyHandler, priority: PriorityDefault, }}, }, { - name: "Empty Config Should Succeed and Clean Maps", - initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, + name: "Empty Config Should Succeed and Clean Maps", + initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}, initUpstreamMap: registeredHandlerMap{ - generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{ + generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ domain: zoneRecords[0].Name, handler: dummyHandler, priority: PriorityMatchDomain, @@ -310,13 +313,13 @@ func TestUpdateDNSServer(t *testing.T) { inputSerial: 1, inputUpdate: nbdns.Config{ServiceEnable: true}, expectedUpstreamMap: make(registeredHandlerMap), - expectedLocalMap: make(registrationMap), + expectedLocalQs: []dns.Question{}, }, { - name: "Disabled Service Should clean map", - initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, + name: "Disabled Service Should clean map", + initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}, initUpstreamMap: registeredHandlerMap{ - generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{ + generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ domain: zoneRecords[0].Name, handler: dummyHandler, priority: PriorityMatchDomain, @@ -326,7 +329,7 @@ func TestUpdateDNSServer(t *testing.T) { inputSerial: 1, inputUpdate: nbdns.Config{ServiceEnable: false}, expectedUpstreamMap: make(registeredHandlerMap), - expectedLocalMap: make(registrationMap), + expectedLocalQs: []dns.Question{}, }, } @@ -377,7 +380,7 @@ func TestUpdateDNSServer(t *testing.T) { }() dnsServer.dnsMuxMap = testCase.initUpstreamMap - dnsServer.localResolver.registeredMap = testCase.initLocalMap + dnsServer.localResolver.Update(testCase.initLocalRecords) dnsServer.updateSerial = testCase.initSerial err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate) @@ -399,15 +402,23 @@ func TestUpdateDNSServer(t *testing.T) { } } - if len(dnsServer.localResolver.registeredMap) != len(testCase.expectedLocalMap) { - t.Fatalf("update local failed, registered map size is different than expected, want %d, got %d", len(testCase.expectedLocalMap), len(dnsServer.localResolver.registeredMap)) + var responseMSG *dns.Msg + responseWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responseMSG = m + return nil + }, + } + for _, q := range testCase.expectedLocalQs { + dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{ + Question: []dns.Question{q}, + }) } - for key := range testCase.expectedLocalMap { - _, found := dnsServer.localResolver.registeredMap[key] - if !found { - t.Fatalf("update local failed, key %s was not found in the localResolver.registeredMap: %#v", key, dnsServer.localResolver.registeredMap) - } + if len(testCase.expectedLocalQs) > 0 { + assert.NotNil(t, responseMSG, "response message should not be nil") + assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success") + assert.NotEmpty(t, responseMSG.Answer, "response message should have answers") } }) } @@ -491,11 +502,12 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { dnsServer.dnsMuxMap = registeredHandlerMap{ "id1": handlerWrapper{ domain: zoneRecords[0].Name, - handler: &localResolver{}, + handler: &local.Resolver{}, priority: PriorityMatchDomain, }, } - dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}} + //dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}} + dnsServer.localResolver.Update([]nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}) dnsServer.updateSerial = 0 nameServers := []nbdns.NameServer{ @@ -582,7 +594,7 @@ func TestDNSServerStartStop(t *testing.T) { } time.Sleep(100 * time.Millisecond) defer dnsServer.Stop() - _, err = dnsServer.localResolver.registerRecord(zoneRecords[0]) + err = dnsServer.localResolver.RegisterRecord(zoneRecords[0]) if err != nil { t.Error(err) } @@ -630,13 +642,11 @@ func TestDNSServerStartStop(t *testing.T) { func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { hostManager := &mockHostConfigurator{} server := DefaultServer{ - ctx: context.Background(), - service: NewServiceViaMemory(&mocWGIface{}), - localResolver: &localResolver{ - registeredMap: make(registrationMap), - }, - handlerChain: NewHandlerChain(), - hostManager: hostManager, + ctx: context.Background(), + service: NewServiceViaMemory(&mocWGIface{}), + localResolver: local.NewResolver(), + handlerChain: NewHandlerChain(), + hostManager: hostManager, currentConfig: HostDNSConfig{ Domains: []DomainConfig{ {false, "domain0", false}, @@ -1004,7 +1014,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) { t.Run(tc.name, func(t *testing.T) { r := new(dns.Msg) r.SetQuestion(tc.query, dns.TypeA) - w := &ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w := &ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} if mh, ok := tc.expectedHandler.(*MockHandler); ok { mh.On("ServeDNS", mock.Anything, r).Once() @@ -1037,9 +1047,9 @@ type mockHandler struct { } func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {} -func (m *mockHandler) stop() {} -func (m *mockHandler) probeAvailability() {} -func (m *mockHandler) id() handlerID { return handlerID(m.Id) } +func (m *mockHandler) Stop() {} +func (m *mockHandler) ProbeAvailability() {} +func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) } type mockService struct{} @@ -1113,7 +1123,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) { name string initialHandlers registeredHandlerMap updates []handlerWrapper - expectedHandlers map[string]string // map[handlerID]domain + expectedHandlers map[string]string // map[HandlerID]domain description string }{ { @@ -1409,7 +1419,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) { // Check each expected handler for id, expectedDomain := range tt.expectedHandlers { - handler, exists := server.dnsMuxMap[handlerID(id)] + handler, exists := server.dnsMuxMap[types.HandlerID(id)] assert.True(t, exists, "Expected handler %s not found", id) if exists { assert.Equal(t, expectedDomain, handler.domain, @@ -1418,9 +1428,9 @@ func TestDefaultServer_UpdateMux(t *testing.T) { } // Verify no unexpected handlers exist - for handlerID := range server.dnsMuxMap { - _, expected := tt.expectedHandlers[string(handlerID)] - assert.True(t, expected, "Unexpected handler found: %s", handlerID) + for HandlerID := range server.dnsMuxMap { + _, expected := tt.expectedHandlers[string(HandlerID)] + assert.True(t, expected, "Unexpected handler found: %s", HandlerID) } // Verify the handlerChain state and order @@ -1696,7 +1706,7 @@ func TestExtraDomains(t *testing.T) { handlerChain: NewHandlerChain(), wgInterface: &mocWGIface{}, hostManager: mockHostConfig, - localResolver: &localResolver{}, + localResolver: &local.Resolver{}, service: mockSvc, statusRecorder: peer.NewRecorder("test"), extraDomains: make(map[domain.Domain]int), @@ -1781,7 +1791,7 @@ func TestExtraDomainsRefCounting(t *testing.T) { ctx: context.Background(), handlerChain: NewHandlerChain(), hostManager: mockHostConfig, - localResolver: &localResolver{}, + localResolver: &local.Resolver{}, service: mockSvc, statusRecorder: peer.NewRecorder("test"), extraDomains: make(map[domain.Domain]int), @@ -1833,7 +1843,7 @@ func TestUpdateConfigWithExistingExtraDomains(t *testing.T) { ctx: context.Background(), handlerChain: NewHandlerChain(), hostManager: mockHostConfig, - localResolver: &localResolver{}, + localResolver: &local.Resolver{}, service: mockSvc, statusRecorder: peer.NewRecorder("test"), extraDomains: make(map[domain.Domain]int), @@ -1916,7 +1926,7 @@ func TestDomainCaseHandling(t *testing.T) { ctx: context.Background(), handlerChain: NewHandlerChain(), hostManager: mockHostConfig, - localResolver: &localResolver{}, + localResolver: &local.Resolver{}, service: mockSvc, statusRecorder: peer.NewRecorder("test"), extraDomains: make(map[domain.Domain]int), diff --git a/client/internal/dns/test/mock.go b/client/internal/dns/test/mock.go new file mode 100644 index 000000000..1db452805 --- /dev/null +++ b/client/internal/dns/test/mock.go @@ -0,0 +1,26 @@ +package test + +import ( + "net" + + "github.com/miekg/dns" +) + +type MockResponseWriter struct { + WriteMsgFunc func(m *dns.Msg) error +} + +func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error { + if rw.WriteMsgFunc != nil { + return rw.WriteMsgFunc(m) + } + return nil +} + +func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil } +func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil } +func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil } +func (rw *MockResponseWriter) Close() error { return nil } +func (rw *MockResponseWriter) TsigStatus() error { return nil } +func (rw *MockResponseWriter) TsigTimersOnly(bool) {} +func (rw *MockResponseWriter) Hijack() {} diff --git a/client/internal/dns/types/types.go b/client/internal/dns/types/types.go new file mode 100644 index 000000000..5a8be03b7 --- /dev/null +++ b/client/internal/dns/types/types.go @@ -0,0 +1,3 @@ +package types + +type HandlerID string diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index fa69d4934..2fbfb3b91 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -19,6 +19,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/proto" ) @@ -81,21 +82,21 @@ func (u *upstreamResolverBase) String() string { } // ID returns the unique handler ID -func (u *upstreamResolverBase) id() handlerID { +func (u *upstreamResolverBase) ID() types.HandlerID { servers := slices.Clone(u.upstreamServers) slices.Sort(servers) hash := sha256.New() hash.Write([]byte(u.domain + ":")) hash.Write([]byte(strings.Join(servers, ","))) - return handlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8])) + return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8])) } func (u *upstreamResolverBase) MatchSubdomains() bool { return true } -func (u *upstreamResolverBase) stop() { +func (u *upstreamResolverBase) Stop() { log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers) u.cancel() } @@ -198,9 +199,9 @@ func (u *upstreamResolverBase) checkUpstreamFails(err error) { ) } -// probeAvailability tests all upstream servers simultaneously and +// ProbeAvailability tests all upstream servers simultaneously and // disables the resolver if none work -func (u *upstreamResolverBase) probeAvailability() { +func (u *upstreamResolverBase) ProbeAvailability() { u.mutex.Lock() defer u.mutex.Unlock() diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index 5dbcc9f79..13bc91a37 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -8,6 +8,8 @@ import ( "time" "github.com/miekg/dns" + + "github.com/netbirdio/netbird/client/internal/dns/test" ) func TestUpstreamResolver_ServeDNS(t *testing.T) { @@ -66,7 +68,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { } var responseMSG *dns.Msg - responseWriter := &mockResponseWriter{ + responseWriter := &test.MockResponseWriter{ WriteMsgFunc: func(m *dns.Msg) error { responseMSG = m return nil @@ -130,7 +132,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { resolver.failsTillDeact = 0 resolver.reactivatePeriod = time.Microsecond * 100 - responseWriter := &mockResponseWriter{ + responseWriter := &test.MockResponseWriter{ WriteMsgFunc: func(m *dns.Msg) error { return nil }, } diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index 2d69ce858..45b479632 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -3,6 +3,7 @@ package dnsfwd import ( "context" "errors" + "fmt" "math" "net" "net/netip" @@ -10,11 +11,16 @@ import ( "sync" "time" + "github.com/hashicorp/go-multierror" "github.com/miekg/dns" log "github.com/sirupsen/logrus" + nberrors "github.com/netbirdio/netbird/client/errors" + firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/peer" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/domain" + "github.com/netbirdio/netbird/route" ) const errResolveFailed = "failed to resolve query for domain=%s: %v" @@ -23,79 +29,116 @@ const upstreamTimeout = 15 * time.Second type DNSForwarder struct { listenAddress string ttl uint32 - domains []string statusRecorder *peer.Status dnsServer *dns.Server mux *dns.ServeMux + tcpServer *dns.Server + tcpMux *dns.ServeMux - resId sync.Map + mutex sync.RWMutex + fwdEntries []*ForwarderEntry + firewall firewall.Manager } -func NewDNSForwarder(listenAddress string, ttl uint32, statusRecorder *peer.Status) *DNSForwarder { +func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewall.Manager, statusRecorder *peer.Status) *DNSForwarder { log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl) return &DNSForwarder{ listenAddress: listenAddress, ttl: ttl, + firewall: firewall, statusRecorder: statusRecorder, } } -func (f *DNSForwarder) Listen(domains []string, resIds map[string]string) error { - log.Infof("listen DNS forwarder on address=%s", f.listenAddress) - mux := dns.NewServeMux() +func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error { + log.Infof("starting DNS forwarder on address=%s", f.listenAddress) - dnsServer := &dns.Server{ + // UDP server + mux := dns.NewServeMux() + f.mux = mux + f.dnsServer = &dns.Server{ Addr: f.listenAddress, Net: "udp", Handler: mux, } - f.dnsServer = dnsServer - f.mux = mux + // TCP server + tcpMux := dns.NewServeMux() + f.tcpMux = tcpMux + f.tcpServer = &dns.Server{ + Addr: f.listenAddress, + Net: "tcp", + Handler: tcpMux, + } - f.UpdateDomains(domains, resIds) + f.UpdateDomains(entries) - return dnsServer.ListenAndServe() + errCh := make(chan error, 2) + + go func() { + log.Infof("DNS UDP listener running on %s", f.listenAddress) + errCh <- f.dnsServer.ListenAndServe() + }() + go func() { + log.Infof("DNS TCP listener running on %s", f.listenAddress) + errCh <- f.tcpServer.ListenAndServe() + }() + + // return the first error we get (e.g. bind failure or shutdown) + return <-errCh } +func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) { + f.mutex.Lock() + defer f.mutex.Unlock() -func (f *DNSForwarder) UpdateDomains(domains []string, resIds map[string]string) { - log.Debugf("Updating domains from %v to %v", f.domains, domains) - - for _, d := range f.domains { - f.mux.HandleRemove(d) + if f.mux == nil { + log.Debug("DNS mux is nil, skipping domain update") + f.fwdEntries = entries + return } - f.resId.Clear() - newDomains := filterDomains(domains) + oldDomains := filterDomains(f.fwdEntries) + for _, d := range oldDomains { + f.mux.HandleRemove(d.PunycodeString()) + f.tcpMux.HandleRemove(d.PunycodeString()) + } + + newDomains := filterDomains(entries) for _, d := range newDomains { - f.mux.HandleFunc(d, f.handleDNSQuery) + f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQueryUDP) + f.tcpMux.HandleFunc(d.PunycodeString(), f.handleDNSQueryTCP) } - for domain, resId := range resIds { - if domain != "" { - f.resId.Store(domain, resId) - } - } - - f.domains = newDomains + f.fwdEntries = entries + log.Debugf("Updated domains from %v to %v", oldDomains, newDomains) } func (f *DNSForwarder) Close(ctx context.Context) error { - if f.dnsServer == nil { - return nil + var result *multierror.Error + + if f.dnsServer != nil { + if err := f.dnsServer.ShutdownContext(ctx); err != nil { + result = multierror.Append(result, fmt.Errorf("UDP shutdown: %w", err)) + } } - return f.dnsServer.ShutdownContext(ctx) + if f.tcpServer != nil { + if err := f.tcpServer.ShutdownContext(ctx); err != nil { + result = multierror.Append(result, fmt.Errorf("TCP shutdown: %w", err)) + } + } + + return nberrors.FormatErrorOrNil(result) } -func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { +func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns.Msg { if len(query.Question) == 0 { - return + return nil } - log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v", - query.Question[0].Name, query.Question[0].Qtype, query.Question[0].Qclass) - question := query.Question[0] - domain := question.Name + log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v", + question.Name, question.Qtype, question.Qclass) + + domain := strings.ToLower(question.Name) resp := query.SetReply(query) var network string @@ -111,41 +154,96 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { if err := w.WriteMsg(resp); err != nil { log.Errorf("failed to write DNS response: %v", err) } - return + return nil } ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout) defer cancel() ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain) if err != nil { - f.handleDNSError(w, resp, domain, err) + f.handleDNSError(w, query, resp, domain, err) + return nil + } + + f.updateInternalState(domain, ips) + f.addIPsToResponse(resp, domain, ips) + + return resp +} + +func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) { + + resp := f.handleDNSQuery(w, query) + if resp == nil { return } - resId := f.getResIdForDomain(strings.TrimSuffix(domain, ".")) - if resId != "" { - for _, ip := range ips { - var ipWithSuffix string - if ip.Is4() { - ipWithSuffix = ip.String() + "/32" - log.Tracef("resolved domain=%s to IPv4=%s", domain, ipWithSuffix) - } else { - ipWithSuffix = ip.String() + "/128" - log.Tracef("resolved domain=%s to IPv6=%s", domain, ipWithSuffix) - } - f.statusRecorder.AddResolvedIPLookupEntry(ipWithSuffix, resId) - } + opt := query.IsEdns0() + maxSize := dns.MinMsgSize + if opt != nil { + // client advertised a larger EDNS0 buffer + maxSize = int(opt.UDPSize()) } - f.addIPsToResponse(resp, domain, ips) + // if our response is too big, truncate and set the TC bit + if resp.Len() > maxSize { + resp.Truncate(maxSize) + } if err := w.WriteMsg(resp); err != nil { log.Errorf("failed to write DNS response: %v", err) } } +func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) { + resp := f.handleDNSQuery(w, query) + if resp == nil { + return + } + + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write DNS response: %v", err) + } +} + +func (f *DNSForwarder) updateInternalState(domain string, ips []netip.Addr) { + var prefixes []netip.Prefix + mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, ".")) + if mostSpecificResId != "" { + for _, ip := range ips { + var prefix netip.Prefix + if ip.Is4() { + prefix = netip.PrefixFrom(ip, 32) + } else { + prefix = netip.PrefixFrom(ip, 128) + } + prefixes = append(prefixes, prefix) + f.statusRecorder.AddResolvedIPLookupEntry(prefix, mostSpecificResId) + } + } + + if f.firewall != nil { + f.updateFirewall(matchingEntries, prefixes) + } +} + +func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixes []netip.Prefix) { + var merr *multierror.Error + for _, entry := range matchingEntries { + if err := f.firewall.UpdateSet(entry.Set, prefixes); err != nil { + merr = multierror.Append(merr, fmt.Errorf("update set for domain=%s: %w", entry.Domain, err)) + } + } + if merr != nil { + log.Errorf("failed to update firewall sets (%d/%d): %v", + len(merr.Errors), + len(matchingEntries), + nberrors.FormatErrorOrNil(merr)) + } +} + // handleDNSError processes DNS lookup errors and sends an appropriate error response -func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, resp *dns.Msg, domain string, err error) { +func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, query, resp *dns.Msg, domain string, err error) { var dnsErr *net.DNSError switch { @@ -157,7 +255,7 @@ func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, resp *dns.Msg, domai } if dnsErr.Server != "" { - log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err) + log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[query.Question[0].Qtype], domain, dnsErr.Server, err) } else { log.Warnf(errResolveFailed, domain, err) } @@ -204,45 +302,53 @@ func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []neti } } -func (f *DNSForwarder) getResIdForDomain(domain string) string { - var selectedResId string +// getMatchingEntries retrieves the resource IDs for a given domain. +// It returns the most specific match and all matching resource IDs. +func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*ForwarderEntry) { + var selectedResId route.ResID var bestScore int + var matches []*ForwarderEntry - f.resId.Range(func(key, value interface{}) bool { + f.mutex.RLock() + defer f.mutex.RUnlock() + + for _, entry := range f.fwdEntries { var score int - pattern := key.(string) + pattern := entry.Domain.PunycodeString() switch { case strings.HasPrefix(pattern, "*."): baseDomain := strings.TrimPrefix(pattern, "*.") - if domain == baseDomain || strings.HasSuffix(domain, "."+baseDomain) { + + if strings.EqualFold(domain, baseDomain) || strings.HasSuffix(domain, "."+baseDomain) { score = len(baseDomain) + matches = append(matches, entry) } case domain == pattern: score = math.MaxInt + matches = append(matches, entry) default: - return true + continue } if score > bestScore { bestScore = score - selectedResId = value.(string) + selectedResId = entry.ResID } - return true - }) + } - return selectedResId + return selectedResId, matches } // filterDomains returns a list of normalized domains -func filterDomains(domains []string) []string { - newDomains := make([]string, 0, len(domains)) - for _, d := range domains { - if d == "" { +func filterDomains(entries []*ForwarderEntry) domain.List { + newDomains := make(domain.List, 0, len(entries)) + for _, d := range entries { + if d.Domain == "" { log.Warn("empty domain in DNS forwarder") continue } - newDomains = append(newDomains, nbdns.NormalizeZone(d)) + newDomains = append(newDomains, domain.Domain(nbdns.NormalizeZone(d.Domain.PunycodeString()))) } return newDomains } diff --git a/client/internal/dnsfwd/forwarder_test.go b/client/internal/dnsfwd/forwarder_test.go index 88ffc2af3..f0829bbbd 100644 --- a/client/internal/dnsfwd/forwarder_test.go +++ b/client/internal/dnsfwd/forwarder_test.go @@ -1,56 +1,61 @@ package dnsfwd import ( - "sync" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/domain" + "github.com/netbirdio/netbird/route" ) -func TestGetResIdForDomain(t *testing.T) { +func Test_getMatchingEntries(t *testing.T) { testCases := []struct { name string - storedMappings map[string]string // key: domain pattern, value: resId + storedMappings map[string]route.ResID // key: domain pattern, value: resId queryDomain string - expectedResId string + expectedResId route.ResID }{ { name: "Empty map returns empty string", - storedMappings: map[string]string{}, + storedMappings: map[string]route.ResID{}, queryDomain: "example.com", expectedResId: "", }, { name: "Exact match returns stored resId", - storedMappings: map[string]string{"example.com": "res1"}, + storedMappings: map[string]route.ResID{"example.com": "res1"}, queryDomain: "example.com", expectedResId: "res1", }, { name: "Wildcard pattern matches base domain", - storedMappings: map[string]string{"*.example.com": "res2"}, + storedMappings: map[string]route.ResID{"*.example.com": "res2"}, queryDomain: "example.com", expectedResId: "res2", }, { name: "Wildcard pattern matches subdomain", - storedMappings: map[string]string{"*.example.com": "res3"}, + storedMappings: map[string]route.ResID{"*.example.com": "res3"}, queryDomain: "foo.example.com", expectedResId: "res3", }, { name: "Wildcard pattern does not match different domain", - storedMappings: map[string]string{"*.example.com": "res4"}, + storedMappings: map[string]route.ResID{"*.example.com": "res4"}, queryDomain: "foo.notexample.com", expectedResId: "", }, { name: "Non-wildcard pattern does not match subdomain", - storedMappings: map[string]string{"example.com": "res5"}, + storedMappings: map[string]route.ResID{"example.com": "res5"}, queryDomain: "foo.example.com", expectedResId: "", }, { name: "Exact match over overlapping wildcard", - storedMappings: map[string]string{ + storedMappings: map[string]route.ResID{ "*.example.com": "resWildcard", "foo.example.com": "resExact", }, @@ -59,7 +64,7 @@ func TestGetResIdForDomain(t *testing.T) { }, { name: "Overlapping wildcards: Select more specific wildcard", - storedMappings: map[string]string{ + storedMappings: map[string]route.ResID{ "*.example.com": "resA", "*.sub.example.com": "resB", }, @@ -68,7 +73,7 @@ func TestGetResIdForDomain(t *testing.T) { }, { name: "Wildcard multi-level subdomain match", - storedMappings: map[string]string{ + storedMappings: map[string]route.ResID{ "*.example.com": "resMulti", }, queryDomain: "a.b.example.com", @@ -78,18 +83,21 @@ func TestGetResIdForDomain(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - fwd := &DNSForwarder{ - resId: sync.Map{}, - } + fwd := &DNSForwarder{} + var entries []*ForwarderEntry for domainPattern, resId := range tc.storedMappings { - fwd.resId.Store(domainPattern, resId) + d, err := domain.FromString(domainPattern) + require.NoError(t, err) + entries = append(entries, &ForwarderEntry{ + Domain: d, + ResID: resId, + }) } + fwd.UpdateDomains(entries) - got := fwd.getResIdForDomain(tc.queryDomain) - if got != tc.expectedResId { - t.Errorf("For query domain %q, expected resId %q, but got %q", tc.queryDomain, tc.expectedResId, got) - } + got, _ := fwd.getMatchingEntries(tc.queryDomain) + assert.Equal(t, got, tc.expectedResId) }) } } diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index a51ae7abb..91abce823 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -11,6 +11,8 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/management/domain" + "github.com/netbirdio/netbird/route" ) const ( @@ -19,11 +21,19 @@ const ( dnsTTL = 60 //seconds ) +// ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list. +type ForwarderEntry struct { + Domain domain.Domain + ResID route.ResID + Set firewall.Set +} + type Manager struct { firewall firewall.Manager statusRecorder *peer.Status fwRules []firewall.Rule + tcpRules []firewall.Rule dnsForwarder *DNSForwarder } @@ -34,7 +44,7 @@ func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager { } } -func (m *Manager) Start(domains []string, resIds map[string]string) error { +func (m *Manager) Start(fwdEntries []*ForwarderEntry) error { log.Infof("starting DNS forwarder") if m.dnsForwarder != nil { return nil @@ -44,9 +54,9 @@ func (m *Manager) Start(domains []string, resIds map[string]string) error { return err } - m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, m.statusRecorder) + m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, m.firewall, m.statusRecorder) go func() { - if err := m.dnsForwarder.Listen(domains, resIds); err != nil { + if err := m.dnsForwarder.Listen(fwdEntries); err != nil { // todo handle close error if it is exists log.Errorf("failed to start DNS forwarder, err: %v", err) } @@ -55,12 +65,12 @@ func (m *Manager) Start(domains []string, resIds map[string]string) error { return nil } -func (m *Manager) UpdateDomains(domains []string, resIds map[string]string) { +func (m *Manager) UpdateDomains(entries []*ForwarderEntry) { if m.dnsForwarder == nil { return } - m.dnsForwarder.UpdateDomains(domains, resIds) + m.dnsForwarder.UpdateDomains(entries) } func (m *Manager) Stop(ctx context.Context) error { @@ -81,34 +91,47 @@ func (m *Manager) Stop(ctx context.Context) error { return nberrors.FormatErrorOrNil(mErr) } -func (h *Manager) allowDNSFirewall() error { +func (m *Manager) allowDNSFirewall() error { dport := &firewall.Port{ IsRange: false, Values: []uint16{ListenPort}, } - if h.firewall == nil { + if m.firewall == nil { return nil } - dnsRules, err := h.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "") + dnsRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "") if err != nil { log.Errorf("failed to add allow DNS router rules, err: %v", err) return err } - h.fwRules = dnsRules + m.fwRules = dnsRules + + tcpRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept, "") + if err != nil { + log.Errorf("failed to add allow DNS router rules, err: %v", err) + return err + } + m.tcpRules = tcpRules return nil } -func (h *Manager) dropDNSFirewall() error { +func (m *Manager) dropDNSFirewall() error { var mErr *multierror.Error - for _, rule := range h.fwRules { - if err := h.firewall.DeletePeerRule(rule); err != nil { + for _, rule := range m.fwRules { + if err := m.firewall.DeletePeerRule(rule); err != nil { + mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err)) + } + } + for _, rule := range m.tcpRules { + if err := m.firewall.DeletePeerRule(rule); err != nil { mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err)) } } - h.fwRules = nil + m.fwRules = nil + m.tcpRules = nil return nberrors.FormatErrorOrNil(mErr) } diff --git a/client/internal/engine.go b/client/internal/engine.go index c377c12e1..b16232883 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -527,7 +527,7 @@ func (e *Engine) blockLanAccess() { if _, err := e.firewall.AddRouteFiltering( nil, []netip.Prefix{v4}, - network, + firewallManager.Network{Prefix: network}, firewallManager.ProtocolALL, nil, nil, @@ -960,21 +960,21 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { } } - // DNS forwarder dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap) - dnsRouteDomains, resourceIds := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), networkMap.GetRoutes()) - e.updateDNSForwarder(dnsRouteFeatureFlag, dnsRouteDomains, resourceIds) + // apply routes first, route related actions might depend on routing being enabled routes := toRoutes(networkMap.GetRoutes()) if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil { log.Errorf("failed to update clientRoutes, err: %v", err) } - // acls might need routing to be enabled, so we apply after routes if e.acl != nil { - e.acl.ApplyFiltering(networkMap) + e.acl.ApplyFiltering(networkMap, dnsRouteFeatureFlag) } + fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes) + e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries) + // Ingress forward rules if err := e.updateForwardRules(networkMap.GetForwardingRules()); err != nil { log.Errorf("failed to update forward rules, err: %v", err) @@ -1079,29 +1079,24 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route { return routes } -func toRouteDomains(myPubKey string, protoRoutes []*mgmProto.Route) ([]string, map[string]string) { - if protoRoutes == nil { - protoRoutes = []*mgmProto.Route{} - } - - var dnsRoutes []string - resIds := make(map[string]string) - for _, protoRoute := range protoRoutes { - if len(protoRoute.Domains) == 0 { +func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderEntry { + var entries []*dnsfwd.ForwarderEntry + for _, route := range routes { + if len(route.Domains) == 0 { continue } - if protoRoute.Peer == myPubKey { - dnsRoutes = append(dnsRoutes, protoRoute.Domains...) - // resource ID is the first part of the ID - resId := strings.Split(protoRoute.ID, ":") - for _, domain := range protoRoute.Domains { - if len(resId) > 0 { - resIds[domain] = resId[0] - } + if route.Peer == myPubKey { + domainSet := firewallManager.NewDomainSet(route.Domains) + for _, d := range route.Domains { + entries = append(entries, &dnsfwd.ForwarderEntry{ + Domain: d, + Set: domainSet, + ResID: route.GetResourceID(), + }) } } } - return dnsRoutes, resIds + return entries } func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config { @@ -1751,7 +1746,10 @@ func (e *Engine) GetWgAddr() net.IP { } // updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag -func (e *Engine) updateDNSForwarder(enabled bool, domains []string, resIds map[string]string) { +func (e *Engine) updateDNSForwarder( + enabled bool, + fwdEntries []*dnsfwd.ForwarderEntry, +) { if !enabled { if e.dnsForwardMgr == nil { return @@ -1762,18 +1760,18 @@ func (e *Engine) updateDNSForwarder(enabled bool, domains []string, resIds map[s return } - if len(domains) > 0 { - log.Infof("enable domain router service for domains: %v", domains) + if len(fwdEntries) > 0 { if e.dnsForwardMgr == nil { e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) - if err := e.dnsForwardMgr.Start(domains, resIds); err != nil { + if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { log.Errorf("failed to start DNS forward: %v", err) e.dnsForwardMgr = nil } + + log.Infof("started domain router service with %d entries", len(fwdEntries)) } else { - log.Infof("update domain router service for domains: %v", domains) - e.dnsForwardMgr.UpdateDomains(domains, resIds) + e.dnsForwardMgr.UpdateDomains(fwdEntries) } } else if e.dnsForwardMgr != nil { log.Infof("disable domain router service") diff --git a/client/internal/networkmonitor/check_change_bsd.go b/client/internal/networkmonitor/check_change_bsd.go index bb327a877..f5eb2c739 100644 --- a/client/internal/networkmonitor/check_change_bsd.go +++ b/client/internal/networkmonitor/check_change_bsd.go @@ -19,7 +19,7 @@ import ( func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) if err != nil { - return fmt.Errorf("failed to open routing socket: %v", err) + return fmt.Errorf("open routing socket: %v", err) } defer func() { err := unix.Close(fd) diff --git a/client/internal/networkmonitor/check_change_windows.go b/client/internal/networkmonitor/check_change_windows.go index 582865738..814584863 100644 --- a/client/internal/networkmonitor/check_change_windows.go +++ b/client/internal/networkmonitor/check_change_windows.go @@ -13,7 +13,7 @@ import ( func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { routeMonitor, err := systemops.NewRouteMonitor(ctx) if err != nil { - return fmt.Errorf("failed to create route monitor: %w", err) + return fmt.Errorf("create route monitor: %w", err) } defer func() { if err := routeMonitor.Stop(); err != nil { @@ -38,35 +38,49 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) er } func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop) bool { - intf := "" - if route.Interface != nil { - intf = route.Interface.Name - if isSoftInterface(intf) { - log.Debugf("Network monitor: ignoring default route change for soft interface %s", intf) - return false - } + if intf := route.NextHop.Intf; intf != nil && isSoftInterface(intf.Name) { + log.Debugf("Network monitor: ignoring default route change for next hop with soft interface %s", route.NextHop) + return false + } + + // TODO: for the empty nexthop ip (on-link), determine the family differently + nexthop := nexthopv4 + if route.NextHop.IP.Is6() { + nexthop = nexthopv6 } switch route.Type { - case systemops.RouteModified: - // TODO: get routing table to figure out if our route is affected for modified routes - log.Infof("Network monitor: default route changed: via %s, interface %s", route.NextHop, intf) - return true - case systemops.RouteAdded: - if route.NextHop.Is4() && route.NextHop != nexthopv4.IP || route.NextHop.Is6() && route.NextHop != nexthopv6.IP { - log.Infof("Network monitor: default route added: via %s, interface %s", route.NextHop, intf) - return true - } + case systemops.RouteModified, systemops.RouteAdded: + return handleRouteAddedOrModified(route, nexthop) case systemops.RouteDeleted: - if nexthopv4.Intf != nil && route.NextHop == nexthopv4.IP || nexthopv6.Intf != nil && route.NextHop == nexthopv6.IP { - log.Infof("Network monitor: default route removed: via %s, interface %s", route.NextHop, intf) - return true - } + return handleRouteDeleted(route, nexthop) } return false } +func handleRouteAddedOrModified(route systemops.RouteUpdate, nexthop systemops.Nexthop) bool { + // For added/modified routes, we care about different next hops + if !nexthop.Equal(route.NextHop) { + action := "changed" + if route.Type == systemops.RouteAdded { + action = "added" + } + log.Infof("Network monitor: default route %s: via %s", action, route.NextHop) + return true + } + return false +} + +func handleRouteDeleted(route systemops.RouteUpdate, nexthop systemops.Nexthop) bool { + // For deleted routes, we care about our tracked next hop being deleted + if nexthop.Equal(route.NextHop) { + log.Infof("Network monitor: default route removed: via %s", route.NextHop) + return true + } + return false +} + func isSoftInterface(name string) bool { return strings.Contains(strings.ToLower(name), "isatap") || strings.Contains(strings.ToLower(name), "teredo") } diff --git a/client/internal/networkmonitor/check_change_windows_test.go b/client/internal/networkmonitor/check_change_windows_test.go new file mode 100644 index 000000000..29ff34dca --- /dev/null +++ b/client/internal/networkmonitor/check_change_windows_test.go @@ -0,0 +1,404 @@ +package networkmonitor + +import ( + "net" + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" +) + +func TestRouteChanged(t *testing.T) { + tests := []struct { + name string + route systemops.RouteUpdate + nexthopv4 systemops.Nexthop + nexthopv6 systemops.Nexthop + expected bool + }{ + { + name: "soft interface should be ignored", + route: systemops.RouteUpdate{ + Type: systemops.RouteModified, + Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + Intf: &net.Interface{ + Name: "ISATAP-Interface", // isSoftInterface checks name + }, + }, + }, + nexthopv4: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.2"), + }, + nexthopv6: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + }, + expected: false, + }, + { + name: "modified route with different v4 nexthop IP should return true", + route: systemops.RouteUpdate{ + Type: systemops.RouteModified, + Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + Intf: &net.Interface{ + Index: 1, Name: "eth0", + }, + }, + }, + nexthopv4: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.2"), + Intf: &net.Interface{ + Index: 1, Name: "eth0", + }, + }, + nexthopv6: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + }, + expected: true, + }, + { + name: "modified route with same v4 nexthop (IP and Intf Index) should return false", + route: systemops.RouteUpdate{ + Type: systemops.RouteModified, + Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + Intf: &net.Interface{ + Index: 1, Name: "eth0", + }, + }, + }, + nexthopv4: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + Intf: &net.Interface{ + Index: 1, Name: "eth0", + }, + }, + nexthopv6: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + }, + expected: false, + }, + { + name: "added route with different v6 nexthop IP should return true", + route: systemops.RouteUpdate{ + Type: systemops.RouteAdded, + Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::2"), + Intf: &net.Interface{ + Index: 1, Name: "eth0", + }, + }, + }, + nexthopv4: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + }, + nexthopv6: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + Intf: &net.Interface{ + Index: 1, Name: "eth0", + }, + }, + expected: true, + }, + { + name: "added route with same v6 nexthop (IP and Intf Index) should return false", + route: systemops.RouteUpdate{ + Type: systemops.RouteAdded, + Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + Intf: &net.Interface{ + Index: 1, Name: "eth0", + }, + }, + }, + nexthopv4: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + }, + nexthopv6: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + Intf: &net.Interface{ + Index: 1, Name: "eth0", + }, + }, + expected: false, + }, + { + name: "deleted route matching tracked v4 nexthop (IP and Intf Index) should return true", + route: systemops.RouteUpdate{ + Type: systemops.RouteDeleted, + Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + Intf: &net.Interface{ + Index: 1, Name: "eth0", + }, + }, + }, + nexthopv4: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + Intf: &net.Interface{ + Index: 1, Name: "eth0", + }, + }, + nexthopv6: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + }, + expected: true, + }, + { + name: "deleted route not matching tracked v4 nexthop (different IP) should return false", + route: systemops.RouteUpdate{ + Type: systemops.RouteDeleted, + Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.3"), // Different IP + Intf: &net.Interface{ + Index: 1, Name: "eth0", + }, + }, + }, + nexthopv4: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + Intf: &net.Interface{ + Index: 1, Name: "eth0", + }, + }, + nexthopv6: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + }, + expected: false, + }, + { + name: "modified v4 route with same IP, different Intf Index should return true", + route: systemops.RouteUpdate{ + Type: systemops.RouteModified, + Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + Intf: &net.Interface{Index: 2, Name: "eth1"}, // Different Intf Index + }, + }, + nexthopv4: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + Intf: &net.Interface{Index: 1, Name: "eth0"}, + }, + expected: true, + }, + { + name: "modified v4 route with same IP, one Intf nil, other non-nil should return true", + route: systemops.RouteUpdate{ + Type: systemops.RouteModified, + Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + Intf: nil, // Intf is nil + }, + }, + nexthopv4: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + Intf: &net.Interface{Index: 1, Name: "eth0"}, // Tracked Intf is not nil + }, + expected: true, + }, + { + name: "added v4 route with same IP, different Intf Index should return true", + route: systemops.RouteUpdate{ + Type: systemops.RouteAdded, + Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + Intf: &net.Interface{Index: 2, Name: "eth1"}, // Different Intf Index + }, + }, + nexthopv4: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + Intf: &net.Interface{Index: 1, Name: "eth0"}, + }, + expected: true, + }, + { + name: "deleted v4 route with same IP, different Intf Index should return false", + route: systemops.RouteUpdate{ + Type: systemops.RouteDeleted, + Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0), + NextHop: systemops.Nexthop{ // This is the route being deleted + IP: netip.MustParseAddr("192.168.1.1"), + Intf: &net.Interface{Index: 1, Name: "eth0"}, + }, + }, + nexthopv4: systemops.Nexthop{ // This is our tracked nexthop + IP: netip.MustParseAddr("192.168.1.1"), + Intf: &net.Interface{Index: 2, Name: "eth1"}, // Different Intf Index + }, + expected: false, // Because nexthopv4.Equal(route.NextHop) will be false + }, + { + name: "modified v6 route with different IP, same Intf Index should return true", + route: systemops.RouteUpdate{ + Type: systemops.RouteModified, + Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::3"), // Different IP + Intf: &net.Interface{Index: 1, Name: "eth0"}, + }, + }, + nexthopv6: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + Intf: &net.Interface{Index: 1, Name: "eth0"}, + }, + expected: true, + }, + { + name: "modified v6 route with same IP, different Intf Index should return true", + route: systemops.RouteUpdate{ + Type: systemops.RouteModified, + Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + Intf: &net.Interface{Index: 2, Name: "eth1"}, // Different Intf Index + }, + }, + nexthopv6: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + Intf: &net.Interface{Index: 1, Name: "eth0"}, + }, + expected: true, + }, + { + name: "modified v6 route with same IP, same Intf Index should return false", + route: systemops.RouteUpdate{ + Type: systemops.RouteModified, + Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + Intf: &net.Interface{Index: 1, Name: "eth0"}, + }, + }, + nexthopv6: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + Intf: &net.Interface{Index: 1, Name: "eth0"}, + }, + expected: false, + }, + { + name: "deleted v6 route matching tracked nexthop (IP and Intf Index) should return true", + route: systemops.RouteUpdate{ + Type: systemops.RouteDeleted, + Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + Intf: &net.Interface{Index: 1, Name: "eth0"}, + }, + }, + nexthopv6: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + Intf: &net.Interface{Index: 1, Name: "eth0"}, + }, + expected: true, + }, + { + name: "deleted v6 route not matching tracked nexthop (different IP) should return false", + route: systemops.RouteUpdate{ + Type: systemops.RouteDeleted, + Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::3"), // Different IP + Intf: &net.Interface{Index: 1, Name: "eth0"}, + }, + }, + nexthopv6: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + Intf: &net.Interface{Index: 1, Name: "eth0"}, + }, + expected: false, + }, + { + name: "deleted v6 route not matching tracked nexthop (same IP, different Intf Index) should return false", + route: systemops.RouteUpdate{ + Type: systemops.RouteDeleted, + Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0), + NextHop: systemops.Nexthop{ // This is the route being deleted + IP: netip.MustParseAddr("2001:db8::1"), + Intf: &net.Interface{Index: 1, Name: "eth0"}, + }, + }, + nexthopv6: systemops.Nexthop{ // This is our tracked nexthop + IP: netip.MustParseAddr("2001:db8::1"), + Intf: &net.Interface{Index: 2, Name: "eth1"}, // Different Intf Index + }, + expected: false, + }, + { + name: "unknown route type should return false", + route: systemops.RouteUpdate{ + Type: systemops.RouteUpdateType(99), // Unknown type + Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + Intf: &net.Interface{Index: 1, Name: "eth0"}, + }, + }, + nexthopv4: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.2"), // Different from route.NextHop + Intf: &net.Interface{Index: 1, Name: "eth0"}, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := routeChanged(tt.route, tt.nexthopv4, tt.nexthopv6) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestIsSoftInterface(t *testing.T) { + tests := []struct { + name string + ifname string + expected bool + }{ + { + name: "ISATAP interface should be detected", + ifname: "ISATAP tunnel adapter", + expected: true, + }, + { + name: "lowercase soft interface should be detected", + ifname: "isatap.{14A5CF17-CA72-43EC-B4EA-B4B093641B7D}", + expected: true, + }, + { + name: "Teredo interface should be detected", + ifname: "Teredo Tunneling Pseudo-Interface", + expected: true, + }, + { + name: "regular interface should not be detected as soft", + ifname: "eth0", + expected: false, + }, + { + name: "another regular interface should not be detected as soft", + ifname: "wlan0", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isSoftInterface(tt.ifname) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/client/internal/networkmonitor/monitor.go b/client/internal/networkmonitor/monitor.go index 5896b66b6..accdd9c9d 100644 --- a/client/internal/networkmonitor/monitor.go +++ b/client/internal/networkmonitor/monitor.go @@ -118,9 +118,12 @@ func (nw *NetworkMonitor) Stop() { } func (nw *NetworkMonitor) checkChanges(ctx context.Context, event chan struct{}, nexthop4 systemops.Nexthop, nexthop6 systemops.Nexthop) { + defer close(event) for { if err := checkChangeFn(ctx, nexthop4, nexthop6); err != nil { - close(event) + if !errors.Is(err, context.Canceled) { + log.Errorf("Network monitor: failed to check for changes: %v", err) + } return } // prevent blocking diff --git a/client/internal/peer/route.go b/client/internal/peer/route.go index c3567dcc9..e5e315e3c 100644 --- a/client/internal/peer/route.go +++ b/client/internal/peer/route.go @@ -6,12 +6,14 @@ import ( "sync" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/route" ) // routeEntry holds the route prefix and the corresponding resource ID. type routeEntry struct { prefix netip.Prefix - resourceID string + resourceID route.ResID } type routeIDLookup struct { @@ -24,7 +26,7 @@ type routeIDLookup struct { resolvedIPs sync.Map } -func (r *routeIDLookup) AddLocalRouteID(resourceID string, route netip.Prefix) { +func (r *routeIDLookup) AddLocalRouteID(resourceID route.ResID, route netip.Prefix) { r.localLock.Lock() defer r.localLock.Unlock() @@ -56,7 +58,7 @@ func (r *routeIDLookup) RemoveLocalRouteID(route netip.Prefix) { } } -func (r *routeIDLookup) AddRemoteRouteID(resourceID string, route netip.Prefix) { +func (r *routeIDLookup) AddRemoteRouteID(resourceID route.ResID, route netip.Prefix) { r.remoteLock.Lock() defer r.remoteLock.Unlock() @@ -87,7 +89,7 @@ func (r *routeIDLookup) RemoveRemoteRouteID(route netip.Prefix) { } } -func (r *routeIDLookup) AddResolvedIP(resourceID string, route netip.Prefix) { +func (r *routeIDLookup) AddResolvedIP(resourceID route.ResID, route netip.Prefix) { r.resolvedIPs.Store(route.Addr(), resourceID) } @@ -97,19 +99,19 @@ func (r *routeIDLookup) RemoveResolvedIP(route netip.Prefix) { // Lookup returns the resource ID for the given IP address // and a bool indicating if the IP is an exit node. -func (r *routeIDLookup) Lookup(ip netip.Addr) (string, bool) { +func (r *routeIDLookup) Lookup(ip netip.Addr) (route.ResID, bool) { if res, ok := r.resolvedIPs.Load(ip); ok { - return res.(string), false + return res.(route.ResID), false } - var resourceID string + var resourceID route.ResID var isExitNode bool r.localLock.RLock() for _, entry := range r.localRoutes { if entry.prefix.Contains(ip) { resourceID = entry.resourceID - isExitNode = (entry.prefix.Bits() == 0) + isExitNode = entry.prefix.Bits() == 0 break } } @@ -120,7 +122,7 @@ func (r *routeIDLookup) Lookup(ip netip.Addr) (string, bool) { for _, entry := range r.remoteRoutes { if entry.prefix.Contains(ip) { resourceID = entry.resourceID - isExitNode = (entry.prefix.Bits() == 0) + isExitNode = entry.prefix.Bits() == 0 break } } diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 9b3fc744d..3eca6a8c9 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -21,6 +21,7 @@ import ( "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/management/domain" relayClient "github.com/netbirdio/netbird/relay/client" + "github.com/netbirdio/netbird/route" ) const eventQueueSize = 10 @@ -313,7 +314,7 @@ func (d *Status) UpdatePeerState(receivedState State) error { return nil } -func (d *Status) AddPeerStateRoute(peer string, route string, resourceId string) error { +func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.ResID) error { d.mux.Lock() defer d.mux.Unlock() @@ -581,7 +582,7 @@ func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) { } // AddLocalPeerStateRoute adds a route to the local peer state -func (d *Status) AddLocalPeerStateRoute(route, resourceId string) { +func (d *Status) AddLocalPeerStateRoute(route string, resourceId route.ResID) { d.mux.Lock() defer d.mux.Unlock() @@ -611,14 +612,11 @@ func (d *Status) RemoveLocalPeerStateRoute(route string) { } // AddResolvedIPLookupEntry adds a resolved IP lookup entry -func (d *Status) AddResolvedIPLookupEntry(route, resourceId string) { +func (d *Status) AddResolvedIPLookupEntry(prefix netip.Prefix, resourceId route.ResID) { d.mux.Lock() defer d.mux.Unlock() - pref, err := netip.ParsePrefix(route) - if err == nil { - d.routeIDLookup.AddResolvedIP(resourceId, pref) - } + d.routeIDLookup.AddResolvedIP(resourceId, prefix) } // RemoveResolvedIPLookupEntry removes a resolved IP lookup entry @@ -723,7 +721,7 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) { d.nsGroupStates = dnsStates } -func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix, resourceId string) { +func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix, resourceId route.ResID) { d.mux.Lock() defer d.mux.Unlock() diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 68d81d968..6d51c88c0 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -234,7 +234,7 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error { origPattern = writer.GetOrigPattern() } - resolvedDomain := domain.Domain(r.Question[0].Name) + resolvedDomain := domain.Domain(strings.ToLower(r.Question[0].Name)) // already punycode via RegisterHandler() originalDomain := domain.Domain(origPattern) @@ -328,6 +328,11 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom // Update domain prefixes using resolved domain as key if len(toAdd) > 0 || len(toRemove) > 0 { + if d.route.KeepRoute { + // replace stored prefixes with old + added + // nolint:gocritic + newPrefixes = append(oldPrefixes, toAdd...) + } d.interceptedDomains[resolvedDomain] = newPrefixes originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), ".")) d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID()) @@ -338,7 +343,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom originalDomain.SafeString(), toAdd) } - if len(toRemove) > 0 { + if len(toRemove) > 0 && !d.route.KeepRoute { log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s", resolvedDomain.SafeString(), originalDomain.SafeString(), diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index ae0d1d220..078206ab9 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -259,8 +259,6 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { } } - m.ctx = nil - m.mux.Lock() defer m.mux.Unlock() m.clientRoutes = nil @@ -292,7 +290,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro return nil } - if err := m.serverRouter.updateRoutes(newServerRoutesMap); err != nil { + if err := m.serverRouter.updateRoutes(newServerRoutesMap, useNewDNSRoute); err != nil { return fmt.Errorf("update routes: %w", err) } diff --git a/client/internal/routemanager/server_android.go b/client/internal/routemanager/server_android.go index 48bb0380d..953210e9e 100644 --- a/client/internal/routemanager/server_android.go +++ b/client/internal/routemanager/server_android.go @@ -18,7 +18,7 @@ type serverRouter struct { func (r serverRouter) cleanUp() { } -func (r serverRouter) updateRoutes(map[route.ID]*route.Route) error { +func (r serverRouter) updateRoutes(map[route.ID]*route.Route, bool) error { return nil } diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index 18713ee65..131d4c170 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -35,7 +35,10 @@ func newServerRouter(ctx context.Context, wgInterface iface.WGIface, firewall fi }, nil } -func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error { +func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route, useNewDNSRoute bool) error { + m.mux.Lock() + defer m.mux.Unlock() + serverRoutesToRemove := make([]route.ID, 0) for routeID := range m.routes { @@ -73,7 +76,7 @@ func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error { continue } - err := m.addToServerNetwork(newRoute) + err := m.addToServerNetwork(newRoute, useNewDNSRoute) if err != nil { log.Errorf("Unable to add route %s from server, got: %v", newRoute.ID, err) continue @@ -90,57 +93,30 @@ func (m *serverRouter) removeFromServerNetwork(route *route.Route) error { return m.ctx.Err() } - m.mux.Lock() - defer m.mux.Unlock() - - routerPair, err := routeToRouterPair(route) - if err != nil { - return fmt.Errorf("parse prefix: %w", err) - } - - err = m.firewall.RemoveNatRule(routerPair) - if err != nil { + routerPair := routeToRouterPair(route, false) + if err := m.firewall.RemoveNatRule(routerPair); err != nil { return fmt.Errorf("remove routing rules: %w", err) } delete(m.routes, route.ID) - - routeStr := route.Network.String() - if route.IsDynamic() { - routeStr = route.Domains.SafeString() - } - m.statusRecorder.RemoveLocalPeerStateRoute(routeStr) + m.statusRecorder.RemoveLocalPeerStateRoute(route.NetString()) return nil } -func (m *serverRouter) addToServerNetwork(route *route.Route) error { +func (m *serverRouter) addToServerNetwork(route *route.Route, useNewDNSRoute bool) error { if m.ctx.Err() != nil { log.Infof("Not adding to server network because context is done") return m.ctx.Err() } - m.mux.Lock() - defer m.mux.Unlock() - - routerPair, err := routeToRouterPair(route) - if err != nil { - return fmt.Errorf("parse prefix: %w", err) - } - - err = m.firewall.AddNatRule(routerPair) - if err != nil { + routerPair := routeToRouterPair(route, useNewDNSRoute) + if err := m.firewall.AddNatRule(routerPair); err != nil { return fmt.Errorf("insert routing rules: %w", err) } m.routes[route.ID] = route - - routeStr := route.Network.String() - if route.IsDynamic() { - routeStr = route.Domains.SafeString() - } - - m.statusRecorder.AddLocalPeerStateRoute(routeStr, route.GetResourceID()) + m.statusRecorder.AddLocalPeerStateRoute(route.NetString(), route.GetResourceID()) return nil } @@ -148,31 +124,29 @@ func (m *serverRouter) addToServerNetwork(route *route.Route) error { func (m *serverRouter) cleanUp() { m.mux.Lock() defer m.mux.Unlock() - for _, r := range m.routes { - routerPair, err := routeToRouterPair(r) - if err != nil { - log.Errorf("Failed to convert route to router pair: %v", err) - continue - } - err = m.firewall.RemoveNatRule(routerPair) - if err != nil { + for _, r := range m.routes { + routerPair := routeToRouterPair(r, false) + if err := m.firewall.RemoveNatRule(routerPair); err != nil { log.Errorf("Failed to remove cleanup route: %v", err) } - } m.statusRecorder.CleanLocalPeerStateRoutes() } -func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) { - // TODO: add ipv6 +func routeToRouterPair(route *route.Route, useNewDNSRoute bool) firewall.RouterPair { source := getDefaultPrefix(route.Network) - - destination := route.Network.Masked() + destination := firewall.Network{} if route.IsDynamic() { - // TODO: add ipv6 additionally - destination = getDefaultPrefix(destination) + if useNewDNSRoute { + destination.Set = firewall.NewDomainSet(route.Domains) + } else { + // TODO: add ipv6 additionally + destination = getDefaultPrefix(destination.Prefix) + } + } else { + destination.Prefix = route.Network.Masked() } return firewall.RouterPair{ @@ -180,12 +154,16 @@ func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) { Source: source, Destination: destination, Masquerade: route.Masquerade, - }, nil + } } -func getDefaultPrefix(prefix netip.Prefix) netip.Prefix { +func getDefaultPrefix(prefix netip.Prefix) firewall.Network { if prefix.Addr().Is6() { - return netip.PrefixFrom(netip.IPv6Unspecified(), 0) + return firewall.Network{ + Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0), + } + } + return firewall.Network{ + Prefix: netip.PrefixFrom(netip.IPv4Unspecified(), 0), } - return netip.PrefixFrom(netip.IPv4Unspecified(), 0) } diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go index 5c117b94d..fd511fc20 100644 --- a/client/internal/routemanager/systemops/systemops.go +++ b/client/internal/routemanager/systemops/systemops.go @@ -1,6 +1,7 @@ package systemops import ( + "fmt" "net" "net/netip" "sync" @@ -15,6 +16,20 @@ type Nexthop struct { Intf *net.Interface } +// Equal checks if two nexthops are equal. +func (n Nexthop) Equal(other Nexthop) bool { + return n.IP == other.IP && (n.Intf == nil && other.Intf == nil || + n.Intf != nil && other.Intf != nil && n.Intf.Index == other.Intf.Index) +} + +// String returns a string representation of the nexthop. +func (n Nexthop) String() string { + if n.Intf == nil { + return n.IP.String() + } + return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name) +} + type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop] type SysOps struct { diff --git a/client/internal/routemanager/systemops/systemops_bsd_test.go b/client/internal/routemanager/systemops/systemops_bsd_test.go index 84b84483e..a83d7f1de 100644 --- a/client/internal/routemanager/systemops/systemops_bsd_test.go +++ b/client/internal/routemanager/systemops/systemops_bsd_test.go @@ -24,7 +24,6 @@ func init() { testCases = append(testCases, []testCase{ { name: "To more specific route without custom dialer via vpn", - destination: "10.10.0.2:53", expectedInterface: expectedVPNint, dialer: &net.Dialer{}, expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53), diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index cf3c2f0aa..59b6346c6 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -45,7 +45,7 @@ var sysctlFailed bool type ruleParams struct { priority int - fwmark int + fwmark uint32 tableID int family int invert bool @@ -55,8 +55,8 @@ type ruleParams struct { func getSetupRules() []ruleParams { return []ruleParams{ - {100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"}, - {100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, false, 0, "rule with suppress prefixlen v6"}, + {100, 0, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"}, + {100, 0, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, false, 0, "rule with suppress prefixlen v6"}, {110, nbnet.ControlPlaneMark, NetbirdVPNTableID, netlink.FAMILY_V4, true, -1, "rule v4 netbird"}, {110, nbnet.ControlPlaneMark, NetbirdVPNTableID, netlink.FAMILY_V6, true, -1, "rule v6 netbird"}, } diff --git a/client/internal/routemanager/systemops/systemops_linux_test.go b/client/internal/routemanager/systemops/systemops_linux_test.go index 8f12740d0..f0d7472dc 100644 --- a/client/internal/routemanager/systemops/systemops_linux_test.go +++ b/client/internal/routemanager/systemops/systemops_linux_test.go @@ -27,14 +27,12 @@ func init() { testCases = append(testCases, []testCase{ { name: "To more specific route without custom dialer via physical interface", - destination: "10.10.0.2:53", expectedInterface: expectedInternalInt, dialer: &net.Dialer{}, expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53), }, { name: "To more specific route (local) without custom dialer via physical interface", - destination: "127.0.10.1:53", expectedInterface: expectedLoopbackInt, dialer: &net.Dialer{}, expectedPacket: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53), @@ -134,6 +132,16 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) { _, dstIPNet, err := net.ParseCIDR(dstCIDR) require.NoError(t, err) + link, err := netlink.LinkByName(intf) + require.NoError(t, err) + linkIndex := link.Attrs().Index + + route := &netlink.Route{ + Dst: dstIPNet, + Gw: gw, + LinkIndex: linkIndex, + } + // Handle existing routes with metric 0 var originalNexthop net.IP var originalLinkIndex int @@ -145,32 +153,24 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) { } if originalNexthop != nil { + // remove original route err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0}) - switch { - case err != nil && !errors.Is(err, syscall.ESRCH): - t.Logf("Failed to delete route: %v", err) - case err == nil: - t.Cleanup(func() { - err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: originalNexthop, LinkIndex: originalLinkIndex, Priority: 0}) - if err != nil && !errors.Is(err, syscall.EEXIST) { - t.Fatalf("Failed to add route: %v", err) - } - }) - default: - t.Logf("Failed to delete route: %v", err) - } + assert.NoError(t, err) + + // add new route + assert.NoError(t, netlink.RouteAdd(route)) + + t.Cleanup(func() { + // restore original route + assert.NoError(t, netlink.RouteDel(route)) + err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: originalNexthop, LinkIndex: originalLinkIndex, Priority: 0}) + assert.NoError(t, err) + }) + + return } } - link, err := netlink.LinkByName(intf) - require.NoError(t, err) - linkIndex := link.Attrs().Index - - route := &netlink.Route{ - Dst: dstIPNet, - Gw: gw, - LinkIndex: linkIndex, - } err = netlink.RouteDel(route) if err != nil && !errors.Is(err, syscall.ESRCH) { t.Logf("Failed to delete route: %v", err) @@ -180,7 +180,6 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) { if err != nil && !errors.Is(err, syscall.EEXIST) { t.Fatalf("Failed to add route: %v", err) } - require.NoError(t, err) } func fetchOriginalGateway(family int) (net.IP, int, error) { @@ -190,7 +189,11 @@ func fetchOriginalGateway(family int) (net.IP, int, error) { } for _, route := range routes { - if route.Dst == nil && route.Priority == 0 { + ones := -1 + if route.Dst != nil { + ones, _ = route.Dst.Mask.Size() + } + if route.Dst == nil || ones == 0 && route.Priority == 0 { return route.Gw, route.LinkIndex, nil } } diff --git a/client/internal/routemanager/systemops/systemops_unix_test.go b/client/internal/routemanager/systemops/systemops_unix_test.go index d88c1ab6b..ad37f611f 100644 --- a/client/internal/routemanager/systemops/systemops_unix_test.go +++ b/client/internal/routemanager/systemops/systemops_unix_test.go @@ -31,7 +31,6 @@ type PacketExpectation struct { type testCase struct { name string - destination string expectedInterface string dialer dialer expectedPacket PacketExpectation @@ -40,14 +39,12 @@ type testCase struct { var testCases = []testCase{ { name: "To external host without custom dialer via vpn", - destination: "192.0.2.1:53", expectedInterface: expectedVPNint, dialer: &net.Dialer{}, expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53), }, { name: "To external host with custom dialer via physical interface", - destination: "192.0.2.1:53", expectedInterface: expectedExternalInt, dialer: nbnet.NewDialer(), expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53), @@ -55,14 +52,12 @@ var testCases = []testCase{ { name: "To duplicate internal route with custom dialer via physical interface", - destination: "10.0.0.2:53", expectedInterface: expectedInternalInt, dialer: nbnet.NewDialer(), expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53), }, { name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence - destination: "10.0.0.2:53", expectedInterface: expectedInternalInt, dialer: &net.Dialer{}, expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53), @@ -70,14 +65,12 @@ var testCases = []testCase{ { name: "To unique vpn route with custom dialer via physical interface", - destination: "172.16.0.2:53", expectedInterface: expectedExternalInt, dialer: nbnet.NewDialer(), expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53), }, { name: "To unique vpn route without custom dialer via vpn", - destination: "172.16.0.2:53", expectedInterface: expectedVPNint, dialer: &net.Dialer{}, expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53), @@ -94,10 +87,11 @@ func TestRouting(t *testing.T) { t.Run(tc.name, func(t *testing.T) { setupTestEnv(t) - filter := createBPFFilter(tc.destination) + dst := fmt.Sprintf("%s:%d", tc.expectedPacket.DstIP, tc.expectedPacket.DstPort) + filter := createBPFFilter(dst) handle := startPacketCapture(t, tc.expectedInterface, filter) - sendTestPacket(t, tc.destination, tc.expectedPacket.SrcPort, tc.dialer) + sendTestPacket(t, dst, tc.expectedPacket.SrcPort, tc.dialer) packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) packet, err := packetSource.NextPacket() diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go index ad325e123..f66161595 100644 --- a/client/internal/routemanager/systemops/systemops_windows.go +++ b/client/internal/routemanager/systemops/systemops_windows.go @@ -33,8 +33,7 @@ type RouteUpdateType int type RouteUpdate struct { Type RouteUpdateType Destination netip.Prefix - NextHop netip.Addr - Interface *net.Interface + NextHop Nexthop } // RouteMonitor provides a way to monitor changes in the routing table. @@ -231,15 +230,15 @@ func (rm *RouteMonitor) parseUpdate(row *MIB_IPFORWARD_ROW2, notificationType MI intf, err := net.InterfaceByIndex(idx) if err != nil { log.Warnf("failed to get interface name for index %d: %v", idx, err) - update.Interface = &net.Interface{ + update.NextHop.Intf = &net.Interface{ Index: idx, } } else { - update.Interface = intf + update.NextHop.Intf = intf } } - log.Tracef("Received route update with destination %v, next hop %v, interface %v", row.DestinationPrefix, row.NextHop, update.Interface) + log.Tracef("Received route update with destination %v, next hop %v, interface %v", row.DestinationPrefix, row.NextHop, update.NextHop.Intf) dest := parseIPPrefix(row.DestinationPrefix, idx) if !dest.Addr().IsValid() { return RouteUpdate{}, fmt.Errorf("invalid destination: %v", row) @@ -262,7 +261,7 @@ func (rm *RouteMonitor) parseUpdate(row *MIB_IPFORWARD_ROW2, notificationType MI update.Type = updateType update.Destination = dest - update.NextHop = nexthop + update.NextHop.IP = nexthop return update, nil } diff --git a/client/internal/routeselector/routeselector.go b/client/internal/routeselector/routeselector.go index 72c4758f4..8ebdc63e5 100644 --- a/client/internal/routeselector/routeselector.go +++ b/client/internal/routeselector/routeselector.go @@ -14,23 +14,15 @@ import ( ) type RouteSelector struct { - mu sync.RWMutex - selectedRoutes map[route.NetID]struct{} - selectAll bool - - // Indicates if new routes should be automatically selected - includeNewRoutes bool - - // All known routes at the time of deselection - knownRoutes []route.NetID + mu sync.RWMutex + deselectedRoutes map[route.NetID]struct{} + deselectAll bool } func NewRouteSelector() *RouteSelector { return &RouteSelector{ - selectedRoutes: map[route.NetID]struct{}{}, - selectAll: true, - includeNewRoutes: false, - knownRoutes: []route.NetID{}, + deselectedRoutes: map[route.NetID]struct{}{}, + deselectAll: false, } } @@ -39,8 +31,11 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al rs.mu.Lock() defer rs.mu.Unlock() - if !appendRoute { - rs.selectedRoutes = map[route.NetID]struct{}{} + if !appendRoute || rs.deselectAll { + maps.Clear(rs.deselectedRoutes) + for _, r := range allRoutes { + rs.deselectedRoutes[r] = struct{}{} + } } var err *multierror.Error @@ -49,11 +44,10 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route)) continue } - - rs.selectedRoutes[route] = struct{}{} + delete(rs.deselectedRoutes, route) } - rs.selectAll = false - rs.includeNewRoutes = false + + rs.deselectAll = false return errors.FormatErrorOrNil(err) } @@ -63,38 +57,26 @@ func (rs *RouteSelector) SelectAllRoutes() { rs.mu.Lock() defer rs.mu.Unlock() - rs.selectAll = true - rs.selectedRoutes = map[route.NetID]struct{}{} - rs.includeNewRoutes = false + rs.deselectAll = false + maps.Clear(rs.deselectedRoutes) } // DeselectRoutes removes specific routes from the selection. -// If the selector is in "select all" mode, it will transition to "select specific" mode -// but will keep new routes selected. func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error { rs.mu.Lock() defer rs.mu.Unlock() - if rs.selectAll { - rs.selectAll = false - rs.includeNewRoutes = true - rs.knownRoutes = make([]route.NetID, len(allRoutes)) - copy(rs.knownRoutes, allRoutes) - - rs.selectedRoutes = map[route.NetID]struct{}{} - for _, route := range allRoutes { - rs.selectedRoutes[route] = struct{}{} - } + if rs.deselectAll { + return nil } var err *multierror.Error - for _, route := range routes { if !slices.Contains(allRoutes, route) { err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route)) continue } - delete(rs.selectedRoutes, route) + rs.deselectedRoutes[route] = struct{}{} } return errors.FormatErrorOrNil(err) @@ -105,9 +87,8 @@ func (rs *RouteSelector) DeselectAllRoutes() { rs.mu.Lock() defer rs.mu.Unlock() - rs.selectAll = false - rs.includeNewRoutes = false - rs.selectedRoutes = map[route.NetID]struct{}{} + rs.deselectAll = true + maps.Clear(rs.deselectedRoutes) } // IsSelected checks if a specific route is selected. @@ -115,23 +96,12 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool { rs.mu.RLock() defer rs.mu.RUnlock() - if rs.selectAll { - return true + if rs.deselectAll { + return false } - // Check if the route exists in selectedRoutes - _, selected := rs.selectedRoutes[routeID] - if selected { - return true - } - - // If includeNewRoutes is true and this is a new route (not in knownRoutes), - // then it should be selected - if rs.includeNewRoutes && !slices.Contains(rs.knownRoutes, routeID) { - return true - } - - return false + _, deselected := rs.deselectedRoutes[routeID] + return !deselected } // FilterSelected removes unselected routes from the provided map. @@ -139,17 +109,15 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap { rs.mu.RLock() defer rs.mu.RUnlock() - if rs.selectAll { - return maps.Clone(routes) + if rs.deselectAll { + return route.HAMap{} } filtered := route.HAMap{} for id, rt := range routes { netID := id.NetID() - _, selected := rs.selectedRoutes[netID] - - // Include if directly selected or if it's a new route and includeNewRoutes is true - if selected || (rs.includeNewRoutes && !slices.Contains(rs.knownRoutes, netID)) { + _, deselected := rs.deselectedRoutes[netID] + if !deselected { filtered[id] = rt } } @@ -162,15 +130,11 @@ func (rs *RouteSelector) MarshalJSON() ([]byte, error) { defer rs.mu.RUnlock() return json.Marshal(struct { - SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"` - SelectAll bool `json:"select_all"` - IncludeNewRoutes bool `json:"include_new_routes"` - KnownRoutes []route.NetID `json:"known_routes"` + DeselectedRoutes map[route.NetID]struct{} `json:"deselected_routes"` + DeselectAll bool `json:"deselect_all"` }{ - SelectAll: rs.selectAll, - SelectedRoutes: rs.selectedRoutes, - IncludeNewRoutes: rs.includeNewRoutes, - KnownRoutes: rs.knownRoutes, + DeselectedRoutes: rs.deselectedRoutes, + DeselectAll: rs.deselectAll, }) } @@ -182,34 +146,25 @@ func (rs *RouteSelector) UnmarshalJSON(data []byte) error { // Check for null or empty JSON if len(data) == 0 || string(data) == "null" { - rs.selectedRoutes = map[route.NetID]struct{}{} - rs.selectAll = true - rs.includeNewRoutes = false - rs.knownRoutes = []route.NetID{} + rs.deselectedRoutes = map[route.NetID]struct{}{} + rs.deselectAll = false return nil } var temp struct { - SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"` - SelectAll bool `json:"select_all"` - IncludeNewRoutes bool `json:"include_new_routes"` - KnownRoutes []route.NetID `json:"known_routes"` + DeselectedRoutes map[route.NetID]struct{} `json:"deselected_routes"` + DeselectAll bool `json:"deselect_all"` } if err := json.Unmarshal(data, &temp); err != nil { return err } - rs.selectedRoutes = temp.SelectedRoutes - rs.selectAll = temp.SelectAll - rs.includeNewRoutes = temp.IncludeNewRoutes - rs.knownRoutes = temp.KnownRoutes + rs.deselectedRoutes = temp.DeselectedRoutes + rs.deselectAll = temp.DeselectAll - if rs.selectedRoutes == nil { - rs.selectedRoutes = map[route.NetID]struct{}{} - } - if rs.knownRoutes == nil { - rs.knownRoutes = []route.NetID{} + if rs.deselectedRoutes == nil { + rs.deselectedRoutes = map[route.NetID]struct{}{} } return nil diff --git a/client/internal/routeselector/routeselector_test.go b/client/internal/routeselector/routeselector_test.go index a1461dff6..cfa723246 100644 --- a/client/internal/routeselector/routeselector_test.go +++ b/client/internal/routeselector/routeselector_test.go @@ -66,12 +66,10 @@ func TestRouteSelector_SelectRoutes(t *testing.T) { t.Run(tt.name, func(t *testing.T) { rs := routeselector.NewRouteSelector() - if tt.initialSelected != nil { - err := rs.SelectRoutes(tt.initialSelected, false, allRoutes) - require.NoError(t, err) - } + err := rs.SelectRoutes(tt.initialSelected, false, allRoutes) + require.NoError(t, err) - err := rs.SelectRoutes(tt.selectRoutes, tt.append, allRoutes) + err = rs.SelectRoutes(tt.selectRoutes, tt.append, allRoutes) if tt.wantError { assert.Error(t, err) } else { @@ -251,7 +249,8 @@ func TestRouteSelector_IsSelected(t *testing.T) { assert.True(t, rs.IsSelected("route1")) assert.True(t, rs.IsSelected("route2")) assert.False(t, rs.IsSelected("route3")) - assert.False(t, rs.IsSelected("route4")) + // Unknown route is selected by default + assert.True(t, rs.IsSelected("route4")) } func TestRouteSelector_FilterSelected(t *testing.T) { @@ -297,8 +296,8 @@ func TestRouteSelector_NewRoutesBehavior(t *testing.T) { initialState: func(rs *routeselector.RouteSelector) error { return rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, initialRoutes) }, - // When specific routes were selected, new routes should remain unselected - wantNewSelected: []route.NetID{"route1", "route2"}, + // When specific routes were selected, new routes should be selected + wantNewSelected: []route.NetID{"route1", "route2", "route4", "route5"}, }, { name: "New routes after deselect all", @@ -315,7 +314,7 @@ func TestRouteSelector_NewRoutesBehavior(t *testing.T) { rs.SelectAllRoutes() return rs.DeselectRoutes([]route.NetID{"route1"}, initialRoutes) }, - // After deselecting specific routes, new routes should remain unselected + // After deselecting specific routes, new routes should be selected wantNewSelected: []route.NetID{"route2", "route3", "route4", "route5"}, }, { @@ -323,8 +322,8 @@ func TestRouteSelector_NewRoutesBehavior(t *testing.T) { initialState: func(rs *routeselector.RouteSelector) error { return rs.SelectRoutes([]route.NetID{"route1"}, true, initialRoutes) }, - // When routes were appended, new routes should remain unselected - wantNewSelected: []route.NetID{"route1"}, + // When routes were appended, new routes should be selected + wantNewSelected: []route.NetID{"route1", "route2", "route3", "route4", "route5"}, }, } @@ -428,3 +427,213 @@ func TestRouteSelector_MixedSelectionDeselection(t *testing.T) { }) } } + +func TestRouteSelector_AfterDeselectAll(t *testing.T) { + allRoutes := []route.NetID{"route1", "route2", "route3"} + + tests := []struct { + name string + initialAction func(rs *routeselector.RouteSelector) error + secondAction func(rs *routeselector.RouteSelector) error + wantSelected []route.NetID + wantError bool + }{ + { + name: "Deselect all -> select specific routes", + initialAction: func(rs *routeselector.RouteSelector) error { + rs.DeselectAllRoutes() + return nil + }, + secondAction: func(rs *routeselector.RouteSelector) error { + return rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, allRoutes) + }, + wantSelected: []route.NetID{"route1", "route2"}, + }, + { + name: "Deselect all -> select with append", + initialAction: func(rs *routeselector.RouteSelector) error { + rs.DeselectAllRoutes() + return nil + }, + secondAction: func(rs *routeselector.RouteSelector) error { + return rs.SelectRoutes([]route.NetID{"route1"}, true, allRoutes) + }, + wantSelected: []route.NetID{"route1"}, + }, + { + name: "Deselect all -> deselect specific", + initialAction: func(rs *routeselector.RouteSelector) error { + rs.DeselectAllRoutes() + return nil + }, + secondAction: func(rs *routeselector.RouteSelector) error { + return rs.DeselectRoutes([]route.NetID{"route1"}, allRoutes) + }, + wantSelected: []route.NetID{}, + }, + { + name: "Deselect all -> select all", + initialAction: func(rs *routeselector.RouteSelector) error { + rs.DeselectAllRoutes() + return nil + }, + secondAction: func(rs *routeselector.RouteSelector) error { + rs.SelectAllRoutes() + return nil + }, + wantSelected: []route.NetID{"route1", "route2", "route3"}, + }, + { + name: "Deselect all -> deselect non-existent route", + initialAction: func(rs *routeselector.RouteSelector) error { + rs.DeselectAllRoutes() + return nil + }, + secondAction: func(rs *routeselector.RouteSelector) error { + return rs.DeselectRoutes([]route.NetID{"route4"}, allRoutes) + }, + wantSelected: []route.NetID{}, + wantError: false, + }, + { + name: "Select specific -> deselect all -> select different", + initialAction: func(rs *routeselector.RouteSelector) error { + err := rs.SelectRoutes([]route.NetID{"route1"}, false, allRoutes) + if err != nil { + return err + } + rs.DeselectAllRoutes() + return nil + }, + secondAction: func(rs *routeselector.RouteSelector) error { + return rs.SelectRoutes([]route.NetID{"route2", "route3"}, false, allRoutes) + }, + wantSelected: []route.NetID{"route2", "route3"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rs := routeselector.NewRouteSelector() + + err := tt.initialAction(rs) + require.NoError(t, err) + + err = tt.secondAction(rs) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + for _, id := range allRoutes { + expected := slices.Contains(tt.wantSelected, id) + assert.Equal(t, expected, rs.IsSelected(id), + "Route %s selection state incorrect, expected %v", id, expected) + } + + routes := route.HAMap{ + "route1|10.0.0.0/8": {}, + "route2|192.168.0.0/16": {}, + "route3|172.16.0.0/12": {}, + } + + filtered := rs.FilterSelected(routes) + assert.Equal(t, len(tt.wantSelected), len(filtered), + "FilterSelected returned wrong number of routes") + }) + } +} + +func TestRouteSelector_ComplexScenarios(t *testing.T) { + allRoutes := []route.NetID{"route1", "route2", "route3", "route4"} + + tests := []struct { + name string + actions []func(rs *routeselector.RouteSelector) error + wantSelected []route.NetID + }{ + { + name: "Select all -> deselect specific -> select different with append", + actions: []func(rs *routeselector.RouteSelector) error{ + func(rs *routeselector.RouteSelector) error { + rs.SelectAllRoutes() + return nil + }, + func(rs *routeselector.RouteSelector) error { + return rs.DeselectRoutes([]route.NetID{"route1", "route2"}, allRoutes) + }, + func(rs *routeselector.RouteSelector) error { + return rs.SelectRoutes([]route.NetID{"route1"}, true, allRoutes) + }, + }, + wantSelected: []route.NetID{"route1", "route3", "route4"}, + }, + { + name: "Deselect all -> select specific -> deselect one -> select different with append", + actions: []func(rs *routeselector.RouteSelector) error{ + func(rs *routeselector.RouteSelector) error { + rs.DeselectAllRoutes() + return nil + }, + func(rs *routeselector.RouteSelector) error { + return rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, allRoutes) + }, + func(rs *routeselector.RouteSelector) error { + return rs.DeselectRoutes([]route.NetID{"route2"}, allRoutes) + }, + func(rs *routeselector.RouteSelector) error { + return rs.SelectRoutes([]route.NetID{"route3"}, true, allRoutes) + }, + }, + wantSelected: []route.NetID{"route1", "route3"}, + }, + { + name: "Select specific -> deselect specific -> select all -> deselect different", + actions: []func(rs *routeselector.RouteSelector) error{ + func(rs *routeselector.RouteSelector) error { + return rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, allRoutes) + }, + func(rs *routeselector.RouteSelector) error { + return rs.DeselectRoutes([]route.NetID{"route2"}, allRoutes) + }, + func(rs *routeselector.RouteSelector) error { + rs.SelectAllRoutes() + return nil + }, + func(rs *routeselector.RouteSelector) error { + return rs.DeselectRoutes([]route.NetID{"route3", "route4"}, allRoutes) + }, + }, + wantSelected: []route.NetID{"route1", "route2"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rs := routeselector.NewRouteSelector() + + for i, action := range tt.actions { + err := action(rs) + require.NoError(t, err, "Action %d failed", i) + } + + for _, id := range allRoutes { + expected := slices.Contains(tt.wantSelected, id) + assert.Equal(t, expected, rs.IsSelected(id), + "Route %s selection state incorrect", id) + } + + routes := route.HAMap{ + "route1|10.0.0.0/8": {}, + "route2|192.168.0.0/16": {}, + "route3|172.16.0.0/12": {}, + "route4|10.10.0.0/16": {}, + } + + filtered := rs.FilterSelected(routes) + assert.Equal(t, len(tt.wantSelected), len(filtered), + "FilterSelected returned wrong number of routes") + }) + } +} diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index d04d7a9c0..879fb8032 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v4.24.3 +// protoc v3.21.9 // source: daemon.proto package proto @@ -2277,6 +2277,7 @@ type DebugBundleRequest struct { Anonymize bool `protobuf:"varint,1,opt,name=anonymize,proto3" json:"anonymize,omitempty"` Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"` SystemInfo bool `protobuf:"varint,3,opt,name=systemInfo,proto3" json:"systemInfo,omitempty"` + UploadURL string `protobuf:"bytes,4,opt,name=uploadURL,proto3" json:"uploadURL,omitempty"` } func (x *DebugBundleRequest) Reset() { @@ -2332,12 +2333,21 @@ func (x *DebugBundleRequest) GetSystemInfo() bool { return false } +func (x *DebugBundleRequest) GetUploadURL() string { + if x != nil { + return x.UploadURL + } + return "" +} + type DebugBundleResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"` + Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"` + UploadedKey string `protobuf:"bytes,2,opt,name=uploadedKey,proto3" json:"uploadedKey,omitempty"` + UploadFailureReason string `protobuf:"bytes,3,opt,name=uploadFailureReason,proto3" json:"uploadFailureReason,omitempty"` } func (x *DebugBundleResponse) Reset() { @@ -2379,6 +2389,20 @@ func (x *DebugBundleResponse) GetPath() string { return "" } +func (x *DebugBundleResponse) GetUploadedKey() string { + if x != nil { + return x.UploadedKey + } + return "" +} + +func (x *DebugBundleResponse) GetUploadFailureReason() string { + if x != nil { + return x.UploadFailureReason + } + return "" +} + type GetLogLevelRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -3924,244 +3948,251 @@ var file_daemon_proto_rawDesc = []byte{ 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x2c, 0x0a, 0x05, 0x72, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x05, 0x72, 0x75, 0x6c, - 0x65, 0x73, 0x22, 0x6a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, - 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x6e, 0x6f, 0x6e, - 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x6e, 0x6f, - 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x1e, - 0x0a, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x22, 0x29, - 0x0a, 0x13, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x14, 0x0a, 0x12, 0x47, 0x65, 0x74, - 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, - 0x3d, 0x0a, 0x13, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, - 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x3c, - 0x0a, 0x12, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, - 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x15, 0x0a, 0x13, - 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x1b, 0x0a, 0x05, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x12, 0x0a, 0x04, - 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, - 0x22, 0x13, 0x0a, 0x11, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, - 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x06, 0x73, - 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, - 0x65, 0x73, 0x22, 0x44, 0x0a, 0x11, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x74, 0x65, - 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x74, 0x61, - 0x74, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x3b, 0x0a, 0x12, 0x43, 0x6c, 0x65, 0x61, - 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, - 0x0a, 0x0e, 0x63, 0x6c, 0x65, 0x61, 0x6e, 0x65, 0x64, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x63, 0x6c, 0x65, 0x61, 0x6e, 0x65, 0x64, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x45, 0x0a, 0x12, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, - 0x74, 0x61, 0x74, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x09, 0x73, 0x74, 0x61, 0x74, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, - 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x3c, 0x0a, 0x13, - 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x5f, 0x73, - 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x64, 0x65, 0x6c, - 0x65, 0x74, 0x65, 0x64, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x3b, 0x0a, 0x1f, 0x53, 0x65, - 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, - 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x18, 0x0a, - 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, - 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x22, 0x0a, 0x20, 0x53, 0x65, 0x74, 0x4e, 0x65, - 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, - 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x76, 0x0a, 0x08, 0x54, - 0x43, 0x50, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x73, 0x79, 0x6e, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x73, 0x79, 0x6e, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x63, 0x6b, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x63, 0x6b, 0x12, 0x10, 0x0a, 0x03, 0x66, - 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x66, 0x69, 0x6e, 0x12, 0x10, 0x0a, - 0x03, 0x72, 0x73, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x72, 0x73, 0x74, 0x12, - 0x10, 0x0a, 0x03, 0x70, 0x73, 0x68, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x70, 0x73, - 0x68, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, - 0x75, 0x72, 0x67, 0x22, 0x80, 0x03, 0x0a, 0x12, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, - 0x6b, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x73, 0x6f, - 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, - 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x70, 0x12, 0x25, 0x0a, 0x0e, 0x64, 0x65, 0x73, 0x74, 0x69, - 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x70, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x0d, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x49, 0x70, 0x12, 0x1a, - 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x6f, - 0x75, 0x72, 0x63, 0x65, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, - 0x0a, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x29, 0x0a, 0x10, 0x64, - 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, - 0x05, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, - 0x69, 0x6f, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, - 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x32, 0x0a, 0x09, 0x74, 0x63, 0x70, 0x5f, 0x66, 0x6c, 0x61, 0x67, - 0x73, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x54, 0x43, 0x50, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x48, 0x00, 0x52, 0x08, 0x74, 0x63, 0x70, - 0x46, 0x6c, 0x61, 0x67, 0x73, 0x88, 0x01, 0x01, 0x12, 0x20, 0x0a, 0x09, 0x69, 0x63, 0x6d, 0x70, - 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x01, 0x52, 0x08, 0x69, - 0x63, 0x6d, 0x70, 0x54, 0x79, 0x70, 0x65, 0x88, 0x01, 0x01, 0x12, 0x20, 0x0a, 0x09, 0x69, 0x63, - 0x6d, 0x70, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x02, 0x52, - 0x08, 0x69, 0x63, 0x6d, 0x70, 0x43, 0x6f, 0x64, 0x65, 0x88, 0x01, 0x01, 0x42, 0x0c, 0x0a, 0x0a, - 0x5f, 0x74, 0x63, 0x70, 0x5f, 0x66, 0x6c, 0x61, 0x67, 0x73, 0x42, 0x0c, 0x0a, 0x0a, 0x5f, 0x69, - 0x63, 0x6d, 0x70, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x42, 0x0c, 0x0a, 0x0a, 0x5f, 0x69, 0x63, 0x6d, - 0x70, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x22, 0x9f, 0x01, 0x0a, 0x0a, 0x54, 0x72, 0x61, 0x63, 0x65, - 0x53, 0x74, 0x61, 0x67, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, - 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, - 0x61, 0x67, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x12, 0x32, 0x0a, - 0x12, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x5f, 0x64, 0x65, 0x74, 0x61, - 0x69, 0x6c, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x11, 0x66, 0x6f, 0x72, - 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x88, 0x01, - 0x01, 0x42, 0x15, 0x0a, 0x13, 0x5f, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, - 0x5f, 0x64, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x22, 0x6e, 0x0a, 0x13, 0x54, 0x72, 0x61, 0x63, - 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, - 0x2a, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x53, 0x74, - 0x61, 0x67, 0x65, 0x52, 0x06, 0x73, 0x74, 0x61, 0x67, 0x65, 0x73, 0x12, 0x2b, 0x0a, 0x11, 0x66, - 0x69, 0x6e, 0x61, 0x6c, 0x5f, 0x64, 0x69, 0x73, 0x70, 0x6f, 0x73, 0x69, 0x74, 0x69, 0x6f, 0x6e, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x66, 0x69, 0x6e, 0x61, 0x6c, 0x44, 0x69, 0x73, - 0x70, 0x6f, 0x73, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x12, 0x0a, 0x10, 0x53, 0x75, 0x62, 0x73, - 0x63, 0x72, 0x69, 0x62, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x93, 0x04, 0x0a, - 0x0b, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x12, 0x0e, 0x0a, 0x02, - 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x38, 0x0a, 0x08, - 0x73, 0x65, 0x76, 0x65, 0x72, 0x69, 0x74, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1c, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, 0x76, - 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x76, 0x65, 0x72, 0x69, 0x74, 0x79, 0x52, 0x08, 0x73, 0x65, - 0x76, 0x65, 0x72, 0x69, 0x74, 0x79, 0x12, 0x38, 0x0a, 0x08, 0x63, 0x61, 0x74, 0x65, 0x67, 0x6f, - 0x72, 0x79, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x61, - 0x74, 0x65, 0x67, 0x6f, 0x72, 0x79, 0x52, 0x08, 0x63, 0x61, 0x74, 0x65, 0x67, 0x6f, 0x72, 0x79, - 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x20, 0x0a, 0x0b, 0x75, 0x73, - 0x65, 0x72, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x0b, 0x75, 0x73, 0x65, 0x72, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x38, 0x0a, 0x09, - 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, - 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x74, 0x69, 0x6d, - 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x12, 0x3d, 0x0a, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, - 0x74, 0x61, 0x18, 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x2e, 0x4d, 0x65, - 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x08, 0x6d, 0x65, 0x74, - 0x61, 0x64, 0x61, 0x74, 0x61, 0x1a, 0x3b, 0x0a, 0x0d, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, - 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, - 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, - 0x38, 0x01, 0x22, 0x3a, 0x0a, 0x08, 0x53, 0x65, 0x76, 0x65, 0x72, 0x69, 0x74, 0x79, 0x12, 0x08, - 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x00, 0x12, 0x0b, 0x0a, 0x07, 0x57, 0x41, 0x52, 0x4e, - 0x49, 0x4e, 0x47, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x02, - 0x12, 0x0c, 0x0a, 0x08, 0x43, 0x52, 0x49, 0x54, 0x49, 0x43, 0x41, 0x4c, 0x10, 0x03, 0x22, 0x52, - 0x0a, 0x08, 0x43, 0x61, 0x74, 0x65, 0x67, 0x6f, 0x72, 0x79, 0x12, 0x0b, 0x0a, 0x07, 0x4e, 0x45, - 0x54, 0x57, 0x4f, 0x52, 0x4b, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x44, 0x4e, 0x53, 0x10, 0x01, - 0x12, 0x12, 0x0a, 0x0e, 0x41, 0x55, 0x54, 0x48, 0x45, 0x4e, 0x54, 0x49, 0x43, 0x41, 0x54, 0x49, - 0x4f, 0x4e, 0x10, 0x02, 0x12, 0x10, 0x0a, 0x0c, 0x43, 0x4f, 0x4e, 0x4e, 0x45, 0x43, 0x54, 0x49, - 0x56, 0x49, 0x54, 0x59, 0x10, 0x03, 0x12, 0x0a, 0x0a, 0x06, 0x53, 0x59, 0x53, 0x54, 0x45, 0x4d, - 0x10, 0x04, 0x22, 0x12, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x40, 0x0a, 0x11, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65, - 0x6e, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x2b, 0x0a, 0x06, 0x65, - 0x76, 0x65, 0x6e, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, 0x76, 0x65, 0x6e, 0x74, - 0x52, 0x06, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, - 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, - 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, 0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, - 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, - 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, 0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, - 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, 0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, - 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, 0x43, 0x45, 0x10, 0x07, 0x32, 0xb3, 0x0b, 0x0a, - 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, - 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, - 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, - 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, - 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, - 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, - 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, - 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, - 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, - 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x22, 0x00, 0x12, 0x51, 0x0a, 0x0e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, - 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, - 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, - 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x53, 0x0a, 0x10, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, - 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1d, 0x2e, 0x64, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, - 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, - 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4a, 0x0a, 0x0f, 0x46, - 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x14, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1f, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x46, 0x6f, - 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, - 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, - 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, - 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, - 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, - 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, - 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, - 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, - 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, - 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, + 0x65, 0x73, 0x22, 0x88, 0x01, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, + 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x6e, 0x6f, + 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x6e, + 0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, + 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, + 0x1e, 0x0a, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x12, + 0x1c, 0x0a, 0x09, 0x75, 0x70, 0x6c, 0x6f, 0x61, 0x64, 0x55, 0x52, 0x4c, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x09, 0x75, 0x70, 0x6c, 0x6f, 0x61, 0x64, 0x55, 0x52, 0x4c, 0x22, 0x7d, 0x0a, + 0x13, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x20, 0x0a, 0x0b, 0x75, 0x70, 0x6c, 0x6f, + 0x61, 0x64, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x75, + 0x70, 0x6c, 0x6f, 0x61, 0x64, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x75, 0x70, + 0x6c, 0x6f, 0x61, 0x64, 0x46, 0x61, 0x69, 0x6c, 0x75, 0x72, 0x65, 0x52, 0x65, 0x61, 0x73, 0x6f, + 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x75, 0x70, 0x6c, 0x6f, 0x61, 0x64, 0x46, + 0x61, 0x69, 0x6c, 0x75, 0x72, 0x65, 0x52, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x22, 0x14, 0x0a, 0x12, + 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x22, 0x3d, 0x0a, 0x13, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, + 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, + 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, + 0x6c, 0x22, 0x3c, 0x0a, 0x12, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, + 0x15, 0x0a, 0x13, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1b, 0x0a, 0x05, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, + 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, + 0x61, 0x6d, 0x65, 0x22, 0x13, 0x0a, 0x11, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, + 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, + 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x73, + 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x44, 0x0a, 0x11, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, + 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, + 0x61, 0x74, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, + 0x73, 0x74, 0x61, 0x74, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x3b, 0x0a, 0x12, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, - 0x74, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, - 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, - 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x6f, 0x0a, - 0x18, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, - 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, - 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x28, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, - 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, - 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, - 0x0a, 0x0b, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x12, 0x1a, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, - 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x44, 0x0a, 0x0f, 0x53, 0x75, 0x62, 0x73, - 0x63, 0x72, 0x69, 0x62, 0x65, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x12, 0x18, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, - 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, - 0x0a, 0x09, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x12, 0x18, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, - 0x65, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x33, + 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x63, 0x6c, 0x65, 0x61, 0x6e, 0x65, 0x64, 0x5f, 0x73, 0x74, 0x61, + 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x63, 0x6c, 0x65, 0x61, 0x6e, + 0x65, 0x64, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x45, 0x0a, 0x12, 0x44, 0x65, 0x6c, 0x65, + 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, + 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x74, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x09, 0x73, 0x74, 0x61, 0x74, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, + 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, + 0x3c, 0x0a, 0x13, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, + 0x64, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, + 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x3b, 0x0a, + 0x1f, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, + 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x22, 0x0a, 0x20, 0x53, 0x65, + 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, + 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x76, + 0x0a, 0x08, 0x54, 0x43, 0x50, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x73, 0x79, + 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x73, 0x79, 0x6e, 0x12, 0x10, 0x0a, 0x03, + 0x61, 0x63, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x63, 0x6b, 0x12, 0x10, + 0x0a, 0x03, 0x66, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x66, 0x69, 0x6e, + 0x12, 0x10, 0x0a, 0x03, 0x72, 0x73, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x72, + 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x70, 0x73, 0x68, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x03, 0x70, 0x73, 0x68, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x03, 0x75, 0x72, 0x67, 0x22, 0x80, 0x03, 0x0a, 0x12, 0x54, 0x72, 0x61, 0x63, 0x65, + 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1b, 0x0a, + 0x09, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x08, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x70, 0x12, 0x25, 0x0a, 0x0e, 0x64, 0x65, + 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x70, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x0d, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x49, + 0x70, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1f, 0x0a, + 0x0b, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x0d, 0x52, 0x0a, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x29, + 0x0a, 0x10, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x70, 0x6f, + 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x64, 0x69, 0x72, + 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x64, 0x69, + 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x32, 0x0a, 0x09, 0x74, 0x63, 0x70, 0x5f, 0x66, + 0x6c, 0x61, 0x67, 0x73, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x54, 0x43, 0x50, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x48, 0x00, 0x52, 0x08, + 0x74, 0x63, 0x70, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x88, 0x01, 0x01, 0x12, 0x20, 0x0a, 0x09, 0x69, + 0x63, 0x6d, 0x70, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x01, + 0x52, 0x08, 0x69, 0x63, 0x6d, 0x70, 0x54, 0x79, 0x70, 0x65, 0x88, 0x01, 0x01, 0x12, 0x20, 0x0a, + 0x09, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0d, + 0x48, 0x02, 0x52, 0x08, 0x69, 0x63, 0x6d, 0x70, 0x43, 0x6f, 0x64, 0x65, 0x88, 0x01, 0x01, 0x42, + 0x0c, 0x0a, 0x0a, 0x5f, 0x74, 0x63, 0x70, 0x5f, 0x66, 0x6c, 0x61, 0x67, 0x73, 0x42, 0x0c, 0x0a, + 0x0a, 0x5f, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x42, 0x0c, 0x0a, 0x0a, 0x5f, + 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x22, 0x9f, 0x01, 0x0a, 0x0a, 0x54, 0x72, + 0x61, 0x63, 0x65, 0x53, 0x74, 0x61, 0x67, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x18, 0x0a, 0x07, + 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, + 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, + 0x12, 0x32, 0x0a, 0x12, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x5f, 0x64, + 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x11, + 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, + 0x73, 0x88, 0x01, 0x01, 0x42, 0x15, 0x0a, 0x13, 0x5f, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, + 0x69, 0x6e, 0x67, 0x5f, 0x64, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x22, 0x6e, 0x0a, 0x13, 0x54, + 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x12, 0x2a, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x54, 0x72, 0x61, 0x63, + 0x65, 0x53, 0x74, 0x61, 0x67, 0x65, 0x52, 0x06, 0x73, 0x74, 0x61, 0x67, 0x65, 0x73, 0x12, 0x2b, + 0x0a, 0x11, 0x66, 0x69, 0x6e, 0x61, 0x6c, 0x5f, 0x64, 0x69, 0x73, 0x70, 0x6f, 0x73, 0x69, 0x74, + 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x66, 0x69, 0x6e, 0x61, 0x6c, + 0x44, 0x69, 0x73, 0x70, 0x6f, 0x73, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x12, 0x0a, 0x10, 0x53, + 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, + 0x93, 0x04, 0x0a, 0x0b, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x12, + 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, + 0x38, 0x0a, 0x08, 0x73, 0x65, 0x76, 0x65, 0x72, 0x69, 0x74, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x0e, 0x32, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x79, 0x73, 0x74, 0x65, + 0x6d, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x76, 0x65, 0x72, 0x69, 0x74, 0x79, 0x52, + 0x08, 0x73, 0x65, 0x76, 0x65, 0x72, 0x69, 0x74, 0x79, 0x12, 0x38, 0x0a, 0x08, 0x63, 0x61, 0x74, + 0x65, 0x67, 0x6f, 0x72, 0x79, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1c, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, 0x76, 0x65, 0x6e, 0x74, + 0x2e, 0x43, 0x61, 0x74, 0x65, 0x67, 0x6f, 0x72, 0x79, 0x52, 0x08, 0x63, 0x61, 0x74, 0x65, 0x67, + 0x6f, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x20, 0x0a, + 0x0b, 0x75, 0x73, 0x65, 0x72, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x05, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x0b, 0x75, 0x73, 0x65, 0x72, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, + 0x38, 0x0a, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x06, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, + 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x12, 0x3d, 0x0a, 0x08, 0x6d, 0x65, 0x74, + 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, 0x76, 0x65, 0x6e, 0x74, + 0x2e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x08, + 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x1a, 0x3b, 0x0a, 0x0d, 0x4d, 0x65, 0x74, 0x61, + 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, + 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, + 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x3a, 0x0a, 0x08, 0x53, 0x65, 0x76, 0x65, 0x72, 0x69, 0x74, + 0x79, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x00, 0x12, 0x0b, 0x0a, 0x07, 0x57, + 0x41, 0x52, 0x4e, 0x49, 0x4e, 0x47, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, + 0x52, 0x10, 0x02, 0x12, 0x0c, 0x0a, 0x08, 0x43, 0x52, 0x49, 0x54, 0x49, 0x43, 0x41, 0x4c, 0x10, + 0x03, 0x22, 0x52, 0x0a, 0x08, 0x43, 0x61, 0x74, 0x65, 0x67, 0x6f, 0x72, 0x79, 0x12, 0x0b, 0x0a, + 0x07, 0x4e, 0x45, 0x54, 0x57, 0x4f, 0x52, 0x4b, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x44, 0x4e, + 0x53, 0x10, 0x01, 0x12, 0x12, 0x0a, 0x0e, 0x41, 0x55, 0x54, 0x48, 0x45, 0x4e, 0x54, 0x49, 0x43, + 0x41, 0x54, 0x49, 0x4f, 0x4e, 0x10, 0x02, 0x12, 0x10, 0x0a, 0x0c, 0x43, 0x4f, 0x4e, 0x4e, 0x45, + 0x43, 0x54, 0x49, 0x56, 0x49, 0x54, 0x59, 0x10, 0x03, 0x12, 0x0a, 0x0a, 0x06, 0x53, 0x59, 0x53, + 0x54, 0x45, 0x4d, 0x10, 0x04, 0x22, 0x12, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65, 0x6e, + 0x74, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x40, 0x0a, 0x11, 0x47, 0x65, 0x74, + 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x2b, + 0x0a, 0x06, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x13, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, 0x76, + 0x65, 0x6e, 0x74, 0x52, 0x06, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x2a, 0x62, 0x0a, 0x08, 0x4c, + 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, + 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, 0x49, 0x43, 0x10, 0x01, 0x12, + 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, + 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, 0x52, 0x4e, 0x10, 0x04, 0x12, + 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, 0x0a, 0x05, 0x44, 0x45, 0x42, + 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, 0x43, 0x45, 0x10, 0x07, 0x32, + 0xb3, 0x0b, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, + 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, + 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, + 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, + 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, + 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x4c, 0x69, 0x73, + 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x51, 0x0a, 0x0e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, + 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x53, 0x0a, 0x10, 0x44, 0x65, 0x73, + 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1d, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, + 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, + 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4a, + 0x0a, 0x0f, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, + 0x73, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1f, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, + 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, + 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, + 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, + 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, + 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, + 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, + 0x0a, 0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, + 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, + 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, + 0x45, 0x0a, 0x0a, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x19, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, + 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, + 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, + 0x12, 0x6f, 0x0a, 0x18, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, + 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, + 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x28, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, + 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, + 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, + 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, + 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, + 0x61, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, + 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x44, 0x0a, 0x0f, 0x53, + 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x12, 0x18, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, + 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x22, 0x00, 0x30, + 0x01, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x12, 0x18, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, + 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, + 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 49e577853..6c63a8f9b 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -336,10 +336,13 @@ message DebugBundleRequest { bool anonymize = 1; string status = 2; bool systemInfo = 3; + string uploadURL = 4; } message DebugBundleResponse { string path = 1; + string uploadedKey = 2; + string uploadFailureReason = 3; } enum LogLevel { diff --git a/client/server/debug.go b/client/server/debug.go index 9ccfb13fb..7de3e8609 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -4,16 +4,24 @@ package server import ( "context" + "crypto/sha256" + "encoding/json" "errors" "fmt" + "io" + "net/http" + "os" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/debug" "github.com/netbirdio/netbird/client/proto" mgmProto "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/upload-server/types" ) +const maxBundleUploadSize = 50 * 1024 * 1024 + // DebugBundle creates a debug bundle and returns the location. func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) { s.mutex.Lock() @@ -42,7 +50,104 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) ( return nil, fmt.Errorf("generate debug bundle: %w", err) } - return &proto.DebugBundleResponse{Path: path}, nil + if req.GetUploadURL() == "" { + return &proto.DebugBundleResponse{Path: path}, nil + } + key, err := uploadDebugBundle(context.Background(), req.GetUploadURL(), s.config.ManagementURL.String(), path) + if err != nil { + log.Errorf("failed to upload debug bundle to %s: %v", req.GetUploadURL(), err) + return &proto.DebugBundleResponse{Path: path, UploadFailureReason: err.Error()}, nil + } + + log.Infof("debug bundle uploaded to %s with key %s", req.GetUploadURL(), key) + + return &proto.DebugBundleResponse{Path: path, UploadedKey: key}, nil +} + +func uploadDebugBundle(ctx context.Context, url, managementURL, filePath string) (key string, err error) { + response, err := getUploadURL(ctx, url, managementURL) + if err != nil { + return "", err + } + + err = upload(ctx, filePath, response) + if err != nil { + return "", err + } + return response.Key, nil +} + +func upload(ctx context.Context, filePath string, response *types.GetURLResponse) error { + fileData, err := os.Open(filePath) + if err != nil { + return fmt.Errorf("open file: %w", err) + } + + defer fileData.Close() + + stat, err := fileData.Stat() + if err != nil { + return fmt.Errorf("stat file: %w", err) + } + + if stat.Size() > maxBundleUploadSize { + return fmt.Errorf("file size exceeds maximum limit of %d bytes", maxBundleUploadSize) + } + + req, err := http.NewRequestWithContext(ctx, "PUT", response.URL, fileData) + if err != nil { + return fmt.Errorf("create PUT request: %w", err) + } + + req.ContentLength = stat.Size() + req.Header.Set("Content-Type", "application/octet-stream") + + putResp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("upload failed: %v", err) + } + defer putResp.Body.Close() + + if putResp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(putResp.Body) + return fmt.Errorf("upload status %d: %s", putResp.StatusCode, string(body)) + } + return nil +} + +func getUploadURL(ctx context.Context, url string, managementURL string) (*types.GetURLResponse, error) { + id := getURLHash(managementURL) + getReq, err := http.NewRequestWithContext(ctx, "GET", url+"?id="+id, nil) + if err != nil { + return nil, fmt.Errorf("create GET request: %w", err) + } + + getReq.Header.Set(types.ClientHeader, types.ClientHeaderValue) + + resp, err := http.DefaultClient.Do(getReq) + if err != nil { + return nil, fmt.Errorf("get presigned URL: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("get presigned URL status %d: %s", resp.StatusCode, string(body)) + } + + urlBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response body: %w", err) + } + var response types.GetURLResponse + if err := json.Unmarshal(urlBytes, &response); err != nil { + return nil, fmt.Errorf("unmarshal response: %w", err) + } + return &response, nil +} + +func getURLHash(url string) string { + return fmt.Sprintf("%x", sha256.Sum256([]byte(url))) } // GetLogLevel gets the current logging level for the server. diff --git a/client/server/debug_test.go b/client/server/debug_test.go new file mode 100644 index 000000000..53d9ac8ed --- /dev/null +++ b/client/server/debug_test.go @@ -0,0 +1,49 @@ +package server + +import ( + "context" + "errors" + "net/http" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/upload-server/server" + "github.com/netbirdio/netbird/upload-server/types" +) + +func TestUpload(t *testing.T) { + if os.Getenv("DOCKER_CI") == "true" { + t.Skip("Skipping upload test on docker ci") + } + testDir := t.TempDir() + testURL := "http://localhost:8080" + t.Setenv("SERVER_URL", testURL) + t.Setenv("STORE_DIR", testDir) + srv := server.NewServer() + go func() { + if err := srv.Start(); err != nil && !errors.Is(err, http.ErrServerClosed) { + t.Errorf("Failed to start server: %v", err) + } + }() + t.Cleanup(func() { + if err := srv.Stop(); err != nil { + t.Errorf("Failed to stop server: %v", err) + } + }) + + file := filepath.Join(t.TempDir(), "tmpfile") + fileContent := []byte("test file content") + err := os.WriteFile(file, fileContent, 0640) + require.NoError(t, err) + key, err := uploadDebugBundle(context.Background(), testURL+types.GetURLPath, testURL, file) + require.NoError(t, err) + id := getURLHash(testURL) + require.Contains(t, key, id+"/") + expectedFilePath := filepath.Join(testDir, key) + createdFileContent, err := os.ReadFile(expectedFilePath) + require.NoError(t, err) + require.Equal(t, fileContent, createdFileContent) +} diff --git a/client/server/network.go b/client/server/network.go index e0b01f763..93b7caa46 100644 --- a/client/server/network.go +++ b/client/server/network.go @@ -100,7 +100,7 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro // Convert to proto format for domain, ips := range domainMap { - pbRoute.ResolvedIPs[domain.PunycodeString()] = &proto.IPList{ + pbRoute.ResolvedIPs[domain.SafeString()] = &proto.IPList{ Ips: ips, } } diff --git a/client/status/status.go b/client/status/status.go index 43acc9197..f37e5b0f0 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/version" ) @@ -414,7 +415,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, signalConnString, relaysString, dnsServersString, - overview.FQDN, + domain.Domain(overview.FQDN).SafeString(), interfaceIP, interfaceTypeString, rosenpassEnabledStatus, @@ -508,7 +509,7 @@ func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bo " Quantum resistance: %s\n"+ " Networks: %s\n"+ " Latency: %s\n", - peerState.FQDN, + domain.Domain(peerState.FQDN).SafeString(), peerState.IP, peerState.PubKey, peerState.Status, diff --git a/client/system/info.go b/client/system/info.go index 2a0343ca6..3a0c57156 100644 --- a/client/system/info.go +++ b/client/system/info.go @@ -185,3 +185,10 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, erro return info, nil } + +// UpdateStaticInfo asynchronously updates static system and platform information +func UpdateStaticInfo() { + go func() { + _ = updateStaticInfo() + }() +} diff --git a/client/system/static_info.go b/client/system/static_info.go index fabe65a68..f178ec932 100644 --- a/client/system/static_info.go +++ b/client/system/static_info.go @@ -16,12 +16,6 @@ var ( once sync.Once ) -func init() { - go func() { - _ = updateStaticInfo() - }() -} - func updateStaticInfo() StaticInfo { once.Do(func() { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) diff --git a/client/system/static_info_stub.go b/client/system/static_info_stub.go new file mode 100644 index 000000000..faa3e700b --- /dev/null +++ b/client/system/static_info_stub.go @@ -0,0 +1,8 @@ +//go:build android || freebsd || ios + +package system + +// updateStaticInfo returns an empty implementation for unsupported platforms +func updateStaticInfo() StaticInfo { + return StaticInfo{} +} diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index b2a6404bb..2c8023185 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -51,14 +51,17 @@ const ( ) func main() { - daemonAddr, showSettings, showNetworks, errorMsg, saveLogsInFile := parseFlags() + daemonAddr, showSettings, showNetworks, showDebug, errorMsg, saveLogsInFile := parseFlags() // Initialize file logging if needed. + var logFile string if saveLogsInFile { - if err := initLogFile(); err != nil { + file, err := initLogFile() + if err != nil { log.Errorf("error while initializing log: %v", err) return } + logFile = file } // Create the Fyne application. @@ -72,13 +75,13 @@ func main() { } // Create the service client (this also builds the settings or networks UI if requested). - client := newServiceClient(daemonAddr, a, showSettings, showNetworks) + client := newServiceClient(daemonAddr, logFile, a, showSettings, showNetworks, showDebug) // Watch for theme/settings changes to update the icon. go watchSettingsChanges(a, client) // Run in window mode if any UI flag was set. - if showSettings || showNetworks { + if showSettings || showNetworks || showDebug { a.Run() return } @@ -99,7 +102,7 @@ func main() { } // parseFlags reads and returns all needed command-line flags. -func parseFlags() (daemonAddr string, showSettings, showNetworks bool, errorMsg string, saveLogsInFile bool) { +func parseFlags() (daemonAddr string, showSettings, showNetworks, showDebug bool, errorMsg string, saveLogsInFile bool) { defaultDaemonAddr := "unix:///var/run/netbird.sock" if runtime.GOOS == "windows" { defaultDaemonAddr = "tcp://127.0.0.1:41731" @@ -107,25 +110,17 @@ func parseFlags() (daemonAddr string, showSettings, showNetworks bool, errorMsg flag.StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]") flag.BoolVar(&showSettings, "settings", false, "run settings window") flag.BoolVar(&showNetworks, "networks", false, "run networks window") + flag.BoolVar(&showDebug, "debug", false, "run debug window") flag.StringVar(&errorMsg, "error-msg", "", "displays an error message window") - - tmpDir := "/tmp" - if runtime.GOOS == "windows" { - tmpDir = os.TempDir() - } - flag.BoolVar(&saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", tmpDir)) + flag.BoolVar(&saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", os.TempDir())) flag.Parse() return } // initLogFile initializes logging into a file. -func initLogFile() error { - tmpDir := "/tmp" - if runtime.GOOS == "windows" { - tmpDir = os.TempDir() - } - logFile := path.Join(tmpDir, fmt.Sprintf("netbird-ui-%d.log", os.Getpid())) - return util.InitLog("trace", logFile) +func initLogFile() (string, error) { + logFile := path.Join(os.TempDir(), fmt.Sprintf("netbird-ui-%d.log", os.Getpid())) + return logFile, util.InitLog("trace", logFile) } // watchSettingsChanges listens for Fyne theme/settings changes and updates the client icon. @@ -168,9 +163,10 @@ var iconConnectingMacOS []byte var iconErrorMacOS []byte type serviceClient struct { - ctx context.Context - addr string - conn proto.DaemonServiceClient + ctx context.Context + cancel context.CancelFunc + addr string + conn proto.DaemonServiceClient icAbout []byte icConnected []byte @@ -231,13 +227,14 @@ type serviceClient struct { daemonVersion string updateIndicationLock sync.Mutex isUpdateIconActive bool - showRoutes bool - wRoutes fyne.Window + showNetworks bool + wNetworks fyne.Window eventManager *event.Manager exitNodeMu sync.Mutex mExitNodeItems []menuHandler + logFile string } type menuHandler struct { @@ -248,25 +245,30 @@ type menuHandler struct { // newServiceClient instance constructor // // This constructor also builds the UI elements for the settings window. -func newServiceClient(addr string, a fyne.App, showSettings bool, showRoutes bool) *serviceClient { +func newServiceClient(addr string, logFile string, a fyne.App, showSettings bool, showNetworks bool, showDebug bool) *serviceClient { + ctx, cancel := context.WithCancel(context.Background()) s := &serviceClient{ - ctx: context.Background(), + ctx: ctx, + cancel: cancel, addr: addr, app: a, + logFile: logFile, sendNotification: false, showAdvancedSettings: showSettings, - showRoutes: showRoutes, + showNetworks: showNetworks, update: version.NewUpdate(), } s.setNewIcons() - if showSettings { + switch { + case showSettings: s.showSettingsUI() - return s - } else if showRoutes { + case showNetworks: s.showNetworksUI() + case showDebug: + s.showDebugUI() } return s @@ -313,6 +315,8 @@ func (s *serviceClient) updateIcon() { func (s *serviceClient) showSettingsUI() { // add settings window UI elements. s.wSettings = s.app.NewWindow("NetBird Settings") + s.wSettings.SetOnClosed(s.cancel) + s.iMngURL = widget.NewEntry() s.iAdminURL = widget.NewEntry() s.iConfigFile = widget.NewEntry() @@ -457,7 +461,7 @@ func (s *serviceClient) menuUpClick() error { if status.Status == string(internal.StatusConnected) { log.Warnf("already connected") - return err + return nil } if _, err := s.conn.Up(s.ctx, &proto.UpRequest{}); err != nil { @@ -482,7 +486,7 @@ func (s *serviceClient) menuDownClick() error { return err } - if status.Status != string(internal.StatusConnected) { + if status.Status != string(internal.StatusConnected) && status.Status != string(internal.StatusConnecting) { log.Warnf("already down") return nil } @@ -520,7 +524,9 @@ func (s *serviceClient) updateStatus() error { } var systrayIconState bool - if status.Status == string(internal.StatusConnected) && !s.mUp.Disabled() { + + switch { + case status.Status == string(internal.StatusConnected): s.connected = true s.sendNotification = true if s.isUpdateIconActive { @@ -535,7 +541,9 @@ func (s *serviceClient) updateStatus() error { s.mNetworks.Enable() go s.updateExitNodes() systrayIconState = true - } else if status.Status != string(internal.StatusConnected) && s.mUp.Disabled() { + case status.Status == string(internal.StatusConnecting): + s.setConnectingStatus() + case status.Status != string(internal.StatusConnected) && s.mUp.Disabled(): s.setDisconnectedStatus() systrayIconState = false } @@ -594,6 +602,17 @@ func (s *serviceClient) setDisconnectedStatus() { go s.updateExitNodes() } +func (s *serviceClient) setConnectingStatus() { + s.connected = false + systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting) + systray.SetTooltip("NetBird (Connecting)") + s.mStatus.SetTitle("Connecting") + s.mUp.Disable() + s.mDown.Enable() + s.mNetworks.Disable() + s.mExitNode.Disable() +} + func (s *serviceClient) onTrayReady() { systray.SetTemplateIcon(iconDisconnectedMacOS, s.icDisconnected) systray.SetTooltip("NetBird") @@ -728,11 +747,10 @@ func (s *serviceClient) onTrayReady() { s.runSelfCommand("settings", "true") }() case <-s.mCreateDebugBundle.ClickedCh: + s.mCreateDebugBundle.Disable() go func() { - if err := s.createAndOpenDebugBundle(); err != nil { - log.Errorf("Failed to create debug bundle: %v", err) - s.app.SendNotification(fyne.NewNotification("Error", "Failed to create debug bundle")) - } + defer s.mCreateDebugBundle.Enable() + s.runSelfCommand("debug", "true") }() case <-s.mQuit.ClickedCh: systray.Quit() @@ -774,7 +792,7 @@ func (s *serviceClient) onTrayReady() { func (s *serviceClient) runSelfCommand(command, arg string) { proc, err := os.Executable() if err != nil { - log.Errorf("show %s failed with error: %v", command, err) + log.Errorf("Error getting executable path: %v", err) return } @@ -783,14 +801,48 @@ func (s *serviceClient) runSelfCommand(command, arg string) { fmt.Sprintf("--daemon-addr=%s", s.addr), ) - out, err := cmd.CombinedOutput() - if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 { - log.Errorf("start %s UI: %v, %s", command, err, string(out)) + if out := s.attachOutput(cmd); out != nil { + defer func() { + if err := out.Close(); err != nil { + log.Errorf("Error closing log file %s: %v", s.logFile, err) + } + }() + } + + log.Printf("Running command: %s --%s=%s --daemon-addr=%s", proc, command, arg, s.addr) + + err = cmd.Run() + + if err != nil { + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + log.Printf("Command '%s %s' failed with exit code %d", command, arg, exitErr.ExitCode()) + } else { + log.Printf("Failed to start/run command '%s %s': %v", command, arg, err) + } return } - if len(out) != 0 { - log.Infof("command %s executed: %s", command, string(out)) + + log.Printf("Command '%s %s' completed successfully.", command, arg) +} + +func (s *serviceClient) attachOutput(cmd *exec.Cmd) *os.File { + if s.logFile == "" { + // attach child's streams to parent's streams + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + return nil } + + out, err := os.OpenFile(s.logFile, os.O_WRONLY|os.O_APPEND, 0) + if err != nil { + log.Errorf("Failed to open log file %s: %v", s.logFile, err) + return nil + } + cmd.Stdout = out + cmd.Stderr = out + return out } func normalizedVersion(version string) string { @@ -803,9 +855,7 @@ func normalizedVersion(version string) string { // onTrayExit is called when the tray icon is closed. func (s *serviceClient) onTrayExit() { - for _, item := range s.mExitNodeItems { - item.cancel() - } + s.cancel() } // getSrvClient connection to the service. @@ -814,7 +864,7 @@ func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonService return s.conn, nil } - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(s.ctx, timeout) defer cancel() conn, err := grpc.DialContext( diff --git a/client/ui/debug.go b/client/ui/debug.go index 845ea284c..ab7dba37a 100644 --- a/client/ui/debug.go +++ b/client/ui/debug.go @@ -3,48 +3,721 @@ package main import ( + "context" "fmt" "path/filepath" + "strconv" + "sync" + "time" "fyne.io/fyne/v2" + "fyne.io/fyne/v2/container" + "fyne.io/fyne/v2/dialog" + "fyne.io/fyne/v2/widget" + log "github.com/sirupsen/logrus" "github.com/skratchdot/open-golang/open" + "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/proto" nbstatus "github.com/netbirdio/netbird/client/status" + uptypes "github.com/netbirdio/netbird/upload-server/types" ) -func (s *serviceClient) createAndOpenDebugBundle() error { +// Initial state for the debug collection +type debugInitialState struct { + wasDown bool + logLevel proto.LogLevel + isLevelTrace bool +} + +// Debug collection parameters +type debugCollectionParams struct { + duration time.Duration + anonymize bool + systemInfo bool + upload bool + uploadURL string + enablePersistence bool +} + +// UI components for progress tracking +type progressUI struct { + statusLabel *widget.Label + progressBar *widget.ProgressBar + uiControls []fyne.Disableable + window fyne.Window +} + +func (s *serviceClient) showDebugUI() { + w := s.app.NewWindow("NetBird Debug") + w.SetOnClosed(s.cancel) + + w.Resize(fyne.NewSize(600, 500)) + w.SetFixedSize(true) + + anonymizeCheck := widget.NewCheck("Anonymize sensitive information (public IPs, domains, ...)", nil) + systemInfoCheck := widget.NewCheck("Include system information (routes, interfaces, ...)", nil) + systemInfoCheck.SetChecked(true) + uploadCheck := widget.NewCheck("Upload bundle automatically after creation", nil) + uploadCheck.SetChecked(true) + + uploadURLLabel := widget.NewLabel("Debug upload URL:") + uploadURL := widget.NewEntry() + uploadURL.SetText(uptypes.DefaultBundleURL) + uploadURL.SetPlaceHolder("Enter upload URL") + + uploadURLContainer := container.NewVBox( + uploadURLLabel, + uploadURL, + ) + + uploadCheck.OnChanged = func(checked bool) { + if checked { + uploadURLContainer.Show() + } else { + uploadURLContainer.Hide() + } + } + + debugModeContainer := container.NewHBox() + runForDurationCheck := widget.NewCheck("Run with trace logs before creating bundle", nil) + runForDurationCheck.SetChecked(true) + + forLabel := widget.NewLabel("for") + + durationInput := widget.NewEntry() + durationInput.SetText("1") + minutesLabel := widget.NewLabel("minute") + durationInput.Validator = func(s string) error { + return validateMinute(s, minutesLabel) + } + + noteLabel := widget.NewLabel("Note: NetBird will be brought up and down during collection") + + runForDurationCheck.OnChanged = func(checked bool) { + if checked { + forLabel.Show() + durationInput.Show() + minutesLabel.Show() + noteLabel.Show() + } else { + forLabel.Hide() + durationInput.Hide() + minutesLabel.Hide() + noteLabel.Hide() + } + } + + debugModeContainer.Add(runForDurationCheck) + debugModeContainer.Add(forLabel) + debugModeContainer.Add(durationInput) + debugModeContainer.Add(minutesLabel) + + statusLabel := widget.NewLabel("") + statusLabel.Hide() + + progressBar := widget.NewProgressBar() + progressBar.Hide() + + createButton := widget.NewButton("Create Debug Bundle", nil) + + // UI controls that should be disabled during debug collection + uiControls := []fyne.Disableable{ + anonymizeCheck, + systemInfoCheck, + uploadCheck, + uploadURL, + runForDurationCheck, + durationInput, + createButton, + } + + createButton.OnTapped = s.getCreateHandler( + statusLabel, + progressBar, + uploadCheck, + uploadURL, + anonymizeCheck, + systemInfoCheck, + runForDurationCheck, + durationInput, + uiControls, + w, + ) + + content := container.NewVBox( + widget.NewLabel("Create a debug bundle to help troubleshoot issues with NetBird"), + widget.NewLabel(""), + anonymizeCheck, + systemInfoCheck, + uploadCheck, + uploadURLContainer, + widget.NewLabel(""), + debugModeContainer, + noteLabel, + widget.NewLabel(""), + statusLabel, + progressBar, + createButton, + ) + + paddedContent := container.NewPadded(content) + w.SetContent(paddedContent) + + w.Show() +} + +func validateMinute(s string, minutesLabel *widget.Label) error { + if val, err := strconv.Atoi(s); err != nil || val < 1 { + return fmt.Errorf("must be a number ≥ 1") + } + if s == "1" { + minutesLabel.SetText("minute") + } else { + minutesLabel.SetText("minutes") + } + return nil +} + +// disableUIControls disables the provided UI controls +func disableUIControls(controls []fyne.Disableable) { + for _, control := range controls { + control.Disable() + } +} + +// enableUIControls enables the provided UI controls +func enableUIControls(controls []fyne.Disableable) { + for _, control := range controls { + control.Enable() + } +} + +func (s *serviceClient) getCreateHandler( + statusLabel *widget.Label, + progressBar *widget.ProgressBar, + uploadCheck *widget.Check, + uploadURL *widget.Entry, + anonymizeCheck *widget.Check, + systemInfoCheck *widget.Check, + runForDurationCheck *widget.Check, + duration *widget.Entry, + uiControls []fyne.Disableable, + w fyne.Window, +) func() { + return func() { + disableUIControls(uiControls) + statusLabel.Show() + + var url string + if uploadCheck.Checked { + url = uploadURL.Text + if url == "" { + statusLabel.SetText("Error: Upload URL is required when upload is enabled") + enableUIControls(uiControls) + return + } + } + + params := &debugCollectionParams{ + anonymize: anonymizeCheck.Checked, + systemInfo: systemInfoCheck.Checked, + upload: uploadCheck.Checked, + uploadURL: url, + enablePersistence: true, + } + + runForDuration := runForDurationCheck.Checked + if runForDuration { + minutes, err := time.ParseDuration(duration.Text + "m") + if err != nil { + statusLabel.SetText(fmt.Sprintf("Error: Invalid duration: %v", err)) + enableUIControls(uiControls) + return + } + params.duration = minutes + + statusLabel.SetText(fmt.Sprintf("Running in debug mode for %d minutes...", int(minutes.Minutes()))) + progressBar.Show() + progressBar.SetValue(0) + + go s.handleRunForDuration( + statusLabel, + progressBar, + uiControls, + w, + params, + ) + return + } + + statusLabel.SetText("Creating debug bundle...") + go s.handleDebugCreation( + anonymizeCheck.Checked, + systemInfoCheck.Checked, + uploadCheck.Checked, + url, + statusLabel, + uiControls, + w, + ) + } +} + +func (s *serviceClient) handleRunForDuration( + statusLabel *widget.Label, + progressBar *widget.ProgressBar, + uiControls []fyne.Disableable, + w fyne.Window, + params *debugCollectionParams, +) { + progressUI := &progressUI{ + statusLabel: statusLabel, + progressBar: progressBar, + uiControls: uiControls, + window: w, + } + conn, err := s.getSrvClient(failFastTimeout) if err != nil { - return fmt.Errorf("get client: %v", err) + handleError(progressUI, fmt.Sprintf("Failed to get client for debug: %v", err)) + return + } + + initialState, err := s.getInitialState(conn) + if err != nil { + handleError(progressUI, err.Error()) + return + } + + statusOutput, err := s.collectDebugData(conn, initialState, params, progressUI) + if err != nil { + handleError(progressUI, err.Error()) + return + } + + if err := s.createDebugBundleFromCollection(conn, params, statusOutput, progressUI); err != nil { + handleError(progressUI, err.Error()) + return + } + + s.restoreServiceState(conn, initialState) + + progressUI.statusLabel.SetText("Bundle created successfully") +} + +// Get initial state of the service +func (s *serviceClient) getInitialState(conn proto.DaemonServiceClient) (*debugInitialState, error) { + statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{}) + if err != nil { + return nil, fmt.Errorf(" get status: %v", err) + } + + logLevelResp, err := conn.GetLogLevel(s.ctx, &proto.GetLogLevelRequest{}) + if err != nil { + return nil, fmt.Errorf("get log level: %v", err) + } + + wasDown := statusResp.Status != string(internal.StatusConnected) && + statusResp.Status != string(internal.StatusConnecting) + + initialLogLevel := logLevelResp.GetLevel() + initialLevelTrace := initialLogLevel >= proto.LogLevel_TRACE + + return &debugInitialState{ + wasDown: wasDown, + logLevel: initialLogLevel, + isLevelTrace: initialLevelTrace, + }, nil +} + +// Handle progress tracking during collection +func startProgressTracker(ctx context.Context, wg *sync.WaitGroup, duration time.Duration, progress *progressUI) { + progress.progressBar.Show() + progress.progressBar.SetValue(0) + + startTime := time.Now() + endTime := startTime.Add(duration) + wg.Add(1) + + go func() { + defer wg.Done() + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + remaining := time.Until(endTime) + if remaining <= 0 { + remaining = 0 + } + + elapsed := time.Since(startTime) + progressVal := float64(elapsed) / float64(duration) + if progressVal > 1.0 { + progressVal = 1.0 + } + + progress.progressBar.SetValue(progressVal) + progress.statusLabel.SetText(fmt.Sprintf("Running with trace logs... %s remaining", formatDuration(remaining))) + } + } + }() + +} + +func (s *serviceClient) configureServiceForDebug( + conn proto.DaemonServiceClient, + state *debugInitialState, + enablePersistence bool, +) error { + if state.wasDown { + if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil { + return fmt.Errorf("bring service up: %v", err) + } + log.Info("Service brought up for debug") + time.Sleep(time.Second * 10) + } + + if !state.isLevelTrace { + if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: proto.LogLevel_TRACE}); err != nil { + return fmt.Errorf("set log level to TRACE: %v", err) + } + log.Info("Log level set to TRACE for debug") + } + + if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil { + return fmt.Errorf("bring service down: %v", err) + } + time.Sleep(time.Second) + + if enablePersistence { + if _, err := conn.SetNetworkMapPersistence(s.ctx, &proto.SetNetworkMapPersistenceRequest{ + Enabled: true, + }); err != nil { + return fmt.Errorf("enable network map persistence: %v", err) + } + log.Info("Network map persistence enabled for debug") + } + + if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil { + return fmt.Errorf("bring service back up: %v", err) + } + time.Sleep(time.Second * 3) + + return nil +} + +func (s *serviceClient) collectDebugData( + conn proto.DaemonServiceClient, + state *debugInitialState, + params *debugCollectionParams, + progress *progressUI, +) (string, error) { + ctx, cancel := context.WithTimeout(s.ctx, params.duration) + defer cancel() + var wg sync.WaitGroup + startProgressTracker(ctx, &wg, params.duration, progress) + + if err := s.configureServiceForDebug(conn, state, params.enablePersistence); err != nil { + return "", err + } + + postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) + if err != nil { + log.Warnf("Failed to get post-up status: %v", err) + } + + var postUpStatusOutput string + if postUpStatus != nil { + overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil) + postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview) + } + headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339)) + statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, postUpStatusOutput) + + wg.Wait() + progress.progressBar.Hide() + progress.statusLabel.SetText("Collecting debug data...") + + preDownStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) + if err != nil { + log.Warnf("Failed to get pre-down status: %v", err) + } + + var preDownStatusOutput string + if preDownStatus != nil { + overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil) + preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview) + } + headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", + time.Now().Format(time.RFC3339), params.duration) + statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, preDownStatusOutput) + + return statusOutput, nil +} + +// Create the debug bundle with collected data +func (s *serviceClient) createDebugBundleFromCollection( + conn proto.DaemonServiceClient, + params *debugCollectionParams, + statusOutput string, + progress *progressUI, +) error { + progress.statusLabel.SetText("Creating debug bundle with collected logs...") + + request := &proto.DebugBundleRequest{ + Anonymize: params.anonymize, + Status: statusOutput, + SystemInfo: params.systemInfo, + } + + if params.upload { + request.UploadURL = params.uploadURL + } + + resp, err := conn.DebugBundle(s.ctx, request) + if err != nil { + return fmt.Errorf("create debug bundle: %v", err) + } + + // Show appropriate dialog based on upload status + localPath := resp.GetPath() + uploadFailureReason := resp.GetUploadFailureReason() + uploadedKey := resp.GetUploadedKey() + + if params.upload { + if uploadFailureReason != "" { + showUploadFailedDialog(progress.window, localPath, uploadFailureReason) + } else { + showUploadSuccessDialog(progress.window, localPath, uploadedKey) + } + } else { + showBundleCreatedDialog(progress.window, localPath) + } + + enableUIControls(progress.uiControls) + return nil +} + +// Restore service to original state +func (s *serviceClient) restoreServiceState(conn proto.DaemonServiceClient, state *debugInitialState) { + if state.wasDown { + if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil { + log.Errorf("Failed to restore down state: %v", err) + } else { + log.Info("Service state restored to down") + } + } + + if !state.isLevelTrace { + if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: state.logLevel}); err != nil { + log.Errorf("Failed to restore log level: %v", err) + } else { + log.Info("Log level restored to original setting") + } + } +} + +// Handle errors during debug collection +func handleError(progress *progressUI, errMsg string) { + log.Errorf("%s", errMsg) + progress.statusLabel.SetText(errMsg) + progress.progressBar.Hide() + enableUIControls(progress.uiControls) +} + +func (s *serviceClient) handleDebugCreation( + anonymize bool, + systemInfo bool, + upload bool, + uploadURL string, + statusLabel *widget.Label, + uiControls []fyne.Disableable, + w fyne.Window, +) { + log.Infof("Creating debug bundle (Anonymized: %v, System Info: %v, Upload Attempt: %v)...", + anonymize, systemInfo, upload) + + resp, err := s.createDebugBundle(anonymize, systemInfo, uploadURL) + if err != nil { + log.Errorf("Failed to create debug bundle: %v", err) + statusLabel.SetText(fmt.Sprintf("Error creating bundle: %v", err)) + enableUIControls(uiControls) + return + } + + localPath := resp.GetPath() + uploadFailureReason := resp.GetUploadFailureReason() + uploadedKey := resp.GetUploadedKey() + + if upload { + if uploadFailureReason != "" { + showUploadFailedDialog(w, localPath, uploadFailureReason) + } else { + showUploadSuccessDialog(w, localPath, uploadedKey) + } + } else { + showBundleCreatedDialog(w, localPath) + } + + enableUIControls(uiControls) + statusLabel.SetText("Bundle created successfully") +} + +func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploadURL string) (*proto.DebugBundleResponse, error) { + conn, err := s.getSrvClient(failFastTimeout) + if err != nil { + return nil, fmt.Errorf("get client: %v", err) } statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) if err != nil { - return fmt.Errorf("failed to get status: %v", err) + log.Warnf("failed to get status for debug bundle: %v", err) } - overview := nbstatus.ConvertToStatusOutputOverview(statusResp, true, "", nil, nil, nil) - statusOutput := nbstatus.ParseToFullDetailSummary(overview) + var statusOutput string + if statusResp != nil { + overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil) + statusOutput = nbstatus.ParseToFullDetailSummary(overview) + } - resp, err := conn.DebugBundle(s.ctx, &proto.DebugBundleRequest{ - Anonymize: true, + request := &proto.DebugBundleRequest{ + Anonymize: anonymize, Status: statusOutput, - SystemInfo: true, - }) + SystemInfo: systemInfo, + } + + if uploadURL != "" { + request.UploadURL = uploadURL + } + + resp, err := conn.DebugBundle(s.ctx, request) if err != nil { - return fmt.Errorf("failed to create debug bundle: %v", err) + return nil, fmt.Errorf("failed to create debug bundle via daemon: %v", err) } - bundleDir := filepath.Dir(resp.GetPath()) - if err := open.Start(bundleDir); err != nil { - return fmt.Errorf("failed to open debug bundle directory: %v", err) - } - - s.app.SendNotification(fyne.NewNotification( - "Debug Bundle", - fmt.Sprintf("Debug bundle created at %s. Administrator privileges are required to access it.", resp.GetPath()), - )) - - return nil + return resp, nil +} + +// formatDuration formats a duration in HH:MM:SS format +func formatDuration(d time.Duration) string { + d = d.Round(time.Second) + h := d / time.Hour + d %= time.Hour + m := d / time.Minute + d %= time.Minute + s := d / time.Second + return fmt.Sprintf("%02d:%02d:%02d", h, m, s) +} + +// createButtonWithAction creates a button with the given label and action +func createButtonWithAction(label string, action func()) *widget.Button { + button := widget.NewButton(label, action) + return button +} + +// showUploadFailedDialog displays a dialog when upload fails +func showUploadFailedDialog(w fyne.Window, localPath, failureReason string) { + content := container.NewVBox( + widget.NewLabel(fmt.Sprintf("Bundle upload failed:\n%s\n\n"+ + "A local copy was saved at:\n%s", failureReason, localPath)), + ) + + customDialog := dialog.NewCustom("Upload Failed", "Cancel", content, w) + + buttonBox := container.NewHBox( + createButtonWithAction("Open file", func() { + log.Infof("Attempting to open local file: %s", localPath) + if openErr := open.Start(localPath); openErr != nil { + log.Errorf("Failed to open local file '%s': %v", localPath, openErr) + dialog.ShowError(fmt.Errorf("open the local file:\n%s\n\nError: %v", localPath, openErr), w) + } + }), + createButtonWithAction("Open folder", func() { + folderPath := filepath.Dir(localPath) + log.Infof("Attempting to open local folder: %s", folderPath) + if openErr := open.Start(folderPath); openErr != nil { + log.Errorf("Failed to open local folder '%s': %v", folderPath, openErr) + dialog.ShowError(fmt.Errorf("open the local folder:\n%s\n\nError: %v", folderPath, openErr), w) + } + }), + ) + + content.Add(buttonBox) + customDialog.Show() +} + +// showUploadSuccessDialog displays a dialog when upload succeeds +func showUploadSuccessDialog(w fyne.Window, localPath, uploadedKey string) { + log.Infof("Upload key: %s", uploadedKey) + keyEntry := widget.NewEntry() + keyEntry.SetText(uploadedKey) + keyEntry.Disable() + + content := container.NewVBox( + widget.NewLabel("Bundle uploaded successfully!"), + widget.NewLabel(""), + widget.NewLabel("Upload key:"), + keyEntry, + widget.NewLabel(""), + widget.NewLabel(fmt.Sprintf("Local copy saved at:\n%s", localPath)), + ) + + customDialog := dialog.NewCustom("Upload Successful", "OK", content, w) + + copyBtn := createButtonWithAction("Copy key", func() { + w.Clipboard().SetContent(uploadedKey) + log.Info("Upload key copied to clipboard") + }) + + buttonBox := createButtonBox(localPath, w, copyBtn) + content.Add(buttonBox) + customDialog.Show() +} + +// showBundleCreatedDialog displays a dialog when bundle is created without upload +func showBundleCreatedDialog(w fyne.Window, localPath string) { + content := container.NewVBox( + widget.NewLabel(fmt.Sprintf("Bundle created locally at:\n%s\n\n"+ + "Administrator privileges may be required to access the file.", localPath)), + ) + + customDialog := dialog.NewCustom("Debug Bundle Created", "Cancel", content, w) + + buttonBox := createButtonBox(localPath, w, nil) + content.Add(buttonBox) + customDialog.Show() +} + +func createButtonBox(localPath string, w fyne.Window, elems ...fyne.Widget) *fyne.Container { + box := container.NewHBox() + for _, elem := range elems { + box.Add(elem) + } + + fileBtn := createButtonWithAction("Open file", func() { + log.Infof("Attempting to open local file: %s", localPath) + if openErr := open.Start(localPath); openErr != nil { + log.Errorf("Failed to open local file '%s': %v", localPath, openErr) + dialog.ShowError(fmt.Errorf("open the local file:\n%s\n\nError: %v", localPath, openErr), w) + } + }) + + folderBtn := createButtonWithAction("Open folder", func() { + folderPath := filepath.Dir(localPath) + log.Infof("Attempting to open local folder: %s", folderPath) + if openErr := open.Start(folderPath); openErr != nil { + log.Errorf("Failed to open local folder '%s': %v", folderPath, openErr) + dialog.ShowError(fmt.Errorf("open the local folder:\n%s\n\nError: %v", folderPath, openErr), w) + } + }) + + box.Add(fileBtn) + box.Add(folderBtn) + + return box } diff --git a/client/ui/network.go b/client/ui/network.go index b21554f09..435917f30 100644 --- a/client/ui/network.go +++ b/client/ui/network.go @@ -34,7 +34,8 @@ const ( type filter string func (s *serviceClient) showNetworksUI() { - s.wRoutes = s.app.NewWindow("Networks") + s.wNetworks = s.app.NewWindow("Networks") + s.wNetworks.SetOnClosed(s.cancel) allGrid := container.New(layout.NewGridLayout(3)) go s.updateNetworks(allGrid, allNetworks) @@ -78,8 +79,8 @@ func (s *serviceClient) showNetworksUI() { content := container.NewBorder(nil, buttonBox, nil, nil, scrollContainer) - s.wRoutes.SetContent(content) - s.wRoutes.Show() + s.wNetworks.SetContent(content) + s.wNetworks.Show() s.startAutoRefresh(10*time.Second, tabs, allGrid, overlappingGrid, exitNodeGrid) } @@ -148,7 +149,7 @@ func (s *serviceClient) updateNetworks(grid *fyne.Container, f filter) { grid.Add(resolvedIPsSelector) } - s.wRoutes.Content().Refresh() + s.wNetworks.Content().Refresh() grid.Refresh() } @@ -305,7 +306,7 @@ func (s *serviceClient) getNetworksRequest(f filter, appendRoute bool) *proto.Se func (s *serviceClient) showError(err error) { wrappedMessage := wrapText(err.Error(), 50) - dialog.ShowError(fmt.Errorf("%s", wrappedMessage), s.wRoutes) + dialog.ShowError(fmt.Errorf("%s", wrappedMessage), s.wNetworks) } func (s *serviceClient) startAutoRefresh(interval time.Duration, tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) { @@ -316,14 +317,15 @@ func (s *serviceClient) startAutoRefresh(interval time.Duration, tabs *container } }() - s.wRoutes.SetOnClosed(func() { + s.wNetworks.SetOnClosed(func() { ticker.Stop() + s.cancel() }) } func (s *serviceClient) updateNetworksBasedOnDisplayTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) { grid, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodesGrid) - s.wRoutes.Content().Refresh() + s.wNetworks.Content().Refresh() s.updateNetworks(grid, f) } @@ -373,7 +375,7 @@ func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) { node.Selected, ) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(s.ctx) s.mExitNodeItems = append(s.mExitNodeItems, menuHandler{ MenuItem: menuItem, cancel: cancel, @@ -456,19 +458,27 @@ func (s *serviceClient) toggleExitNode(nodeID string, item *systray.MenuItem) er } } - if item.Checked() && len(ids) == 0 { - // exit node is the only selected node, deselect it + // exit node is the only selected node, deselect it + deselectAll := item.Checked() && len(ids) == 0 + if deselectAll { ids = append(ids, nodeID) - exitNode = nil + for _, node := range exitNodes { + if node.ID == nodeID { + // set desired state for recreation + node.Selected = false + } + } } // deselect all other selected exit nodes - if err := s.deselectOtherExitNodes(conn, ids, item); err != nil { + if err := s.deselectOtherExitNodes(conn, ids); err != nil { return err } - if err := s.selectNewExitNode(conn, exitNode, nodeID, item); err != nil { - return err + if !deselectAll { + if err := s.selectNewExitNode(conn, exitNode, nodeID, item); err != nil { + return err + } } // linux/bsd doesn't handle Check/Uncheck well, so we recreate the menu @@ -479,7 +489,7 @@ func (s *serviceClient) toggleExitNode(nodeID string, item *systray.MenuItem) er return nil } -func (s *serviceClient) deselectOtherExitNodes(conn proto.DaemonServiceClient, ids []string, currentItem *systray.MenuItem) error { +func (s *serviceClient) deselectOtherExitNodes(conn proto.DaemonServiceClient, ids []string) error { // deselect all other selected exit nodes if len(ids) > 0 { deselectReq := &proto.SelectNetworksRequest{ @@ -494,9 +504,6 @@ func (s *serviceClient) deselectOtherExitNodes(conn proto.DaemonServiceClient, i // uncheck all other exit node menu items for _, i := range s.mExitNodeItems { - if i.MenuItem == currentItem { - continue - } i.Uncheck() log.Infof("Unchecked exit node %v", i) } @@ -518,6 +525,7 @@ func (s *serviceClient) selectNewExitNode(conn proto.DaemonServiceClient, exitNo } item.Check() + log.Infof("Checked exit node '%s'", nodeID) return nil } diff --git a/dns/dns.go b/dns/dns.go index 8dfdf8526..f889a32ec 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -66,17 +66,17 @@ func (s SimpleRecord) String() string { func (s SimpleRecord) Len() uint16 { emptyString := s.RData == "" switch s.Type { - case 1: + case int(dns.TypeA): if emptyString { return 0 } return net.IPv4len - case 5: + case int(dns.TypeCNAME): if emptyString || s.RData == "." { return 1 } return uint16(len(s.RData) + 1) - case 28: + case int(dns.TypeAAAA): if emptyString { return 0 } @@ -111,6 +111,5 @@ func GetParsedDomainLabel(name string) (string, error) { // NormalizeZone returns a normalized domain name without the wildcard prefix func NormalizeZone(domain string) string { - d, _ := strings.CutPrefix(domain, "*.") - return d + return strings.TrimPrefix(domain, "*.") } diff --git a/go.mod b/go.mod index c00f32063..2b3ef9cd6 100644 --- a/go.mod +++ b/go.mod @@ -18,9 +18,9 @@ require ( github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 - github.com/vishvananda/netlink v1.2.1-beta.2 - golang.org/x/crypto v0.36.0 - golang.org/x/sys v0.31.0 + github.com/vishvananda/netlink v1.3.0 + golang.org/x/crypto v0.37.0 + golang.org/x/sys v0.32.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 @@ -33,13 +33,15 @@ require ( fyne.io/fyne/v2 v2.5.3 fyne.io/systray v1.11.0 github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible + github.com/aws/aws-sdk-go-v2 v1.36.3 + github.com/aws/aws-sdk-go-v2/config v1.29.14 + github.com/aws/aws-sdk-go-v2/service/s3 v1.79.2 github.com/c-robinson/iplib v1.0.3 github.com/caddyserver/certmagic v0.21.3 github.com/cilium/ebpf v0.15.0 github.com/coder/websocket v1.8.12 github.com/coreos/go-iptables v0.7.0 github.com/creack/pty v1.1.18 - github.com/davecgh/go-spew v1.1.1 github.com/eko/gocache/lib/v4 v4.2.0 github.com/eko/gocache/store/go_cache/v4 v4.2.2 github.com/eko/gocache/store/redis/v4 v4.2.2 @@ -49,7 +51,7 @@ require ( github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.7.0 github.com/google/gopacket v1.1.19 - github.com/google/nftables v0.2.0 + github.com/google/nftables v0.3.0 github.com/gopacket/gopacket v1.1.1 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 github.com/hashicorp/go-multierror v1.1.1 @@ -75,7 +77,7 @@ require ( github.com/pion/turn/v3 v3.0.1 github.com/prometheus/client_golang v1.22.0 github.com/quic-go/quic-go v0.48.2 - github.com/redis/go-redis/v9 v9.7.1 + github.com/redis/go-redis/v9 v9.7.3 github.com/rs/xid v1.3.0 github.com/shirou/gopsutil/v3 v3.24.4 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 @@ -100,10 +102,10 @@ require ( goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a - golang.org/x/net v0.36.0 + golang.org/x/net v0.39.0 golang.org/x/oauth2 v0.24.0 - golang.org/x/sync v0.12.0 - golang.org/x/term v0.30.0 + golang.org/x/sync v0.13.0 + golang.org/x/term v0.31.0 google.golang.org/api v0.177.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/mysql v1.5.7 @@ -124,20 +126,22 @@ require ( github.com/Microsoft/go-winio v0.6.2 // indirect github.com/Microsoft/hcsshim v0.12.3 // indirect github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect - github.com/aws/aws-sdk-go-v2 v1.30.3 // indirect - github.com/aws/aws-sdk-go-v2/config v1.27.27 // indirect - github.com/aws/aws-sdk-go-v2/credentials v1.17.27 // indirect - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15 // indirect - github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.34 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.7.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.15 // indirect github.com/aws/aws-sdk-go-v2/service/route53 v1.42.3 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.22.4 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.30.3 // indirect - github.com/aws/smithy-go v1.20.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 // indirect + github.com/aws/smithy-go v1.22.2 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/caddyserver/zerossl v0.1.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect @@ -145,6 +149,7 @@ require ( github.com/containerd/log v0.1.0 // indirect github.com/containerd/platforms v0.2.1 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/distribution/reference v0.6.0 // indirect github.com/docker/docker v26.1.5+incompatible // indirect @@ -183,7 +188,6 @@ require ( github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect - github.com/josharian/native v1.1.0 // indirect github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e // indirect github.com/kelseyhightower/envconfig v1.4.0 // indirect github.com/klauspost/compress v1.18.0 // indirect @@ -192,7 +196,7 @@ require ( github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mdlayher/genetlink v1.3.2 // indirect - github.com/mdlayher/netlink v1.7.2 // indirect + github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect github.com/mholt/acmez/v2 v2.0.1 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/patternmatcher v0.6.0 // indirect @@ -235,7 +239,7 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/mod v0.17.0 // indirect - golang.org/x/text v0.23.0 // indirect + golang.org/x/text v0.24.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect diff --git a/go.sum b/go.sum index f00b42beb..a90db83de 100644 --- a/go.sum +++ b/go.sum @@ -74,34 +74,44 @@ github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kd github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= -github.com/aws/aws-sdk-go-v2 v1.30.3 h1:jUeBtG0Ih+ZIFH0F4UkmL9w3cSpaMv9tYYDbzILP8dY= -github.com/aws/aws-sdk-go-v2 v1.30.3/go.mod h1:nIQjQVp5sfpQcTc9mPSr1B0PaWK5ByX9MOoDadSN4lc= -github.com/aws/aws-sdk-go-v2/config v1.27.27 h1:HdqgGt1OAP0HkEDDShEl0oSYa9ZZBSOmKpdpsDMdO90= -github.com/aws/aws-sdk-go-v2/config v1.27.27/go.mod h1:MVYamCg76dFNINkZFu4n4RjDixhVr51HLj4ErWzrVwg= -github.com/aws/aws-sdk-go-v2/credentials v1.17.27 h1:2raNba6gr2IfA0eqqiP2XiQ0UVOpGPgDSi0I9iAP+UI= -github.com/aws/aws-sdk-go-v2/credentials v1.17.27/go.mod h1:gniiwbGahQByxan6YjQUMcW4Aov6bLC3m+evgcoN4r4= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 h1:KreluoV8FZDEtI6Co2xuNk/UqI9iwMrOx/87PBNIKqw= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11/go.mod h1:SeSUYBLsMYFoRvHE0Tjvn7kbxaUhl75CJi1sbfhMxkU= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15 h1:SoNJ4RlFEQEbtDcCEt+QG56MY4fm4W8rYirAmq+/DdU= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15/go.mod h1:U9ke74k1n2bf+RIgoX1SXFed1HLs51OgUSs+Ph0KJP8= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15 h1:C6WHdGnTDIYETAm5iErQUiVNsclNx9qbJVPIt03B6bI= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15/go.mod h1:ZQLZqhcu+JhSrA9/NXRm8SkDvsycE+JkV3WGY41e+IM= -github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU= -github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 h1:dT3MqvGhSoaIhRseqw2I0yH81l7wiR2vjs57O51EAm8= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3/go.mod h1:GlAeCkHwugxdHaueRr4nhPuY+WW+gR8UjlcqzPr1SPI= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 h1:HGErhhrxZlQ044RiM+WdoZxp0p+EGM62y3L6pwA4olE= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17/go.mod h1:RkZEx4l0EHYDJpWppMJ3nD9wZJAa8/0lq9aVC+r2UII= +github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM= +github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 h1:zAybnyUQXIZ5mok5Jqwlf58/TFE7uvd3IAsa1aF9cXs= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10/go.mod h1:qqvMj6gHLR/EXWZw4ZbqlPbQUyenf4h82UQUlKc+l14= +github.com/aws/aws-sdk-go-v2/config v1.29.14 h1:f+eEi/2cKCg9pqKBoAIwRGzVb70MRKqWX4dg1BDcSJM= +github.com/aws/aws-sdk-go-v2/config v1.29.14/go.mod h1:wVPHWcIFv3WO89w0rE10gzf17ZYy+UVS1Geq8Iei34g= +github.com/aws/aws-sdk-go-v2/credentials v1.17.67 h1:9KxtdcIA/5xPNQyZRgUSpYOE6j9Bc4+D7nZua0KGYOM= +github.com/aws/aws-sdk-go-v2/credentials v1.17.67/go.mod h1:p3C44m+cfnbv763s52gCqrjaqyPikj9Sg47kUVaNZQQ= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 h1:x793wxmUWVDhshP8WW2mlnXuFrO4cOd3HLBroh1paFw= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30/go.mod h1:Jpne2tDnYiFascUEs2AWHJL9Yp7A5ZVy3TNyxaAjD6M= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 h1:ZK5jHhnrioRkUNOc+hOgQKlUL5JeC3S6JgLxtQ+Rm0Q= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34/go.mod h1:p4VfIceZokChbA9FzMbRGz5OV+lekcVtHlPKEO0gSZY= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 h1:SZwFm17ZUNNg5Np0ioo/gq8Mn6u9w19Mri8DnJ15Jf0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34/go.mod h1:dFZsC0BLo346mvKQLWmoJxT+Sjp+qcVR1tRVHQGOH9Q= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.34 h1:ZNTqv4nIdE/DiBfUUfXcLZ/Spcuz+RjeziUtNJackkM= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.34/go.mod h1:zf7Vcd1ViW7cPqYWEHLHJkS50X0JS2IKz9Cgaj6ugrs= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 h1:eAh2A4b5IzM/lum78bZ590jy36+d/aFLgKF/4Vd1xPE= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3/go.mod h1:0yKJC/kb8sAnmlYa6Zs3QVYqaC8ug2AbnNChv5Ox3uA= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.7.0 h1:lguz0bmOoGzozP9XfRJR1QIayEYo+2vP/No3OfLF0pU= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.7.0/go.mod h1:iu6FSzgt+M2/x3Dk8zhycdIcHjEFb36IS8HVUVFoMg0= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 h1:dM9/92u2F1JbDaGooxTq18wmmFzbJRfXfVfy96/1CXM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15/go.mod h1:SwFBy2vjtA0vZbjjaFtfN045boopadnoVPhu4Fv66vY= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.15 h1:moLQUoVq91LiqT1nbvzDukyqAlCv89ZmwaHw/ZFlFZg= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.15/go.mod h1:ZH34PJUc8ApjBIfgQCFvkWcUDBtl/WTD+uiYHjd8igA= github.com/aws/aws-sdk-go-v2/service/route53 v1.42.3 h1:MmLCRqP4U4Cw9gJ4bNrCG0mWqEtBlmAVleyelcHARMU= github.com/aws/aws-sdk-go-v2/service/route53 v1.42.3/go.mod h1:AMPjK2YnRh0YgOID3PqhJA1BRNfXDfGOnSsKHtAe8yA= -github.com/aws/aws-sdk-go-v2/service/sso v1.22.4 h1:BXx0ZIxvrJdSgSvKTZ+yRBeSqqgPM89VPlulEcl37tM= -github.com/aws/aws-sdk-go-v2/service/sso v1.22.4/go.mod h1:ooyCOXjvJEsUw7x+ZDHeISPMhtwI3ZCB7ggFMcFfWLU= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 h1:yiwVzJW2ZxZTurVbYWA7QOrAaCYQR72t0wrSBfoesUE= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4/go.mod h1:0oxfLkpz3rQ/CHlx5hB7H69YUpFiI1tql6Q6Ne+1bCw= -github.com/aws/aws-sdk-go-v2/service/sts v1.30.3 h1:ZsDKRLXGWHk8WdtyYMoGNO7bTudrvuKpDKgMVRlepGE= -github.com/aws/aws-sdk-go-v2/service/sts v1.30.3/go.mod h1:zwySh8fpFyXp9yOr/KVzxOl8SRqgf/IDw5aUt9UKFcQ= -github.com/aws/smithy-go v1.20.3 h1:ryHwveWzPV5BIof6fyDvor6V3iUL7nTfiTKXHiW05nE= -github.com/aws/smithy-go v1.20.3/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= +github.com/aws/aws-sdk-go-v2/service/s3 v1.79.2 h1:tWUG+4wZqdMl/znThEk9tcCy8tTMxq8dW0JTgamohrY= +github.com/aws/aws-sdk-go-v2/service/s3 v1.79.2/go.mod h1:U5SNqwhXB3Xe6F47kXvWihPl/ilGaEDe8HD/50Z9wxc= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 h1:1Gw+9ajCV1jogloEv1RRnvfRFia2cL6c9cuKV2Ps+G8= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.3/go.mod h1:qs4a9T5EMLl/Cajiw2TcbNt2UNo/Hqlyp+GiuG4CFDI= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 h1:hXmVKytPfTy5axZ+fYbR5d0cFmC3JvwLm5kM83luako= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1/go.mod h1:MlYRNmYu/fGPoxBQVvBYr9nyr948aY/WLUvwBMBJubs= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 h1:1XuUZ8mYJw9B6lzAkXhqHlJd/XvaX32evhproijJEZY= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.19/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4= +github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ= +github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= @@ -301,8 +311,8 @@ github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= -github.com/google/nftables v0.2.0 h1:PbJwaBmbVLzpeldoeUKGkE2RjstrjPKMl6oLrfEJ6/8= -github.com/google/nftables v0.2.0/go.mod h1:Beg6V6zZ3oEn0JuiUQ4wqwuyqqzasOltcoXPtgLbFp4= +github.com/google/nftables v0.3.0 h1:bkyZ0cbpVeMHXOrtlFc8ISmfVqq5gPJukoYieyVmITg= +github.com/google/nftables v0.3.0/go.mod h1:BCp9FsrbF1Fn/Yu6CLUc9GGZFw/+hsxfluNXXmxBfRM= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= @@ -399,8 +409,6 @@ github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9Y github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= -github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= -github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= @@ -447,8 +455,8 @@ github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw= github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= -github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= -github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= +github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0tpKMbdCXCIfBclan4oHk1Jb+Hrejirg= +github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o= github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k= @@ -576,8 +584,8 @@ github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0leargg github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE= github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs= -github.com/redis/go-redis/v9 v9.7.1 h1:4LhKRCIduqXqtvCUlaq9c8bdHOkICjDMrr1+Zb3osAc= -github.com/redis/go-redis/v9 v9.7.1/go.mod h1:f6zhXITC7JUJIlPEiBOTXxJgPLdZcA93GewI7inzyWw= +github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM= +github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= @@ -665,9 +673,8 @@ github.com/tklauser/numcpus v0.8.0 h1:Mx4Wwe/FjZLeQsK/6kt2EOepwwSl7SmJrK5bV/dXYg github.com/tklauser/numcpus v0.8.0/go.mod h1:ZJZlAY+dmR4eut8epnzf0u/VwodKmryxR8txiloSqBE= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= -github.com/vishvananda/netlink v1.2.1-beta.2 h1:Llsql0lnQEbHj0I1OuKyp8otXp0r3q0mPkuhwHfStVs= -github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho= -github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= +github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk= +github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= @@ -752,8 +759,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= -golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= +golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= +golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -846,8 +853,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= -golang.org/x/net v0.36.0 h1:vWF2fRbw4qslQsQzgFqZff+BItCvGFQqKzKIzx1rmoA= -golang.org/x/net v0.36.0/go.mod h1:bFmbeoIPfrw4sMHNhb4J9f6+tPziuGjq7Jk/38fxi1I= +golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= +golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -876,8 +883,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= -golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= +golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -902,7 +909,6 @@ golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -911,7 +917,6 @@ golang.org/x/sys v0.0.0-20200501052902-10377860bb8e/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200728102440-3e129f6d46b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -939,14 +944,16 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= -golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= +golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -954,8 +961,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= -golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y= -golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g= +golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o= +golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -969,8 +976,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= -golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= +golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/management/client/grpc.go b/management/client/grpc.go index d3aaffec0..2f4729e23 100644 --- a/management/client/grpc.go +++ b/management/client/grpc.go @@ -159,6 +159,7 @@ func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key, // blocking until error err = c.receiveEvents(stream, serverPubKey, msgHandler) if err != nil { + c.notifyDisconnected(err) s, _ := gstatus.FromError(err) switch s.Code() { case codes.PermissionDenied: @@ -167,7 +168,6 @@ func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key, log.Debugf("management connection context has been canceled, this usually indicates shutdown") return nil default: - c.notifyDisconnected(err) log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err) return err } @@ -258,10 +258,10 @@ func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, se return err } - err = msgHandler(decryptedResp) - if err != nil { + if err := msgHandler(decryptedResp); err != nil { log.Errorf("failed handling an update message received from Management Service: %v", err.Error()) - return err + // hide any grpc error code that is not relevant for management + return fmt.Errorf("msg handler error: %v", err.Error()) } } } diff --git a/management/client/rest/accounts.go b/management/client/rest/accounts.go index a0ecd730c..2530e4f72 100644 --- a/management/client/rest/accounts.go +++ b/management/client/rest/accounts.go @@ -20,7 +20,9 @@ func (a *AccountsAPI) List(ctx context.Context) ([]api.Account, error) { if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.Account](resp) return ret, err } @@ -36,7 +38,9 @@ func (a *AccountsAPI) Update(ctx context.Context, accountID string, request api. if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.Account](resp) return &ret, err } @@ -48,7 +52,9 @@ func (a *AccountsAPI) Delete(ctx context.Context, accountID string) error { if err != nil { return err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } return nil } diff --git a/management/client/rest/dns.go b/management/client/rest/dns.go index 6c0dc02d3..1e35c0226 100644 --- a/management/client/rest/dns.go +++ b/management/client/rest/dns.go @@ -20,7 +20,9 @@ func (a *DNSAPI) ListNameserverGroups(ctx context.Context) ([]api.NameserverGrou if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.NameserverGroup](resp) return ret, err } @@ -32,7 +34,9 @@ func (a *DNSAPI) GetNameserverGroup(ctx context.Context, nameserverGroupID strin if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.NameserverGroup](resp) return &ret, err } @@ -48,7 +52,9 @@ func (a *DNSAPI) CreateNameserverGroup(ctx context.Context, request api.PostApiD if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.NameserverGroup](resp) return &ret, err } @@ -64,7 +70,9 @@ func (a *DNSAPI) UpdateNameserverGroup(ctx context.Context, nameserverGroupID st if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.NameserverGroup](resp) return &ret, err } @@ -76,7 +84,9 @@ func (a *DNSAPI) DeleteNameserverGroup(ctx context.Context, nameserverGroupID st if err != nil { return err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } return nil } @@ -88,7 +98,9 @@ func (a *DNSAPI) GetSettings(ctx context.Context) (*api.DNSSettings, error) { if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.DNSSettings](resp) return &ret, err } @@ -104,7 +116,9 @@ func (a *DNSAPI) UpdateSettings(ctx context.Context, request api.PutApiDnsSettin if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.DNSSettings](resp) return &ret, err } diff --git a/management/client/rest/events.go b/management/client/rest/events.go index e19532df1..cae813e86 100644 --- a/management/client/rest/events.go +++ b/management/client/rest/events.go @@ -18,7 +18,9 @@ func (a *EventsAPI) List(ctx context.Context) ([]api.Event, error) { if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.Event](resp) return ret, err } diff --git a/management/client/rest/geo.go b/management/client/rest/geo.go index 6162281e2..d06d65d80 100644 --- a/management/client/rest/geo.go +++ b/management/client/rest/geo.go @@ -18,7 +18,9 @@ func (a *GeoLocationAPI) ListCountries(ctx context.Context) ([]api.Country, erro if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.Country](resp) return ret, err } @@ -30,7 +32,9 @@ func (a *GeoLocationAPI) ListCountryCities(ctx context.Context, countryCode stri if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.City](resp) return ret, err } diff --git a/management/client/rest/groups.go b/management/client/rest/groups.go index 56a0e3278..7612b7188 100644 --- a/management/client/rest/groups.go +++ b/management/client/rest/groups.go @@ -20,7 +20,9 @@ func (a *GroupsAPI) List(ctx context.Context) ([]api.Group, error) { if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.Group](resp) return ret, err } @@ -32,7 +34,9 @@ func (a *GroupsAPI) Get(ctx context.Context, groupID string) (*api.Group, error) if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.Group](resp) return &ret, err } @@ -48,7 +52,9 @@ func (a *GroupsAPI) Create(ctx context.Context, request api.PostApiGroupsJSONReq if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.Group](resp) return &ret, err } @@ -64,7 +70,9 @@ func (a *GroupsAPI) Update(ctx context.Context, groupID string, request api.PutA if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.Group](resp) return &ret, err } @@ -76,7 +84,9 @@ func (a *GroupsAPI) Delete(ctx context.Context, groupID string) error { if err != nil { return err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } return nil } diff --git a/management/client/rest/networks.go b/management/client/rest/networks.go index 55ed7320a..b744e3fe7 100644 --- a/management/client/rest/networks.go +++ b/management/client/rest/networks.go @@ -20,7 +20,9 @@ func (a *NetworksAPI) List(ctx context.Context) ([]api.Network, error) { if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.Network](resp) return ret, err } @@ -32,7 +34,9 @@ func (a *NetworksAPI) Get(ctx context.Context, networkID string) (*api.Network, if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.Network](resp) return &ret, err } @@ -48,7 +52,9 @@ func (a *NetworksAPI) Create(ctx context.Context, request api.PostApiNetworksJSO if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.Network](resp) return &ret, err } @@ -64,7 +70,9 @@ func (a *NetworksAPI) Update(ctx context.Context, networkID string, request api. if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.Network](resp) return &ret, err } @@ -76,7 +84,9 @@ func (a *NetworksAPI) Delete(ctx context.Context, networkID string) error { if err != nil { return err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } return nil } @@ -102,7 +112,9 @@ func (a *NetworkResourcesAPI) List(ctx context.Context) ([]api.NetworkResource, if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.NetworkResource](resp) return ret, err } @@ -114,7 +126,9 @@ func (a *NetworkResourcesAPI) Get(ctx context.Context, networkResourceID string) if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.NetworkResource](resp) return &ret, err } @@ -130,7 +144,9 @@ func (a *NetworkResourcesAPI) Create(ctx context.Context, request api.PostApiNet if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.NetworkResource](resp) return &ret, err } @@ -146,7 +162,9 @@ func (a *NetworkResourcesAPI) Update(ctx context.Context, networkResourceID stri if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.NetworkResource](resp) return &ret, err } @@ -158,7 +176,9 @@ func (a *NetworkResourcesAPI) Delete(ctx context.Context, networkResourceID stri if err != nil { return err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } return nil } @@ -184,7 +204,9 @@ func (a *NetworkRoutersAPI) List(ctx context.Context) ([]api.NetworkRouter, erro if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.NetworkRouter](resp) return ret, err } @@ -196,7 +218,9 @@ func (a *NetworkRoutersAPI) Get(ctx context.Context, networkRouterID string) (*a if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.NetworkRouter](resp) return &ret, err } @@ -212,7 +236,9 @@ func (a *NetworkRoutersAPI) Create(ctx context.Context, request api.PostApiNetwo if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.NetworkRouter](resp) return &ret, err } @@ -228,7 +254,9 @@ func (a *NetworkRoutersAPI) Update(ctx context.Context, networkRouterID string, if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.NetworkRouter](resp) return &ret, err } @@ -240,7 +268,9 @@ func (a *NetworkRoutersAPI) Delete(ctx context.Context, networkRouterID string) if err != nil { return err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } return nil } diff --git a/management/client/rest/peers.go b/management/client/rest/peers.go index 23220609f..37679fdb9 100644 --- a/management/client/rest/peers.go +++ b/management/client/rest/peers.go @@ -20,7 +20,9 @@ func (a *PeersAPI) List(ctx context.Context) ([]api.Peer, error) { if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.Peer](resp) return ret, err } @@ -32,7 +34,9 @@ func (a *PeersAPI) Get(ctx context.Context, peerID string) (*api.Peer, error) { if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.Peer](resp) return &ret, err } @@ -48,7 +52,9 @@ func (a *PeersAPI) Update(ctx context.Context, peerID string, request api.PutApi if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.Peer](resp) return &ret, err } @@ -60,7 +66,9 @@ func (a *PeersAPI) Delete(ctx context.Context, peerID string) error { if err != nil { return err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } return nil } @@ -72,7 +80,9 @@ func (a *PeersAPI) ListAccessiblePeers(ctx context.Context, peerID string) ([]ap if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.Peer](resp) return ret, err } diff --git a/management/client/rest/policies.go b/management/client/rest/policies.go index 426a688e2..006b0eeb7 100644 --- a/management/client/rest/policies.go +++ b/management/client/rest/policies.go @@ -24,7 +24,9 @@ func (a *PoliciesAPI) List(ctx context.Context, accountID string) ([]api.Policy, if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.Policy](resp) return ret, err } @@ -36,7 +38,9 @@ func (a *PoliciesAPI) Get(ctx context.Context, policyID string) (*api.Policy, er if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.Policy](resp) return &ret, err } @@ -52,7 +56,9 @@ func (a *PoliciesAPI) Create(ctx context.Context, request api.PostApiPoliciesJSO if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.Policy](resp) return &ret, err } @@ -72,7 +78,9 @@ func (a *PoliciesAPI) Update(ctx context.Context, policyID string, request api.P if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.Policy](resp) return &ret, err } @@ -84,7 +92,9 @@ func (a *PoliciesAPI) Delete(ctx context.Context, policyID string) error { if err != nil { return err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } return nil } diff --git a/management/client/rest/posturechecks.go b/management/client/rest/posturechecks.go index bcb62f096..622eeeb64 100644 --- a/management/client/rest/posturechecks.go +++ b/management/client/rest/posturechecks.go @@ -20,7 +20,9 @@ func (a *PostureChecksAPI) List(ctx context.Context) ([]api.PostureCheck, error) if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.PostureCheck](resp) return ret, err } @@ -32,7 +34,9 @@ func (a *PostureChecksAPI) Get(ctx context.Context, postureCheckID string) (*api if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.PostureCheck](resp) return &ret, err } @@ -48,7 +52,9 @@ func (a *PostureChecksAPI) Create(ctx context.Context, request api.PostApiPostur if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.PostureCheck](resp) return &ret, err } @@ -64,7 +70,9 @@ func (a *PostureChecksAPI) Update(ctx context.Context, postureCheckID string, re if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.PostureCheck](resp) return &ret, err } @@ -76,7 +84,9 @@ func (a *PostureChecksAPI) Delete(ctx context.Context, postureCheckID string) er if err != nil { return err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } return nil } diff --git a/management/client/rest/routes.go b/management/client/rest/routes.go index b85084c8d..671c3bfc9 100644 --- a/management/client/rest/routes.go +++ b/management/client/rest/routes.go @@ -20,7 +20,9 @@ func (a *RoutesAPI) List(ctx context.Context) ([]api.Route, error) { if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.Route](resp) return ret, err } @@ -32,7 +34,9 @@ func (a *RoutesAPI) Get(ctx context.Context, routeID string) (*api.Route, error) if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.Route](resp) return &ret, err } @@ -48,7 +52,9 @@ func (a *RoutesAPI) Create(ctx context.Context, request api.PostApiRoutesJSONReq if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.Route](resp) return &ret, err } @@ -64,7 +70,9 @@ func (a *RoutesAPI) Update(ctx context.Context, routeID string, request api.PutA if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.Route](resp) return &ret, err } @@ -76,7 +84,9 @@ func (a *RoutesAPI) Delete(ctx context.Context, routeID string) error { if err != nil { return err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } return nil } diff --git a/management/client/rest/setupkeys.go b/management/client/rest/setupkeys.go index 7e98107d1..ceb3996d6 100644 --- a/management/client/rest/setupkeys.go +++ b/management/client/rest/setupkeys.go @@ -20,7 +20,9 @@ func (a *SetupKeysAPI) List(ctx context.Context) ([]api.SetupKey, error) { if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.SetupKey](resp) return ret, err } @@ -32,7 +34,9 @@ func (a *SetupKeysAPI) Get(ctx context.Context, setupKeyID string) (*api.SetupKe if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.SetupKey](resp) return &ret, err } @@ -53,7 +57,9 @@ func (a *SetupKeysAPI) Create(ctx context.Context, request api.PostApiSetupKeysJ if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.SetupKeyClear](resp) return &ret, err } @@ -69,7 +75,9 @@ func (a *SetupKeysAPI) Update(ctx context.Context, setupKeyID string, request ap if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.SetupKey](resp) return &ret, err } @@ -81,7 +89,9 @@ func (a *SetupKeysAPI) Delete(ctx context.Context, setupKeyID string) error { if err != nil { return err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } return nil } diff --git a/management/client/rest/tokens.go b/management/client/rest/tokens.go index 9fcab82ef..278a0d159 100644 --- a/management/client/rest/tokens.go +++ b/management/client/rest/tokens.go @@ -20,7 +20,9 @@ func (a *TokensAPI) List(ctx context.Context, userID string) ([]api.PersonalAcce if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.PersonalAccessToken](resp) return ret, err } @@ -32,7 +34,9 @@ func (a *TokensAPI) Get(ctx context.Context, userID, tokenID string) (*api.Perso if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.PersonalAccessToken](resp) return &ret, err } @@ -48,7 +52,9 @@ func (a *TokensAPI) Create(ctx context.Context, userID string, request api.PostA if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.PersonalAccessTokenGenerated](resp) return &ret, err } @@ -60,7 +66,9 @@ func (a *TokensAPI) Delete(ctx context.Context, userID, tokenID string) error { if err != nil { return err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } return nil } diff --git a/management/client/rest/users.go b/management/client/rest/users.go index cb13ca617..107b0581e 100644 --- a/management/client/rest/users.go +++ b/management/client/rest/users.go @@ -20,7 +20,9 @@ func (a *UsersAPI) List(ctx context.Context) ([]api.User, error) { if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.User](resp) return ret, err } @@ -36,7 +38,9 @@ func (a *UsersAPI) Create(ctx context.Context, request api.PostApiUsersJSONReque if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.User](resp) return &ret, err } @@ -52,7 +56,9 @@ func (a *UsersAPI) Update(ctx context.Context, userID string, request api.PutApi if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.User](resp) return &ret, err } @@ -64,7 +70,9 @@ func (a *UsersAPI) Delete(ctx context.Context, userID string) error { if err != nil { return err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } return nil } @@ -76,7 +84,9 @@ func (a *UsersAPI) ResendInvitation(ctx context.Context, userID string) error { if err != nil { return err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } return nil } @@ -88,7 +98,9 @@ func (a *UsersAPI) Current(ctx context.Context) (*api.User, error) { if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.User](resp) return &ret, err diff --git a/management/client/rest/users_test.go b/management/client/rest/users_test.go index f68c5f083..715eb1661 100644 --- a/management/client/rest/users_test.go +++ b/management/client/rest/users_test.go @@ -30,11 +30,8 @@ var ( Issued: ptr("api"), LastLogin: &time.Time{}, Name: "M. Essam", - Permissions: &api.UserPermissions{ - DashboardView: ptr(api.UserPermissionsDashboardViewFull), - }, - Role: "user", - Status: api.UserStatusActive, + Role: "user", + Status: api.UserStatusActive, } ) diff --git a/management/domain/domain.go b/management/domain/domain.go index 2e089b01f..97acec688 100644 --- a/management/domain/domain.go +++ b/management/domain/domain.go @@ -1,12 +1,17 @@ package domain import ( + "strings" + "golang.org/x/net/idna" ) +// Domain represents a punycode-encoded domain string. +// This should only be converted from a string when the string already is in punycode, otherwise use FromString. type Domain string // String converts the Domain to a non-punycode string. +// For an infallible conversion, use SafeString. func (d Domain) String() (string, error) { unicode, err := idna.ToUnicode(string(d)) if err != nil { @@ -15,16 +20,17 @@ func (d Domain) String() (string, error) { return unicode, nil } -// SafeString converts the Domain to a non-punycode string, falling back to the original string if conversion fails. +// SafeString converts the Domain to a non-punycode string, falling back to the punycode string if conversion fails. func (d Domain) SafeString() string { str, err := d.String() if err != nil { - str = string(d) + return string(d) } return str } // PunycodeString returns the punycode representation of the Domain. +// This should only be used if a punycode domain is expected but only a string is supported. func (d Domain) PunycodeString() string { return string(d) } @@ -35,5 +41,5 @@ func FromString(s string) (Domain, error) { if err != nil { return "", err } - return Domain(ascii), nil + return Domain(strings.ToLower(ascii)), nil } diff --git a/management/domain/list.go b/management/domain/list.go index b6090c717..a988f4f70 100644 --- a/management/domain/list.go +++ b/management/domain/list.go @@ -5,6 +5,7 @@ import ( "strings" ) +// List is a slice of punycode-encoded domain strings. type List []Domain // ToStringList converts a List to a slice of string. @@ -53,7 +54,7 @@ func (d List) String() (string, error) { func (d List) SafeString() string { str, err := d.String() if err != nil { - return strings.Join(d.ToPunycodeList(), ", ") + return d.PunycodeString() } return str } @@ -101,7 +102,7 @@ func FromStringList(s []string) (List, error) { func FromPunycodeList(s []string) List { var dl List for _, domain := range s { - dl = append(dl, Domain(domain)) + dl = append(dl, Domain(strings.ToLower(domain))) } return dl } diff --git a/management/domain/validate.go b/management/domain/validate.go index bcbf26e05..a42aebe6f 100644 --- a/management/domain/validate.go +++ b/management/domain/validate.go @@ -22,8 +22,6 @@ func ValidateDomains(domains []string) (List, error) { var domainList List for _, d := range domains { - d := strings.ToLower(d) - // handles length and idna conversion punycode, err := FromString(d) if err != nil { diff --git a/management/proto/management.pb.go b/management/proto/management.pb.go index f3f53bfd4..9d7fdc682 100644 --- a/management/proto/management.pb.go +++ b/management/proto/management.pb.go @@ -3057,6 +3057,8 @@ type RouteFirewallRule struct { CustomProtocol uint32 `protobuf:"varint,8,opt,name=customProtocol,proto3" json:"customProtocol,omitempty"` // PolicyID is the ID of the policy that this rule belongs to PolicyID []byte `protobuf:"bytes,9,opt,name=PolicyID,proto3" json:"PolicyID,omitempty"` + // RouteID is the ID of the route that this rule belongs to + RouteID string `protobuf:"bytes,10,opt,name=RouteID,proto3" json:"RouteID,omitempty"` } func (x *RouteFirewallRule) Reset() { @@ -3154,6 +3156,13 @@ func (x *RouteFirewallRule) GetPolicyID() []byte { return nil } +func (x *RouteFirewallRule) GetRouteID() string { + if x != nil { + return x.RouteID + } + return "" +} + type ForwardingRule struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -3702,7 +3711,7 @@ var file_management_proto_rawDesc = []byte{ 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, - 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0xed, 0x02, 0x0a, 0x11, 0x52, 0x6f, + 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x87, 0x03, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, @@ -3725,66 +3734,68 @@ var file_management_proto_rawDesc = []byte{ 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, - 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x22, 0xf2, 0x01, 0x0a, 0x0e, 0x46, 0x6f, - 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x34, 0x0a, 0x08, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, - 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, - 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, - 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, - 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, - 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, - 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x11, 0x74, - 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, - 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, - 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0e, - 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x2a, 0x4c, - 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, - 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, - 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, - 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, - 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, - 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, - 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, - 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, - 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, - 0x10, 0x01, 0x32, 0x90, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, - 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x52, 0x6f, 0x75, + 0x74, 0x65, 0x49, 0x44, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x52, 0x6f, 0x75, 0x74, + 0x65, 0x49, 0x44, 0x22, 0xf2, 0x01, 0x0a, 0x0e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, + 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, + 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, + 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, + 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, + 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, + 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, + 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, + 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x2a, 0x4c, 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, + 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, + 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, + 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, + 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, + 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, + 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, + 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, + 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, + 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x32, 0x90, 0x04, 0x0a, + 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, + 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, + 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, + 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, + 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, - 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, - 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, - 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, - 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, - 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, - 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, - 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, - 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, - 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, - 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, 0x0a, 0x18, - 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, + 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, + 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, + 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, + 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, + 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, - 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, - 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, + 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, + 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, - 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, - 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, - 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, + 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, + 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, + 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, + 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x33, } var ( diff --git a/management/proto/management.proto b/management/proto/management.proto index 0f1cdb97a..f0dc16ce2 100644 --- a/management/proto/management.proto +++ b/management/proto/management.proto @@ -509,6 +509,9 @@ message RouteFirewallRule { // PolicyID is the ID of the policy that this rule belongs to bytes PolicyID = 9; + + // RouteID is the ID of the route that this rule belongs to + string RouteID = 10; } message ForwardingRule { diff --git a/management/server/account.go b/management/server/account.go index d7f108dfe..5c474a343 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -17,6 +17,7 @@ import ( "time" cacheStore "github.com/eko/gocache/lib/v4/store" + "github.com/eko/gocache/store/redis/v4" "github.com/rs/xid" log "github.com/sirupsen/logrus" "github.com/vmihailenco/msgpack/v5" @@ -237,7 +238,7 @@ func BuildManager( if !isNil(am.idpManager) { go func() { - err := am.warmupIDPCache(ctx) + err := am.warmupIDPCache(ctx, cacheStore) if err != nil { log.WithContext(ctx).Warnf("failed warming up cache due to error: %v", err) // todo retry? @@ -275,6 +276,10 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") } + if newSettings.DNSDomain != "" && !isDomainValid(newSettings.DNSDomain) { + return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain) + } + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -325,6 +330,12 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco account.Network.Serial++ } + if oldSettings.DNSDomain != newSettings.DNSDomain { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountDNSDomainUpdated, nil) + updateAccountPeers = true + account.Network.Serial++ + } + err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) if err != nil { return nil, err @@ -484,7 +495,25 @@ func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain return nil, status.Errorf(status.Internal, "error while creating new account") } -func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context) error { +func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context, store cacheStore.StoreInterface) error { + cold, err := am.isCacheCold(ctx, store) + if err != nil { + return err + } + + if !cold { + log.WithContext(ctx).Debug("cache already populated, skipping warm up") + return nil + } + + if delayStr, ok := os.LookupEnv("NB_IDP_CACHE_WARMUP_DELAY"); ok { + delay, err := time.ParseDuration(delayStr) + if err != nil { + return fmt.Errorf("invalid IDP warmup delay: %w", err) + } + time.Sleep(delay) + } + userData, err := am.idpManager.GetAllAccounts(ctx) if err != nil { return err @@ -524,6 +553,32 @@ func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context) error { return nil } +// isCacheCold checks if the cache needs warming up. +func (am *DefaultAccountManager) isCacheCold(ctx context.Context, store cacheStore.StoreInterface) (bool, error) { + if store.GetType() != redis.RedisType { + return true, nil + } + + accountID, err := am.Store.GetAnyAccountID(ctx) + if err != nil { + if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { + return true, nil + } + return false, err + } + + _, err = store.Get(ctx, accountID) + if err == nil { + return false, nil + } + + if notFoundErr := new(cacheStore.NotFound); errors.As(err, ¬FoundErr) { + return true, nil + } + + return false, fmt.Errorf("failed to check cache: %w", err) +} + // DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) @@ -548,11 +603,15 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u } for _, otherUser := range account.Users { - if otherUser.IsServiceUser { + if otherUser.Id == userID { continue } - if otherUser.Id == userID { + if otherUser.IsServiceUser { + err = am.deleteServiceUser(ctx, accountID, userID, otherUser) + if err != nil { + return err + } continue } @@ -657,7 +716,7 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any) log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID) accountIDString := fmt.Sprintf("%v", accountID) - account, err := am.Store.GetAccount(ctx, accountIDString) + accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountIDString) if err != nil { return nil, nil, err } @@ -666,7 +725,7 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any) if err != nil { return nil, nil, err } - log.WithContext(ctx).Debugf("%d entries received from IdP management", len(userData)) + log.WithContext(ctx).Debugf("%d entries received from IdP management for account %s", len(userData), accountIDString) dataMap := make(map[string]*idp.UserData, len(userData)) for _, datum := range userData { @@ -674,7 +733,7 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any) } matchedUserData := make([]*idp.UserData, 0) - for _, user := range account.Users { + for _, user := range accountUsers { if user.IsServiceUser { continue } @@ -1057,6 +1116,19 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s return am.Store.GetAccount(ctx, accountID) } +// GetAccountMeta returns the account metadata associated with this account ID. +func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) { + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !allowed { + return nil, status.NewPermissionDeniedError() + } + + return am.Store.GetAccountMeta(ctx, store.LockingStrengthShare, accountID) +} + func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { if userAuth.UserId == "" { return "", "", errors.New(emptyUserID) @@ -1480,8 +1552,15 @@ func isDomainValid(domain string) bool { } // GetDNSDomain returns the configured dnsDomain -func (am *DefaultAccountManager) GetDNSDomain() string { - return am.dnsDomain +func (am *DefaultAccountManager) GetDNSDomain(settings *types.Settings) string { + if settings == nil { + return am.dnsDomain + } + if settings.DNSDomain == "" { + return am.dnsDomain + } + + return settings.DNSDomain } func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) { diff --git a/management/server/account/manager.go b/management/server/account/manager.go index ea664d10e..9bc4f9605 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/route" ) @@ -37,6 +38,7 @@ type Manager interface { SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*types.User, addIfNotExists bool) ([]*types.UserInfo, error) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) + GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) AccountExists(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) @@ -80,7 +82,7 @@ type Manager interface { SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) - GetDNSDomain() string + GetDNSDomain(settings *types.Settings) string StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) @@ -114,5 +116,5 @@ type Manager interface { CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) - GetCurrentUserInfo(ctx context.Context, accountID, userID string) (*types.UserInfo, error) + GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) } diff --git a/management/server/account_test.go b/management/server/account_test.go index 7f34cf845..c5583d226 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -14,30 +14,30 @@ import ( "time" "github.com/golang/mock/gomock" - - nbAccount "github.com/netbirdio/netbird/management/server/account" - "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/settings" - "github.com/netbirdio/netbird/management/server/util" - - resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" - routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" - networkTypes "github.com/netbirdio/netbird/management/server/networks/types" - + "github.com/netbirdio/netbird/management/server/idp" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" nbdns "github.com/netbirdio/netbird/dns" + nbAccount "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/cache" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/testutil" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" ) @@ -853,6 +853,42 @@ func TestAccountManager_DeleteAccount(t *testing.T) { t.Fatal(err) } + account.Users["service-user-1"] = &types.User{ + Id: "service-user-1", + Role: types.UserRoleAdmin, + IsServiceUser: true, + Issued: types.UserIssuedAPI, + PATs: map[string]*types.PersonalAccessToken{ + "pat-1": { + ID: "pat-1", + UserID: "service-user-1", + Name: "service-user-1", + HashedToken: "hashedToken", + CreatedAt: time.Now(), + }, + }, + } + account.Users[userId] = &types.User{ + Id: "service-user-2", + Role: types.UserRoleUser, + IsServiceUser: true, + Issued: types.UserIssuedAPI, + PATs: map[string]*types.PersonalAccessToken{ + "pat-2": { + ID: "pat-2", + UserID: userId, + Name: userId, + HashedToken: "hashedToken", + CreatedAt: time.Now(), + }, + }, + } + + err = manager.Store.SaveAccount(context.Background(), account) + if err != nil { + t.Fatal(err) + } + err = manager.DeleteAccount(context.Background(), account.Id, userId) if err != nil { t.Fatal(err) @@ -862,6 +898,14 @@ func TestAccountManager_DeleteAccount(t *testing.T) { if err == nil { t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount)) } + + pats, err := manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, "service-user-1") + require.NoError(t, err) + assert.Len(t, pats, 0) + + pats, err = manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, userId) + require.NoError(t, err) + assert.Len(t, pats, 0) } func BenchmarkTest_GetAccountWithclaims(b *testing.B) { @@ -3201,3 +3245,53 @@ func Test_UpdateToPrimaryAccount(t *testing.T) { assert.NoError(t, err) assert.True(t, account.IsDomainPrimaryAccount) } + +func TestDefaultAccountManager_IsCacheCold(t *testing.T) { + manager, err := createManager(t) + require.NoError(t, err) + + t.Run("memory cache", func(t *testing.T) { + t.Run("should always return true", func(t *testing.T) { + cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond) + require.NoError(t, err) + + cold, err := manager.isCacheCold(context.Background(), cacheStore) + assert.NoError(t, err) + assert.True(t, cold) + }) + }) + + t.Run("redis cache", func(t *testing.T) { + cleanup, redisURL, err := testutil.CreateRedisTestContainer() + require.NoError(t, err) + t.Cleanup(cleanup) + t.Setenv(cache.RedisStoreEnvVar, redisURL) + + cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond) + require.NoError(t, err) + + t.Run("should return true when no account exists", func(t *testing.T) { + cold, err := manager.isCacheCold(context.Background(), cacheStore) + assert.NoError(t, err) + assert.True(t, cold) + }) + + account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") + require.NoError(t, err) + + t.Run("should return true when account is not found in cache", func(t *testing.T) { + cold, err := manager.isCacheCold(context.Background(), cacheStore) + assert.NoError(t, err) + assert.True(t, cold) + }) + + t.Run("should return false when account is found in cache", func(t *testing.T) { + err = cacheStore.Set(context.Background(), account.Id, &idp.UserData{ID: "v", Name: "vv"}) + require.NoError(t, err) + + cold, err := manager.isCacheCold(context.Background(), cacheStore) + assert.NoError(t, err) + assert.False(t, cold) + }) + }) +} diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 46ae754cf..ed4be82e2 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -169,6 +169,8 @@ const ( ResourceAddedToGroup Activity = 82 ResourceRemovedFromGroup Activity = 83 + + AccountDNSDomainUpdated Activity = 84 ) var activityMap = map[Activity]Code{ @@ -264,6 +266,8 @@ var activityMap = map[Activity]Code{ ResourceAddedToGroup: {"Resource added to group", "resource.group.add"}, ResourceRemovedFromGroup: {"Resource removed from group", "resource.group.delete"}, + + AccountDNSDomainUpdated: {"Account DNS domain updated", "account.dns.domain.update"}, } // StringCode returns a string code of the activity diff --git a/management/server/cache/idp_test.go b/management/server/cache/idp_test.go index beefcd9bd..3fcfbb11a 100644 --- a/management/server/cache/idp_test.go +++ b/management/server/cache/idp_test.go @@ -8,12 +8,11 @@ import ( "github.com/eko/gocache/lib/v4/store" "github.com/redis/go-redis/v9" - "github.com/testcontainers/testcontainers-go" - testcontainersredis "github.com/testcontainers/testcontainers-go/modules/redis" "github.com/vmihailenco/msgpack/v5" "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/testutil" ) func TestNewIDPCacheManagers(t *testing.T) { @@ -27,21 +26,11 @@ func TestNewIDPCacheManagers(t *testing.T) { for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { if tc.redis { - ctx := context.Background() - redisContainer, err := testcontainersredis.RunContainer(ctx, testcontainers.WithImage("redis:7")) + cleanup, redisURL, err := testutil.CreateRedisTestContainer() if err != nil { t.Fatalf("couldn't start redis container: %s", err) } - defer func() { - if err := redisContainer.Terminate(ctx); err != nil { - t.Logf("failed to terminate container: %s", err) - } - }() - redisURL, err := redisContainer.ConnectionString(ctx) - if err != nil { - t.Fatalf("couldn't get connection string: %s", err) - } - + t.Cleanup(cleanup) t.Setenv(cache.RedisStoreEnvVar, redisURL) } cacheStore, err := cache.NewStore(context.Background(), cache.DefaultIDPCacheExpirationMax, cache.DefaultIDPCacheCleanupInterval) diff --git a/management/server/group.go b/management/server/group.go index 0bd840798..87d649228 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -158,6 +158,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac return nil } + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err) + return nil + } + dnsDomain := am.GetDNSDomain(settings) + for _, peerID := range addedPeers { peer, ok := peers[peerID] if !ok { @@ -168,7 +175,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac eventsToStore = append(eventsToStore, func() { meta := map[string]any{ "group": newGroup.Name, "group_id": newGroup.ID, - "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(dnsDomain), } am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, meta) }) @@ -184,7 +191,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac eventsToStore = append(eventsToStore, func() { meta := map[string]any{ "group": newGroup.Name, "group_id": newGroup.ID, - "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(dnsDomain), } am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, meta) }) diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index a7ed639c3..43d35f643 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -480,20 +480,12 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p s.ephemeralManager.OnPeerDisconnected(ctx, peer) } - var relayToken *Token - if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 { - relayToken, err = s.secretsManager.GenerateRelayToken() - if err != nil { - log.Errorf("failed generating Relay token: %v", err) - } + loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks) + if err != nil { + log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err) + return nil, status.Errorf(codes.Internal, "failed logging in peer") } - // if peer has reached this point then it has logged in - loginResp := &proto.LoginResponse{ - NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil), - PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain(), false), - Checks: toProtocolChecks(ctx, postureChecks), - } encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp) if err != nil { log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID) @@ -506,6 +498,32 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p }, nil } +func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) { + var relayToken *Token + var err error + if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 { + relayToken, err = s.secretsManager.GenerateRelayToken() + if err != nil { + log.Errorf("failed generating Relay token: %v", err) + } + } + + settings, err := s.settingsManager.GetSettings(ctx, peer.AccountID, activity.SystemInitiator) + if err != nil { + log.WithContext(ctx).Warnf("failed getting settings for peer %s: %s", peer.Key, err) + return nil, status.Errorf(codes.Internal, "failed getting settings") + } + + // if peer has reached this point then it has logged in + loginResp := &proto.LoginResponse{ + NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil), + PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain(settings), false), + Checks: toProtocolChecks(ctx, postureChecks), + } + + return loginResp, nil +} + // processJwtToken validates the existence of a JWT token in the login request, and returns the corresponding user ID if // the token is valid. // @@ -712,7 +730,7 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p return status.Errorf(codes.Internal, "error handling request") } - plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(), postureChecks, nil, settings.RoutingPeerDNSResolutionEnabled, settings.Extra) + plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings.RoutingPeerDNSResolutionEnabled, settings.Extra) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) if err != nil { diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index c699e9eef..bf40777fc 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -43,9 +43,30 @@ components: example: ch8i4ug6lnn4g9hqv7l0 settings: $ref: '#/components/schemas/AccountSettings' + domain: + description: Account domain + type: string + example: netbird.io + domain_category: + description: Account domain category + type: string + example: private + created_at: + description: Account creation date (UTC) + type: string + format: date-time + example: "2023-05-05T09:00:35.477782Z" + created_by: + description: Account creator + type: string + example: google-oauth2|277474792786460067937 required: - id - settings + - domain + - domain_category + - created_at + - created_by AccountSettings: type: object properties: @@ -91,6 +112,10 @@ components: description: Enables or disables DNS resolution on the routing peers type: boolean example: true + dns_domain: + description: Allows to define a custom dns domain for the account + type: string + example: my-organization.org extra: $ref: '#/components/schemas/AccountExtraSettings' required: @@ -191,11 +216,25 @@ components: UserPermissions: type: object properties: - dashboard_view: - description: User's permission to view the dashboard - type: string - enum: [ "limited", "blocked", "full" ] - example: limited + is_restricted: + type: boolean + description: Indicates whether this User's Peers view is restricted + modules: + type: object + additionalProperties: + type: object + additionalProperties: + type: boolean + propertyNames: + type: string + description: The operation type + propertyNames: + type: string + description: The module name + example: {"networks": { "read": true, "create": false, "update": false, "delete": false}, "peers": { "read": false, "create": false, "update": false, "delete": false} } + required: + - modules + - is_restricted UserRequest: type: object properties: @@ -1990,6 +2029,32 @@ components: - policy_name - icmp_type - icmp_code + NetworkTrafficEventsResponse: + type: object + properties: + data: + type: array + description: List of network traffic events + items: + $ref: "#/components/schemas/NetworkTrafficEvent" + page: + type: integer + description: Current page number + page_size: + type: integer + description: Number of items per page + total_records: + type: integer + description: Total number of event records available + total_pages: + type: integer + description: Total number of pages available + required: + - data + - page + - page_size + - total_records + - total_pages responses: not_found: description: Resource not found @@ -4206,15 +4271,77 @@ paths: tags: [ Events ] x-cloud-only: true x-experimental: true + parameters: + - name: page + in: query + description: Page number + required: false + schema: + type: integer + minimum: 1 + default: 1 + - name: page_size + in: query + description: Number of items per page + required: false + schema: + type: integer + minimum: 1 + maximum: 50000 + default: 1000 + - name: user_id + in: query + description: Filter by user ID + required: false + schema: + type: string + - name: protocol + in: query + description: Filter by protocol + required: false + schema: + type: integer + - name: type + in: query + description: Filter by event type + required: false + schema: + type: string + enum: [TYPE_UNKNOWN, TYPE_START, TYPE_END, TYPE_DROP] + - name: direction + in: query + description: Filter by direction + required: false + schema: + type: string + enum: [INGRESS, EGRESS, DIRECTION_UNKNOWN] + - name: search + in: query + description: Filters events with a partial match on user email, source and destination names and source and destination addresses + required: false + schema: + type: string + - name: start_date + in: query + description: Start date for filtering events (ISO 8601 format, e.g., 2024-01-01T00:00:00Z). + required: false + schema: + type: string + format: date-time + - name: end_date + in: query + description: End date for filtering events (ISO 8601 format, e.g., 2024-01-31T23:59:59Z). + required: false + schema: + type: string + format: date-time responses: "200": description: List of network traffic events content: application/json: schema: - type: array - items: - $ref: "#/components/schemas/NetworkTrafficEvent" + $ref: "#/components/schemas/NetworkTrafficEventsResponse" '400': "$ref": "#/components/responses/bad_request" '401': diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 9bdb3e4ac..e108c6884 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -178,11 +178,19 @@ const ( UserStatusInvited UserStatus = "invited" ) -// Defines values for UserPermissionsDashboardView. +// Defines values for GetApiEventsNetworkTrafficParamsType. const ( - UserPermissionsDashboardViewBlocked UserPermissionsDashboardView = "blocked" - UserPermissionsDashboardViewFull UserPermissionsDashboardView = "full" - UserPermissionsDashboardViewLimited UserPermissionsDashboardView = "limited" + GetApiEventsNetworkTrafficParamsTypeTYPEDROP GetApiEventsNetworkTrafficParamsType = "TYPE_DROP" + GetApiEventsNetworkTrafficParamsTypeTYPEEND GetApiEventsNetworkTrafficParamsType = "TYPE_END" + GetApiEventsNetworkTrafficParamsTypeTYPESTART GetApiEventsNetworkTrafficParamsType = "TYPE_START" + GetApiEventsNetworkTrafficParamsTypeTYPEUNKNOWN GetApiEventsNetworkTrafficParamsType = "TYPE_UNKNOWN" +) + +// Defines values for GetApiEventsNetworkTrafficParamsDirection. +const ( + GetApiEventsNetworkTrafficParamsDirectionDIRECTIONUNKNOWN GetApiEventsNetworkTrafficParamsDirection = "DIRECTION_UNKNOWN" + GetApiEventsNetworkTrafficParamsDirectionEGRESS GetApiEventsNetworkTrafficParamsDirection = "EGRESS" + GetApiEventsNetworkTrafficParamsDirectionINGRESS GetApiEventsNetworkTrafficParamsDirection = "INGRESS" ) // AccessiblePeer defines model for AccessiblePeer. @@ -223,6 +231,18 @@ type AccessiblePeer struct { // Account defines model for Account. type Account struct { + // CreatedAt Account creation date (UTC) + CreatedAt time.Time `json:"created_at"` + + // CreatedBy Account creator + CreatedBy string `json:"created_by"` + + // Domain Account domain + Domain string `json:"domain"` + + // DomainCategory Account domain category + DomainCategory string `json:"domain_category"` + // Id Account ID Id string `json:"id"` Settings AccountSettings `json:"settings"` @@ -247,7 +267,9 @@ type AccountRequest struct { // AccountSettings defines model for AccountSettings. type AccountSettings struct { - Extra *AccountExtraSettings `json:"extra,omitempty"` + // DnsDomain Allows to define a custom dns domain for the account + DnsDomain *string `json:"dns_domain,omitempty"` + Extra *AccountExtraSettings `json:"extra,omitempty"` // GroupsPropagationEnabled Allows propagate the new user auto groups to peers that belongs to the user GroupsPropagationEnabled *bool `json:"groups_propagation_enabled,omitempty"` @@ -908,6 +930,24 @@ type NetworkTrafficEvent struct { UserName *string `json:"user_name"` } +// NetworkTrafficEventsResponse defines model for NetworkTrafficEventsResponse. +type NetworkTrafficEventsResponse struct { + // Data List of network traffic events + Data []NetworkTrafficEvent `json:"data"` + + // Page Current page number + Page int `json:"page"` + + // PageSize Number of items per page + PageSize int `json:"page_size"` + + // TotalPages Total number of pages available + TotalPages int `json:"total_pages"` + + // TotalRecords Total number of event records available + TotalRecords int `json:"total_records"` +} + // NetworkTrafficLocation defines model for NetworkTrafficLocation. type NetworkTrafficLocation struct { // CityName Name of the city (if known). @@ -1710,13 +1750,11 @@ type UserCreateRequest struct { // UserPermissions defines model for UserPermissions. type UserPermissions struct { - // DashboardView User's permission to view the dashboard - DashboardView *UserPermissionsDashboardView `json:"dashboard_view,omitempty"` + // IsRestricted Indicates whether this User's Peers view is restricted + IsRestricted bool `json:"is_restricted"` + Modules map[string]map[string]bool `json:"modules"` } -// UserPermissionsDashboardView User's permission to view the dashboard -type UserPermissionsDashboardView string - // UserRequest defines model for UserRequest. type UserRequest struct { // AutoGroups Group IDs to auto-assign to peers registered by this user @@ -1729,6 +1767,42 @@ type UserRequest struct { Role string `json:"role"` } +// GetApiEventsNetworkTrafficParams defines parameters for GetApiEventsNetworkTraffic. +type GetApiEventsNetworkTrafficParams struct { + // Page Page number + Page *int `form:"page,omitempty" json:"page,omitempty"` + + // PageSize Number of items per page + PageSize *int `form:"page_size,omitempty" json:"page_size,omitempty"` + + // UserId Filter by user ID + UserId *string `form:"user_id,omitempty" json:"user_id,omitempty"` + + // Protocol Filter by protocol + Protocol *int `form:"protocol,omitempty" json:"protocol,omitempty"` + + // Type Filter by event type + Type *GetApiEventsNetworkTrafficParamsType `form:"type,omitempty" json:"type,omitempty"` + + // Direction Filter by direction + Direction *GetApiEventsNetworkTrafficParamsDirection `form:"direction,omitempty" json:"direction,omitempty"` + + // Search Filters events with a partial match on user email, source and destination names and source and destination addresses + Search *string `form:"search,omitempty" json:"search,omitempty"` + + // StartDate Start date for filtering events (ISO 8601 format, e.g., 2024-01-01T00:00:00Z). + StartDate *time.Time `form:"start_date,omitempty" json:"start_date,omitempty"` + + // EndDate End date for filtering events (ISO 8601 format, e.g., 2024-01-31T23:59:59Z). + EndDate *time.Time `form:"end_date,omitempty" json:"end_date,omitempty"` +} + +// GetApiEventsNetworkTrafficParamsType defines parameters for GetApiEventsNetworkTraffic. +type GetApiEventsNetworkTrafficParamsType string + +// GetApiEventsNetworkTrafficParamsDirection defines parameters for GetApiEventsNetworkTraffic. +type GetApiEventsNetworkTrafficParamsDirection string + // GetApiPeersParams defines parameters for GetApiPeers. type GetApiPeersParams struct { // Name Filter peers by name diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 483bb989a..3d4de31d0 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -62,6 +62,7 @@ func NewAPIHandler( authManager, accountManager.GetAccountIDFromUserAuth, accountManager.SyncUserJWTGroups, + accountManager.GetUserFromUserAuth, ) corsMiddleware := cors.AllowAll() diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index 6c8f8028a..7cad26bd6 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -47,13 +47,19 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) { accountID, userID := userAuth.AccountId, userAuth.UserId + meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + settings, err := h.settingsManager.GetSettings(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return } - resp := toAccountResponse(accountID, settings) + resp := toAccountResponse(accountID, settings, meta) util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) } @@ -113,6 +119,9 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { if req.Settings.RoutingPeerDnsResolutionEnabled != nil { settings.RoutingPeerDNSResolutionEnabled = *req.Settings.RoutingPeerDnsResolutionEnabled } + if req.Settings.DnsDomain != nil { + settings.DNSDomain = *req.Settings.DnsDomain + } updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings) if err != nil { @@ -120,7 +129,13 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { return } - resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings) + meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings, meta) util.WriteJSONObject(r.Context(), w, &resp) } @@ -149,7 +164,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -func toAccountResponse(accountID string, settings *types.Settings) *api.Account { +func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta) *api.Account { jwtAllowGroups := settings.JWTAllowGroups if jwtAllowGroups == nil { jwtAllowGroups = []string{} @@ -166,6 +181,7 @@ func toAccountResponse(accountID string, settings *types.Settings) *api.Account JwtAllowGroups: &jwtAllowGroups, RegularUsersViewBlocked: settings.RegularUsersViewBlocked, RoutingPeerDnsResolutionEnabled: &settings.RoutingPeerDNSResolutionEnabled, + DnsDomain: &settings.DNSDomain, } if settings.Extra != nil { @@ -177,7 +193,11 @@ func toAccountResponse(accountID string, settings *types.Settings) *api.Account } return &api.Account{ - Id: accountID, - Settings: apiSettings, + Id: accountID, + Settings: apiSettings, + CreatedAt: meta.CreatedAt, + CreatedBy: meta.CreatedBy, + Domain: meta.Domain, + DomainCategory: meta.DomainCategory, } } diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index e971a6514..57bbffc7c 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -50,6 +50,12 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler { accCopy.UpdateSettings(newSettings) return accCopy, nil }, + GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*types.Account, error) { + return account.Copy(), nil + }, + GetAccountMetaFunc: func(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) { + return account.GetMeta(), nil + }, }, settingsManager: settingsMockManager, } @@ -102,6 +108,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { JwtAllowGroups: &[]string{}, RegularUsersViewBlocked: true, RoutingPeerDnsResolutionEnabled: br(false), + DnsDomain: sr(""), }, expectedArray: true, expectedID: accountID, @@ -122,6 +129,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { JwtAllowGroups: &[]string{}, RegularUsersViewBlocked: false, RoutingPeerDnsResolutionEnabled: br(false), + DnsDomain: sr(""), }, expectedArray: false, expectedID: accountID, @@ -142,6 +150,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { JwtAllowGroups: &[]string{"test"}, RegularUsersViewBlocked: true, RoutingPeerDnsResolutionEnabled: br(false), + DnsDomain: sr(""), }, expectedArray: false, expectedID: accountID, @@ -162,6 +171,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { JwtAllowGroups: &[]string{}, RegularUsersViewBlocked: true, RoutingPeerDnsResolutionEnabled: br(false), + DnsDomain: sr(""), }, expectedArray: false, expectedID: accountID, diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index fa78836d8..58ea06ea3 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -65,7 +65,13 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, util.WriteError(ctx, err, w) return } - dnsDomain := h.accountManager.GetDNSDomain() + settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator) + if err != nil { + util.WriteError(ctx, err, w) + return + } + + dnsDomain := h.accountManager.GetDNSDomain(settings) grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID) grpsInfoMap := groups.ToGroupsInfoMap(grps, 0) @@ -110,7 +116,13 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri util.WriteError(ctx, err, w) return } - dnsDomain := h.accountManager.GetDNSDomain() + + settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator) + if err != nil { + util.WriteError(ctx, err, w) + return + } + dnsDomain := h.accountManager.GetDNSDomain(settings) peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID) if err != nil { @@ -192,7 +204,12 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { return } - dnsDomain := h.accountManager.GetDNSDomain() + settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, activity.SystemInitiator) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + dnsDomain := h.accountManager.GetDNSDomain(settings) grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) @@ -279,7 +296,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { return } - dnsDomain := h.accountManager.GetDNSDomain() + dnsDomain := h.accountManager.GetDNSDomain(account.Settings) customZone := account.GetPeersCustomZone(r.Context(), dnsDomain) netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go index a03c3c29d..a1fc13dd3 100644 --- a/management/server/http/handlers/peers/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -152,7 +152,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { }, }, nil }, - GetDNSDomainFunc: func() string { + GetDNSDomainFunc: func(settings *types.Settings) string { return "netbird.selfhosted" }, GetAccountFunc: func(ctx context.Context, accountID string) (*types.Account, error) { @@ -172,6 +172,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { _, ok := statuses[peerID] return ok }, + GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) { + return account.Settings, nil + }, }, } } diff --git a/management/server/http/handlers/users/users_handler.go b/management/server/http/handlers/users/users_handler.go index c69c6b944..ac04b8e35 100644 --- a/management/server/http/handlers/users/users_handler.go +++ b/management/server/http/handlers/users/users_handler.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/users" nbcontext "github.com/netbirdio/netbird/management/server/context" ) @@ -272,15 +273,33 @@ func (h *handler) getCurrentUser(w http.ResponseWriter, r *http.Request) { return } - accountID, userID := userAuth.AccountId, userAuth.UserId - - user, err := h.accountManager.GetCurrentUserInfo(ctx, accountID, userID) + user, err := h.accountManager.GetCurrentUserInfo(ctx, userAuth) if err != nil { util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(r.Context(), w, toUserResponse(user, userID)) + util.WriteJSONObject(r.Context(), w, toUserWithPermissionsResponse(user, userAuth.UserId)) +} + +func toUserWithPermissionsResponse(user *users.UserInfoWithPermissions, userID string) *api.User { + response := toUserResponse(user.UserInfo, userID) + + // stringify modules and operations keys + modules := make(map[string]map[string]bool) + for module, operations := range user.Permissions { + modules[string(module)] = make(map[string]bool) + for op, val := range operations { + modules[string(module)][string(op)] = val + } + } + + response.Permissions = &api.UserPermissions{ + IsRestricted: user.Restricted, + Modules: modules, + } + + return response } func toUserResponse(user *types.UserInfo, currenUserID string) *api.User { @@ -316,8 +335,5 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User { IsBlocked: user.IsBlocked, LastLogin: &user.LastLogin, Issued: &user.Issued, - Permissions: &api.UserPermissions{ - DashboardView: (*api.UserPermissionsDashboardView)(&user.Permissions.DashboardView), - }, } } diff --git a/management/server/http/handlers/users/users_handler_test.go b/management/server/http/handlers/users/users_handler_test.go index 604954819..58e33a6d5 100644 --- a/management/server/http/handlers/users/users_handler_test.go +++ b/management/server/http/handlers/users/users_handler_test.go @@ -13,12 +13,16 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/roles" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/users" ) const ( @@ -107,7 +111,7 @@ func initUsersTestData() *handler { return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID) } - info, err := update.Copy().ToUserInfo(nil, &types.Settings{RegularUsersViewBlocked: false}) + info, err := update.Copy().ToUserInfo(nil) if err != nil { return nil, err } @@ -124,8 +128,8 @@ func initUsersTestData() *handler { return nil }, - GetCurrentUserInfoFunc: func(ctx context.Context, accountID, userID string) (*types.UserInfo, error) { - switch userID { + GetCurrentUserInfoFunc: func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) { + switch userAuth.UserId { case "not-found": return nil, status.NewUserNotFoundError("not-found") case "not-of-account": @@ -135,52 +139,68 @@ func initUsersTestData() *handler { case "service-user": return nil, status.NewPermissionDeniedError() case "owner": - return &types.UserInfo{ - ID: "owner", - Name: "", - Role: "owner", - Status: "active", - IsServiceUser: false, - IsBlocked: false, - NonDeletable: false, - Issued: "api", - Permissions: types.UserPermissions{ - DashboardView: "full", + return &users.UserInfoWithPermissions{ + UserInfo: &types.UserInfo{ + ID: "owner", + Name: "", + Role: "owner", + Status: "active", + IsServiceUser: false, + IsBlocked: false, + NonDeletable: false, + Issued: "api", }, + Permissions: mergeRolePermissions(roles.Owner), }, nil case "regular-user": - return &types.UserInfo{ - ID: "regular-user", - Name: "", - Role: "user", - Status: "active", - IsServiceUser: false, - IsBlocked: false, - NonDeletable: false, - Issued: "api", - Permissions: types.UserPermissions{ - DashboardView: "limited", + return &users.UserInfoWithPermissions{ + UserInfo: &types.UserInfo{ + ID: "regular-user", + Name: "", + Role: "user", + Status: "active", + IsServiceUser: false, + IsBlocked: false, + NonDeletable: false, + Issued: "api", }, + Permissions: mergeRolePermissions(roles.User), }, nil case "admin-user": - return &types.UserInfo{ - ID: "admin-user", - Name: "", - Role: "admin", - Status: "active", - IsServiceUser: false, - IsBlocked: false, - NonDeletable: false, - LastLogin: time.Time{}, - Issued: "api", - Permissions: types.UserPermissions{ - DashboardView: "full", + return &users.UserInfoWithPermissions{ + UserInfo: &types.UserInfo{ + ID: "admin-user", + Name: "", + Role: "admin", + Status: "active", + IsServiceUser: false, + IsBlocked: false, + NonDeletable: false, + LastLogin: time.Time{}, + Issued: "api", }, + Permissions: mergeRolePermissions(roles.Admin), + }, nil + case "restricted-user": + return &users.UserInfoWithPermissions{ + UserInfo: &types.UserInfo{ + ID: "restricted-user", + Name: "", + Role: "user", + Status: "active", + IsServiceUser: false, + IsBlocked: false, + NonDeletable: false, + LastLogin: time.Time{}, + Issued: "api", + }, + Permissions: mergeRolePermissions(roles.User), + Restricted: true, }, nil } - return nil, fmt.Errorf("user id %s not handled", userID) + return nil, fmt.Errorf("user id %s not handled", userAuth.UserId) }, }, } @@ -546,6 +566,7 @@ func TestCurrentUser(t *testing.T) { name string expectedStatus int requestAuth nbcontext.UserAuth + expectedResult *api.User }{ { name: "without auth", @@ -575,16 +596,78 @@ func TestCurrentUser(t *testing.T) { name: "owner", requestAuth: nbcontext.UserAuth{UserId: "owner"}, expectedStatus: http.StatusOK, + expectedResult: &api.User{ + Id: "owner", + Role: "owner", + Status: "active", + IsBlocked: false, + IsCurrent: ptr(true), + IsServiceUser: ptr(false), + AutoGroups: []string{}, + Issued: ptr("api"), + LastLogin: ptr(time.Time{}), + Permissions: &api.UserPermissions{ + Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.Owner)), + }, + }, }, { name: "regular user", requestAuth: nbcontext.UserAuth{UserId: "regular-user"}, expectedStatus: http.StatusOK, + expectedResult: &api.User{ + Id: "regular-user", + Role: "user", + Status: "active", + IsBlocked: false, + IsCurrent: ptr(true), + IsServiceUser: ptr(false), + AutoGroups: []string{}, + Issued: ptr("api"), + LastLogin: ptr(time.Time{}), + Permissions: &api.UserPermissions{ + Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.User)), + }, + }, }, { name: "admin user", requestAuth: nbcontext.UserAuth{UserId: "admin-user"}, expectedStatus: http.StatusOK, + expectedResult: &api.User{ + Id: "admin-user", + Role: "admin", + Status: "active", + IsBlocked: false, + IsCurrent: ptr(true), + IsServiceUser: ptr(false), + AutoGroups: []string{}, + Issued: ptr("api"), + LastLogin: ptr(time.Time{}), + Permissions: &api.UserPermissions{ + Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.Admin)), + }, + }, + }, + { + name: "restricted user", + requestAuth: nbcontext.UserAuth{UserId: "restricted-user"}, + expectedStatus: http.StatusOK, + expectedResult: &api.User{ + Id: "restricted-user", + Role: "user", + Status: "active", + IsBlocked: false, + IsCurrent: ptr(true), + IsServiceUser: ptr(false), + AutoGroups: []string{}, + Issued: ptr("api"), + LastLogin: ptr(time.Time{}), + Permissions: &api.UserPermissions{ + IsRestricted: true, + Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.User)), + }, + }, }, } @@ -603,10 +686,42 @@ func TestCurrentUser(t *testing.T) { res := rr.Result() defer res.Body.Close() - if status := rr.Code; status != tc.expectedStatus { - t.Fatalf("handler returned wrong status code: got %v want %v", - status, tc.expectedStatus) + assert.Equal(t, tc.expectedStatus, rr.Code, "handler returned wrong status code") + + if tc.expectedResult != nil { + var result api.User + require.NoError(t, json.NewDecoder(res.Body).Decode(&result)) + assert.EqualValues(t, *tc.expectedResult, result) } }) } } + +func ptr[T any, PT *T](x T) PT { + return &x +} + +func mergeRolePermissions(role roles.RolePermissions) roles.Permissions { + permissions := roles.Permissions{} + + for k := range modules.All { + if rolePermissions, ok := role.Permissions[k]; ok { + permissions[k] = rolePermissions + continue + } + permissions[k] = role.AutoAllowNew + } + + return permissions +} + +func stringifyPermissionsKeys(permissions roles.Permissions) map[string]map[string]bool { + modules := make(map[string]map[string]bool) + for module, operations := range permissions { + modules[string(module)] = make(map[string]bool) + for op, val := range operations { + modules[string(module)][string(op)] = val + } + } + return modules +} diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index a8e6790a9..f2732fbf8 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -15,16 +15,20 @@ import ( "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) type EnsureAccountFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth) error +type GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) + // AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens type AuthMiddleware struct { - authManager auth.Manager - ensureAccount EnsureAccountFunc - syncUserJWTGroups SyncUserJWTGroupsFunc + authManager auth.Manager + ensureAccount EnsureAccountFunc + getUserFromUserAuth GetUserFromUserAuthFunc + syncUserJWTGroups SyncUserJWTGroupsFunc } // NewAuthMiddleware instance constructor @@ -32,11 +36,13 @@ func NewAuthMiddleware( authManager auth.Manager, ensureAccount EnsureAccountFunc, syncUserJWTGroups SyncUserJWTGroupsFunc, + getUserFromUserAuth GetUserFromUserAuthFunc, ) *AuthMiddleware { return &AuthMiddleware{ - authManager: authManager, - ensureAccount: ensureAccount, - syncUserJWTGroups: syncUserJWTGroups, + authManager: authManager, + ensureAccount: ensureAccount, + syncUserJWTGroups: syncUserJWTGroups, + getUserFromUserAuth: getUserFromUserAuth, } } @@ -123,6 +129,12 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*h log.WithContext(ctx).Errorf("HTTP server failed to sync user JWT groups: %s", err) } + _, err = m.getUserFromUserAuth(ctx, userAuth) + if err != nil { + log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err) + return r, err + } + return nbcontext.SetUserAuthInRequest(r, userAuth), nil } @@ -155,6 +167,11 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*h IsPAT: true, } + if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 { + userAuth.AccountId = impersonate[0] + userAuth.IsChild = ok + } + return nbcontext.SetUserAuthInRequest(r, userAuth), nil } diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index 3dc7d51cb..2285ed244 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -190,6 +190,9 @@ func TestAuthMiddleware_Handler(t *testing.T) { func(ctx context.Context, userAuth nbcontext.UserAuth) error { return nil }, + func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, ) handlerToTest := authMiddleware.Handler(nextHandler) @@ -239,14 +242,15 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { }, }, { - name: "Valid PAT Token ignores child", + name: "Valid PAT Token accesses child", path: "/test?account=xyz", authHeader: "Token " + PAT, expectedUserAuth: &nbcontext.UserAuth{ - AccountId: accountID, + AccountId: "xyz", UserId: userID, Domain: testAccount.Domain, DomainCategory: testAccount.DomainCategory, + IsChild: true, IsPAT: true, }, }, @@ -291,6 +295,9 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { func(ctx context.Context, userAuth nbcontext.UserAuth) error { return nil }, + func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, ) for _, tc := range tt { diff --git a/management/server/migration/migration.go b/management/server/migration/migration.go index d7abbad47..c8a852e0a 100644 --- a/management/server/migration/migration.go +++ b/management/server/migration/migration.go @@ -352,3 +352,24 @@ func MigrateNewField[T any](ctx context.Context, db *gorm.DB, columnName string, log.WithContext(ctx).Infof("Migration of empty %s to default value in table %s completed", columnName, tableName) return nil } + +func DropIndex[T any](ctx context.Context, db *gorm.DB, indexName string) error { + var model T + + if !db.Migrator().HasTable(&model) { + log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model) + return nil + } + + if !db.Migrator().HasIndex(&model, indexName) { + log.WithContext(ctx).Debugf("index %s does not exist in table %T, no migration needed", indexName, model) + return nil + } + + if err := db.Migrator().DropIndex(&model, indexName); err != nil { + return fmt.Errorf("failed to drop index %s: %w", indexName, err) + } + + log.WithContext(ctx).Infof("dropped index %s from table %T", indexName, model) + return nil +} diff --git a/management/server/migration/migration_test.go b/management/server/migration/migration_test.go index e907d6853..94377930a 100644 --- a/management/server/migration/migration_test.go +++ b/management/server/migration/migration_test.go @@ -227,3 +227,25 @@ func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case2(t *testing. assert.Equal(t, "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", key.Key, "Key should be hashed") } + +func TestDropIndex(t *testing.T) { + db := setupDatabase(t) + + err := db.AutoMigrate(&types.SetupKey{}) + require.NoError(t, err, "Failed to auto-migrate tables") + + err = db.Save(&types.SetupKey{ + Id: "1", + Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", + }).Error + require.NoError(t, err, "Failed to insert setup key") + + exist := db.Migrator().HasIndex(&types.SetupKey{}, "idx_setup_keys_account_id") + assert.True(t, exist, "Should have the index") + + err = migration.DropIndex[types.SetupKey](context.Background(), db, "idx_setup_keys_account_id") + require.NoError(t, err, "Migration should not fail to remove index") + + exist = db.Migrator().HasIndex(&types.SetupKey{}, "idx_setup_keys_account_id") + assert.False(t, exist, "Should not have the index") +} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 870fe3219..0dd3f927e 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -19,6 +19,7 @@ import ( "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/route" ) @@ -83,7 +84,7 @@ type MockAccountManager struct { CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) DeleteAccountFunc func(ctx context.Context, accountID, userID string) error - GetDNSDomainFunc func() string + GetDNSDomainFunc func(settings *types.Settings) string StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) GetEventsFunc func(ctx context.Context, accountID, userID string) ([]*activity.Event, error) GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*types.DNSSettings, error) @@ -115,7 +116,8 @@ type MockAccountManager struct { CreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, error) UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error) GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error) - GetCurrentUserInfoFunc func(ctx context.Context, accountID, userID string) (*types.UserInfo, error) + GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) + GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error) } func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { @@ -619,9 +621,9 @@ func (am *MockAccountManager) GetPeers(ctx context.Context, accountID, userID, n } // GetDNSDomain mocks GetDNSDomain of the AccountManager interface -func (am *MockAccountManager) GetDNSDomain() string { +func (am *MockAccountManager) GetDNSDomain(settings *types.Settings) string { if am.GetDNSDomainFunc != nil { - return am.GetDNSDomainFunc() + return am.GetDNSDomainFunc(settings) } return "" } @@ -803,6 +805,14 @@ func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID stri return nil, status.Errorf(codes.Unimplemented, "method GetAccountByID is not implemented") } +// GetAccountByID mocks GetAccountByID of the AccountManager interface +func (am *MockAccountManager) GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) { + if am.GetAccountMetaFunc != nil { + return am.GetAccountMetaFunc(ctx, accountID, userID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetAccountMeta is not implemented") +} + // GetUserByID mocks GetUserByID of the AccountManager interface func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) { if am.GetUserByIDFunc != nil { @@ -873,9 +883,9 @@ func (am *MockAccountManager) GetOwnerInfo(ctx context.Context, accountId string return nil, status.Errorf(codes.Unimplemented, "method GetOwnerInfo is not implemented") } -func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, accountID, userID string) (*types.UserInfo, error) { +func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) { if am.GetCurrentUserInfoFunc != nil { - return am.GetCurrentUserInfoFunc(ctx, accountID, userID) + return am.GetCurrentUserInfoFunc(ctx, userAuth) } return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented") } diff --git a/management/server/networks/resources/types/resource.go b/management/server/networks/resources/types/resource.go index ecac0a724..04c63608d 100644 --- a/management/server/networks/resources/types/resource.go +++ b/management/server/networks/resources/types/resource.go @@ -30,7 +30,7 @@ func (p NetworkResourceType) String() string { } type NetworkResource struct { - ID string `gorm:"index"` + ID string `gorm:"primaryKey"` NetworkID string `gorm:"index"` AccountID string `gorm:"index"` Name string diff --git a/management/server/networks/routers/types/router.go b/management/server/networks/routers/types/router.go index 5158ebb12..71465868f 100644 --- a/management/server/networks/routers/types/router.go +++ b/management/server/networks/routers/types/router.go @@ -10,7 +10,7 @@ import ( ) type NetworkRouter struct { - ID string `gorm:"index"` + ID string `gorm:"primaryKey"` NetworkID string `gorm:"index"` AccountID string `gorm:"index"` Peer string diff --git a/management/server/networks/types/network.go b/management/server/networks/types/network.go index a4ba7b821..d1c7f2b33 100644 --- a/management/server/networks/types/network.go +++ b/management/server/networks/types/network.go @@ -7,7 +7,7 @@ import ( ) type Network struct { - ID string `gorm:"index"` + ID string `gorm:"primaryKey"` AccountID string `gorm:"index"` Name string Description string diff --git a/management/server/peer.go b/management/server/peer.go index 27825a148..9ff80442e 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -49,20 +49,9 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID return nil, err } - peers := make([]*nbpeer.Peer, 0) - peersMap := make(map[string]*nbpeer.Peer) - - for _, peer := range accountPeers { - if user.IsRegularUser() && user.Id != peer.UserID { - // only display peers that belong to the current user if the current user is not an admin - continue - } - peers = append(peers, peer) - peersMap[peer.ID] = peer - } - + // @note if the user has permission to read peers it shows all account peers if allowed { - return peers, nil + return accountPeers, nil } settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) @@ -70,10 +59,22 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID return nil, fmt.Errorf("failed to get account settings: %w", err) } - if settings.RegularUsersViewBlocked { + if user.IsRestrictable() && settings.RegularUsersViewBlocked { return []*nbpeer.Peer{}, nil } + // @note if it does not have permission read peers then only display it's own peers + peers := make([]*nbpeer.Peer, 0) + peersMap := make(map[string]*nbpeer.Peer) + + for _, peer := range accountPeers { + if user.Id != peer.UserID { + continue + } + peers = append(peers, peer) + peersMap[peer.ID] = peer + } + return am.getUserAccessiblePeers(ctx, accountID, peersMap, peers) } @@ -206,6 +207,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user var sshChanged bool var loginExpirationChanged bool var inactivityExpirationChanged bool + var dnsDomain string err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, update.ID) @@ -223,7 +225,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user return err } - update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), peerGroupList, settings.Extra) + dnsDomain = am.GetDNSDomain(settings) + + update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, dnsDomain, peerGroupList, settings.Extra) if err != nil { return err } @@ -276,11 +280,11 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user if !peer.SSHEnabled { event = activity.PeerSSHDisabled } - am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) + am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(dnsDomain)) } if peerLabelChanged { - am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRenamed, peer.EventMeta(am.GetDNSDomain())) + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRenamed, peer.EventMeta(dnsDomain)) } if loginExpirationChanged { @@ -288,7 +292,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user if !peer.LoginExpirationEnabled { event = activity.PeerLoginExpirationDisabled } - am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) + am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(dnsDomain)) if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled { am.checkAndSchedulePeerLoginExpiration(ctx, accountID) @@ -300,7 +304,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user if !peer.InactivityExpirationEnabled { event = activity.PeerInactivityExpirationDisabled } - am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) + am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(dnsDomain)) if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled { am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) @@ -413,7 +417,7 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin if err != nil { return nil, err } - customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) + customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings)) proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id) if err != nil { @@ -574,8 +578,13 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s ExtraDNSLabels: peer.ExtraDNSLabels, AllowExtraDNSLabels: allowExtraDNSLabels, } + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return fmt.Errorf("failed to get account settings: %w", err) + } + opEvent.TargetID = newPeer.ID - opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain()) + opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain(settings)) if !addedByUser { opEvent.Meta["setup_key_name"] = setupKeyName } @@ -591,10 +600,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } } - settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) - if err != nil { - return fmt.Errorf("failed to get account settings: %w", err) - } newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra) err = transaction.AddPeerToAllGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID) @@ -1024,7 +1029,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return nil, nil, nil, err } - customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) + customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings)) proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id) if err != nil { @@ -1060,7 +1065,12 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact log.WithContext(ctx).Debugf("failed to update user last login: %v", err) } - am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain())) + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, peer.AccountID) + if err != nil { + return fmt.Errorf("failed to get account settings: %w", err) + } + + am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain(settings))) return nil } @@ -1174,7 +1184,8 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account semaphore := make(chan struct{}, 10) dnsCache := &DNSConfigCache{} - customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) + dnsDomain := am.GetDNSDomain(account.Settings) + customZone := account.GetPeersCustomZone(ctx, dnsDomain) resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() @@ -1215,7 +1226,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account return } - update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSetting) + update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSetting) am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) }(peer) } @@ -1270,7 +1281,8 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI } dnsCache := &DNSConfigCache{} - customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) + dnsDomain := am.GetDNSDomain(account.Settings) + customZone := account.GetPeersCustomZone(ctx, dnsDomain) resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() @@ -1299,7 +1311,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI return } - update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSettings) + update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSettings) am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) } @@ -1484,6 +1496,12 @@ func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) { var peerDeletedEvents []func() + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + dnsDomain := am.GetDNSDomain(settings) + for _, peer := range peers { if err := am.integratedPeerValidator.PeerDeleted(ctx, accountID, peer.ID); err != nil { return nil, err @@ -1514,7 +1532,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto }) am.peersUpdateManager.CloseChannel(ctx, peer.ID) peerDeletedEvents = append(peerDeletedEvents, func() { - am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain())) + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain)) }) } diff --git a/management/server/permissions/manager.go b/management/server/permissions/manager.go index 50a44eb0f..ebbce5d4a 100644 --- a/management/server/permissions/manager.go +++ b/management/server/permissions/manager.go @@ -20,6 +20,8 @@ type Manager interface { ValidateUserPermissions(ctx context.Context, accountID, userID string, module modules.Module, operation operations.Operation) (bool, error) ValidateRoleModuleAccess(ctx context.Context, accountID string, role roles.RolePermissions, module modules.Module, operation operations.Operation) bool ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error + + GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error) } type managerImpl struct { @@ -96,3 +98,22 @@ func (m *managerImpl) ValidateAccountAccess(ctx context.Context, accountID strin } return nil } + +func (m *managerImpl) GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error) { + roleMap, ok := roles.RolesMap[role] + if !ok { + return roles.Permissions{}, status.NewUserRoleNotFoundError(string(role)) + } + + permissions := roles.Permissions{} + + for k := range modules.All { + if rolePermissions, ok := roleMap.Permissions[k]; ok { + permissions[k] = rolePermissions + continue + } + permissions[k] = roleMap.AutoAllowNew + } + + return permissions, nil +} diff --git a/management/server/permissions/manager_mock.go b/management/server/permissions/manager_mock.go index 266a24270..fa115d628 100644 --- a/management/server/permissions/manager_mock.go +++ b/management/server/permissions/manager_mock.go @@ -38,6 +38,21 @@ func (m *MockManager) EXPECT() *MockManagerMockRecorder { return m.recorder } +// GetPermissionsByRole mocks base method. +func (m *MockManager) GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPermissionsByRole", ctx, role) + ret0, _ := ret[0].(roles.Permissions) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPermissionsByRole indicates an expected call of GetPermissionsByRole. +func (mr *MockManagerMockRecorder) GetPermissionsByRole(ctx, role interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPermissionsByRole", reflect.TypeOf((*MockManager)(nil).GetPermissionsByRole), ctx, role) +} + // ValidateAccountAccess mocks base method. func (m *MockManager) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error { m.ctrl.T.Helper() diff --git a/management/server/permissions/modules/module.go b/management/server/permissions/modules/module.go index 4c42b6190..3d021a235 100644 --- a/management/server/permissions/modules/module.go +++ b/management/server/permissions/modules/module.go @@ -17,3 +17,19 @@ const ( SetupKeys Module = "setup_keys" Pats Module = "pats" ) + +var All = map[Module]struct{}{ + Networks: {}, + Peers: {}, + Groups: {}, + Settings: {}, + Accounts: {}, + Dns: {}, + Nameservers: {}, + Events: {}, + Policies: {}, + Routes: {}, + Users: {}, + SetupKeys: {}, + Pats: {}, +} diff --git a/management/server/permissions/roles/auditor.go b/management/server/permissions/roles/auditor.go new file mode 100644 index 000000000..33d8651f4 --- /dev/null +++ b/management/server/permissions/roles/auditor.go @@ -0,0 +1,16 @@ +package roles + +import ( + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/management/server/types" +) + +var Auditor = RolePermissions{ + Role: types.UserRoleAuditor, + AutoAllowNew: map[operations.Operation]bool{ + operations.Read: true, + operations.Create: false, + operations.Update: false, + operations.Delete: false, + }, +} diff --git a/management/server/permissions/roles/network_admin.go b/management/server/permissions/roles/network_admin.go new file mode 100644 index 000000000..e95d58381 --- /dev/null +++ b/management/server/permissions/roles/network_admin.go @@ -0,0 +1,97 @@ +package roles + +import ( + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/management/server/types" +) + +var NetworkAdmin = RolePermissions{ + Role: types.UserRoleNetworkAdmin, + AutoAllowNew: map[operations.Operation]bool{ + operations.Read: false, + operations.Create: false, + operations.Update: false, + operations.Delete: false, + }, + Permissions: Permissions{ + modules.Networks: { + operations.Read: true, + operations.Create: true, + operations.Update: true, + operations.Delete: true, + }, + modules.Groups: { + operations.Read: true, + operations.Create: true, + operations.Update: true, + operations.Delete: true, + }, + modules.Settings: { + operations.Read: true, + operations.Create: false, + operations.Update: false, + operations.Delete: false, + }, + modules.Accounts: { + operations.Read: true, + operations.Create: false, + operations.Update: false, + operations.Delete: false, + }, + modules.Dns: { + operations.Read: true, + operations.Create: true, + operations.Update: true, + operations.Delete: true, + }, + modules.Nameservers: { + operations.Read: true, + operations.Create: true, + operations.Update: true, + operations.Delete: true, + }, + modules.Events: { + operations.Read: true, + operations.Create: false, + operations.Update: false, + operations.Delete: false, + }, + modules.Policies: { + operations.Read: true, + operations.Create: true, + operations.Update: true, + operations.Delete: true, + }, + modules.Routes: { + operations.Read: true, + operations.Create: true, + operations.Update: true, + operations.Delete: true, + }, + modules.Users: { + operations.Read: true, + operations.Create: false, + operations.Update: false, + operations.Delete: false, + }, + modules.SetupKeys: { + operations.Read: true, + operations.Create: false, + operations.Update: false, + operations.Delete: false, + }, + modules.Pats: { + operations.Read: true, + operations.Create: true, + operations.Update: true, + operations.Delete: true, + }, + modules.Peers: { + operations.Read: true, + operations.Create: false, + operations.Update: false, + operations.Delete: false, + }, + }, +} diff --git a/management/server/permissions/roles/role_permissions.go b/management/server/permissions/roles/role_permissions.go index dda7e6b99..754e568f5 100644 --- a/management/server/permissions/roles/role_permissions.go +++ b/management/server/permissions/roles/role_permissions.go @@ -15,7 +15,9 @@ type RolePermissions struct { type Permissions map[modules.Module]map[operations.Operation]bool var RolesMap = map[types.UserRole]RolePermissions{ - types.UserRoleOwner: Owner, - types.UserRoleAdmin: Admin, - types.UserRoleUser: User, + types.UserRoleOwner: Owner, + types.UserRoleAdmin: Admin, + types.UserRoleUser: User, + types.UserRoleAuditor: Auditor, + types.UserRoleNetworkAdmin: NetworkAdmin, } diff --git a/management/server/route.go b/management/server/route.go index 8b91e127a..02755a708 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -398,7 +398,9 @@ func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.Ro Protocol: getProtoProtocol(rule.Protocol), PortInfo: getProtoPortInfo(rule), IsDynamic: rule.IsDynamic, + Domains: rule.Domains.ToPunycodeList(), PolicyID: []byte(rule.PolicyID), + RouteID: string(rule.RouteID), } } diff --git a/management/server/route_test.go b/management/server/route_test.go index dcda3e6d1..833477b55 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1850,6 +1850,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Destination: "192.168.0.0/16", Protocol: "all", Port: 80, + RouteID: "route1:peerA", }, { SourceRanges: []string{ @@ -1861,6 +1862,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Destination: "192.168.0.0/16", Protocol: "all", Port: 320, + RouteID: "route1:peerA", }, } additionalFirewallRule := []*types.RouteFirewallRule{ @@ -1872,6 +1874,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Destination: "192.168.10.0/16", Protocol: "tcp", Port: 80, + RouteID: "route4:peerA", }, { SourceRanges: []string{ @@ -1880,6 +1883,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Action: "accept", Destination: "192.168.10.0/16", Protocol: "all", + RouteID: "route4:peerA", }, } @@ -1888,6 +1892,9 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { // peerD is also the routing peer for route1, should contain same routes firewall rules as peerA routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers) assert.Len(t, routesFirewallRules, 2) + for _, rule := range expectedRoutesFirewallRules { + rule.RouteID = "route1:peerD" + } assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules)) // peerE is a single routing peer for route 2 and route 3 @@ -1901,6 +1908,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Destination: existingNetwork.String(), Protocol: "tcp", PortRange: types.RulePortRange{Start: 80, End: 350}, + RouteID: "route2", }, { SourceRanges: []string{"0.0.0.0/0"}, @@ -1909,6 +1917,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Protocol: "all", Domains: domain.List{"example.com"}, IsDynamic: true, + RouteID: "route3", }, { SourceRanges: []string{"::/0"}, @@ -1917,6 +1926,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Protocol: "all", Domains: domain.List{"example.com"}, IsDynamic: true, + RouteID: "route3", }, } assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules)) @@ -2676,6 +2686,7 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { Destination: "192.168.0.0/16", Protocol: "all", Port: 80, + RouteID: "resource2:peerA", }, { SourceRanges: []string{ @@ -2687,6 +2698,7 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { Destination: "192.168.0.0/16", Protocol: "all", Port: 320, + RouteID: "resource2:peerA", }, } @@ -2701,6 +2713,7 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { Port: 80, Domains: domain.List{"example.com"}, IsDynamic: true, + RouteID: "resource4:peerA", }, { SourceRanges: []string{ @@ -2711,6 +2724,7 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { Protocol: "all", Domains: domain.List{"example.com"}, IsDynamic: true, + RouteID: "resource4:peerA", }, } assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(append(expectedFirewallRules, additionalFirewallRules...))) @@ -2719,6 +2733,9 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerD", resourcePoliciesMap, resourceRoutersMap) firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerD"], validatedPeers, routes, resourcePoliciesMap) assert.Len(t, firewallRules, 2) + for _, rule := range expectedFirewallRules { + rule.RouteID = "resource2:peerD" + } assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules)) assert.Len(t, sourcePeers, 3) @@ -2736,6 +2753,7 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { Destination: "10.10.10.0/24", Protocol: "tcp", PortRange: types.RulePortRange{Start: 80, End: 350}, + RouteID: "resource1:peerE", }, } assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules)) @@ -2758,6 +2776,7 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { Destination: "10.12.12.1/32", Protocol: "tcp", Port: 8080, + RouteID: "resource5:peerL", }, } assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules)) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index aacb56ab8..d568460f9 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -658,6 +658,21 @@ func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*types.Account) { return all } +func (s *SqlStore) GetAccountMeta(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.AccountMeta, error) { + var accountMeta types.AccountMeta + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). + First(&accountMeta, idQueryCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("error when getting account meta %s from the store: %s", accountID, result.Error) + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewAccountNotFoundError(accountID) + } + return nil, status.NewGetAccountFromStoreError(result.Error) + } + + return &accountMeta, nil +} + func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { start := time.Now() defer func() { @@ -785,6 +800,19 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) ( return s.GetAccount(ctx, peer.AccountID) } +func (s *SqlStore) GetAnyAccountID(ctx context.Context) (string, error) { + var account types.Account + result := s.db.WithContext(ctx).Select("id").Order("created_at desc").Limit(1).Find(&account) + if result.Error != nil { + return "", status.NewGetAccountFromStoreError(result.Error) + } + if result.RowsAffected == 0 { + return "", status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + return account.Id, nil +} + func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) { var peer nbpeer.Peer var accountID string @@ -1655,18 +1683,26 @@ func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, } func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&types.Policy{}, accountAndIDQueryCondition, accountID, policyID) - if err := result.Error; err != nil { - log.WithContext(ctx).Errorf("failed to delete policy from store: %s", err) - return status.Errorf(status.Internal, "failed to delete policy from store") - } + return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := tx.Where("policy_id = ?", policyID).Delete(&types.PolicyRule{}).Error; err != nil { + return fmt.Errorf("delete policy rules: %w", err) + } - if result.RowsAffected == 0 { - return status.NewPolicyNotFoundError(policyID) - } + result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}). + Where(accountAndIDQueryCondition, accountID, policyID). + Delete(&types.Policy{}) - return nil + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete policy from store: %s", err) + return status.Errorf(status.Internal, "failed to delete policy from store") + } + + if result.RowsAffected == 0 { + return status.NewPolicyNotFoundError(policyID) + } + + return nil + }) } // GetAccountPostureChecks retrieves posture checks for an account. diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 589e727e9..8e99b34e1 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -60,10 +60,10 @@ func Test_NewStore(t *testing.T) { runTestForAllEngines(t, "", func(t *testing.T, store Store) { if store == nil { - t.Errorf("expected to create a new Store") + t.Fatalf("expected to create a new Store") } if len(store.GetAllAccounts(context.Background())) != 0 { - t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") + t.Fatalf("expected to create a new empty Accounts map when creating a new FileStore") } }) } @@ -1115,7 +1115,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { group := &types.Group{ ID: "group-id", - AccountID: "account-id", + AccountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", Name: "group-name", Issued: "api", Peers: nil, @@ -3247,3 +3247,44 @@ func TestSqlStore_SaveGroups_LargeBatch(t *testing.T) { require.NoError(t, err) require.Equal(t, 8003, len(accountGroups)) } + +func TestSqlStore_GetAccountMeta(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + accountMeta, err := store.GetAccountMeta(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err) + require.NotNil(t, accountMeta) + require.Equal(t, accountID, accountMeta.AccountID) + require.Equal(t, "edafee4e-63fb-11ec-90d6-0242ac120003", accountMeta.CreatedBy) + require.Equal(t, "test.com", accountMeta.Domain) + require.Equal(t, "private", accountMeta.DomainCategory) + require.Equal(t, time.Date(2024, time.October, 2, 14, 1, 38, 210000000, time.UTC), accountMeta.CreatedAt.UTC()) +} + +func TestSqlStore_GetAnyAccountID(t *testing.T) { + t.Run("should return account ID when accounts exist", func(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID, err := store.GetAnyAccountID(context.Background()) + require.NoError(t, err) + assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", accountID) + }) + + t.Run("should return error when no accounts exist", func(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID, err := store.GetAnyAccountID(context.Background()) + require.Error(t, err) + sErr, ok := status.FromError(err) + assert.True(t, ok) + assert.Equal(t, sErr.Type(), status.NotFound) + assert.Empty(t, accountID) + }) +} diff --git a/management/server/store/store.go b/management/server/store/store.go index c13a8dfe6..6da623956 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -50,10 +50,12 @@ type Store interface { GetAccountsCounter(ctx context.Context) (int64, error) GetAllAccounts(ctx context.Context) []*types.Account GetAccount(ctx context.Context, accountID string) (*types.Account, error) + GetAccountMeta(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.AccountMeta, error) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*types.Account, error) + GetAnyAccountID(ctx context.Context) (string, error) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error) GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error) @@ -313,6 +315,15 @@ func getMigrations(ctx context.Context) []migrationFunc { func(db *gorm.DB) error { return migration.MigrateNewField[routerTypes.NetworkRouter](ctx, db, "enabled", true) }, + func(db *gorm.DB) error { + return migration.DropIndex[networkTypes.Network](ctx, db, "idx_networks_id") + }, + func(db *gorm.DB) error { + return migration.DropIndex[resourceTypes.NetworkResource](ctx, db, "idx_network_resources_id") + }, + func(db *gorm.DB) error { + return migration.DropIndex[routerTypes.NetworkRouter](ctx, db, "idx_network_routers_id") + }, } } diff --git a/management/server/testdata/extended-store.sql b/management/server/testdata/extended-store.sql index 2859e82c8..7900dabf5 100644 --- a/management/server/testdata/extended-store.sql +++ b/management/server/testdata/extended-store.sql @@ -25,7 +25,7 @@ CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`); CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`); CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); -INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:01:38.210014+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','edafee4e-63fb-11ec-90d6-0242ac120003','2024-10-02 16:01:38.210000+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); INSERT INTO setup_keys VALUES('A2C8E62B-38F5-4553-B31E-DD66C696CEBB','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'["cfefqs706sqkneg59g2g"]',0,0); INSERT INTO setup_keys VALUES('A2C8E62B-38F5-4553-B31E-DD66C696CEBC','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBC','Faulty key with non existing group','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'["abcd"]',0,0); INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','["cfefqs706sqkneg59g3g"]',0,NULL,'2024-10-02 16:01:38.210678+02:00','api',0,''); diff --git a/management/server/testutil/store.go b/management/server/testutil/store.go index 8672efa7f..ca022bfef 100644 --- a/management/server/testutil/store.go +++ b/management/server/testutil/store.go @@ -12,6 +12,7 @@ import ( "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/modules/mysql" "github.com/testcontainers/testcontainers-go/modules/postgres" + testcontainersredis "github.com/testcontainers/testcontainers-go/modules/redis" "github.com/testcontainers/testcontainers-go/wait" ) @@ -84,3 +85,28 @@ func CreatePostgresTestContainer() (func(), error) { return cleanup, os.Setenv("NETBIRD_STORE_ENGINE_POSTGRES_DSN", talksConn) } + +// CreateRedisTestContainer creates a new Redis container for testing. +func CreateRedisTestContainer() (func(), string, error) { + ctx := context.Background() + + redisContainer, err := testcontainersredis.RunContainer(ctx, testcontainers.WithImage("redis:7")) + if err != nil { + return nil, "", err + } + + cleanup := func() { + timeoutCtx, cancelFunc := context.WithTimeout(ctx, 1*time.Second) + defer cancelFunc() + if err = redisContainer.Terminate(timeoutCtx); err != nil { + log.WithContext(ctx).Warnf("failed to stop redis container %s: %s", redisContainer.GetContainerID(), err) + } + } + + redisURL, err := redisContainer.ConnectionString(ctx) + if err != nil { + return nil, "", err + } + + return cleanup, redisURL, nil +} diff --git a/management/server/testutil/store_ios.go b/management/server/testutil/store_ios.go index edde62f1e..a614258d2 100644 --- a/management/server/testutil/store_ios.go +++ b/management/server/testutil/store_ios.go @@ -14,3 +14,9 @@ func CreateMysqlTestContainer() (func(), error) { // Empty function for MySQL }, nil } + +func CreateRedisTestContainer() (func(), string, error) { + return func() { + // Empty function for Redis + }, "", nil +} diff --git a/management/server/types/account.go b/management/server/types/account.go index 687709991..8315f5796 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -40,6 +40,17 @@ const ( type LookupMap map[string]struct{} +// AccountMeta is a struct that contains a stripped down version of the Account object. +// It doesn't carry any peers, groups, policies, or routes, etc. Just some metadata (e.g. ID, created by, created at, etc). +type AccountMeta struct { + // AccountId is the unique identifier of the account + AccountID string `gorm:"column:id"` + CreatedAt time.Time + CreatedBy string + Domain string + DomainCategory string +} + // Account represents a unique account of the system type Account struct { // we have to name column to aid as it collides with Network.Id when work with associations @@ -855,6 +866,16 @@ func (a *Account) Copy() *Account { } } +func (a *Account) GetMeta() *AccountMeta { + return &AccountMeta{ + AccountID: a.Id, + CreatedBy: a.CreatedBy, + CreatedAt: a.CreatedAt, + Domain: a.Domain, + DomainCategory: a.DomainCategory, + } +} + func (a *Account) GetGroupAll() (*Group, error) { for _, g := range a.Groups { if g.Name == "All" { @@ -1219,6 +1240,7 @@ func getDefaultPermit(route *route.Route) []*RouteFirewallRule { Protocol: string(PolicyRuleProtocolALL), Domains: route.Domains, IsDynamic: route.IsDynamic(), + RouteID: route.ID, } rules = append(rules, &rule) @@ -1267,7 +1289,7 @@ func (a *Account) GetPeerNetworkResourceFirewallRules(ctx context.Context, peer if route.Peer != peer.Key { continue } - resourceAppliedPolicies := resourcePolicies[route.GetResourceID()] + resourceAppliedPolicies := resourcePolicies[string(route.GetResourceID())] distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups) rules := a.getRouteFirewallRules(ctx, peer.ID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers) diff --git a/management/server/types/firewall_rule.go b/management/server/types/firewall_rule.go index d98a56871..ef54abea2 100644 --- a/management/server/types/firewall_rule.go +++ b/management/server/types/firewall_rule.go @@ -62,6 +62,7 @@ func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule baseRule := RouteFirewallRule{ PolicyID: rule.PolicyID, + RouteID: route.ID, SourceRanges: sourceRanges, Action: string(rule.Action), Destination: route.Network.String(), diff --git a/management/server/types/group.go b/management/server/types/group.go index 00a28fa77..1b321387c 100644 --- a/management/server/types/group.go +++ b/management/server/types/group.go @@ -14,7 +14,7 @@ const ( // Group of the peers for ACL type Group struct { // ID of the group - ID string + ID string `gorm:"primaryKey"` // AccountID is a reference to Account that this object belongs AccountID string `json:"-" gorm:"index"` diff --git a/management/server/types/route_firewall_rule.go b/management/server/types/route_firewall_rule.go index 5b752bc36..c09c64a3d 100644 --- a/management/server/types/route_firewall_rule.go +++ b/management/server/types/route_firewall_rule.go @@ -2,6 +2,7 @@ package types import ( "github.com/netbirdio/netbird/management/domain" + "github.com/netbirdio/netbird/route" ) // RouteFirewallRule a firewall rule applicable for a routed network. @@ -9,6 +10,9 @@ type RouteFirewallRule struct { // PolicyID is the ID of the policy this rule is derived from PolicyID string + // RouteID is the ID of the route this rule belongs to. + RouteID route.ID + // SourceRanges IP ranges of the routing peers. SourceRanges []string diff --git a/management/server/types/settings.go b/management/server/types/settings.go index 7054ede8c..c8de2a98c 100644 --- a/management/server/types/settings.go +++ b/management/server/types/settings.go @@ -39,6 +39,9 @@ type Settings struct { // RoutingPeerDNSResolutionEnabled enabled the DNS resolution on the routing peers RoutingPeerDNSResolutionEnabled bool + // DNSDomain is the custom domain for that account + DNSDomain string + // Extra is a dictionary of Account settings Extra *ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"` } @@ -58,6 +61,7 @@ func (s *Settings) Copy() *Settings { PeerInactivityExpiration: s.PeerInactivityExpiration, RoutingPeerDNSResolutionEnabled: s.RoutingPeerDNSResolutionEnabled, + DNSDomain: s.DNSDomain, } if s.Extra != nil { settings.Extra = s.Extra.Copy() diff --git a/management/server/types/user.go b/management/server/types/user.go index 5f7a4f2cb..783fe14da 100644 --- a/management/server/types/user.go +++ b/management/server/types/user.go @@ -15,6 +15,8 @@ const ( UserRoleUser UserRole = "user" UserRoleUnknown UserRole = "unknown" UserRoleBillingAdmin UserRole = "billing_admin" + UserRoleAuditor UserRole = "auditor" + UserRoleNetworkAdmin UserRole = "network_admin" UserStatusActive UserStatus = "active" UserStatusDisabled UserStatus = "disabled" @@ -35,6 +37,10 @@ func StrRoleToUserRole(strRole string) UserRole { return UserRoleUser case "billing_admin": return UserRoleBillingAdmin + case "auditor": + return UserRoleAuditor + case "network_admin": + return UserRoleNetworkAdmin default: return UserRoleUnknown } @@ -59,11 +65,6 @@ type UserInfo struct { LastLogin time.Time `json:"last_login"` Issued string `json:"issued"` IntegrationReference integration_reference.IntegrationReference `json:"-"` - Permissions UserPermissions `json:"permissions"` -} - -type UserPermissions struct { - DashboardView string `json:"dashboard_view"` } // User represents a user of the system @@ -126,21 +127,18 @@ func (u *User) IsRegularUser() bool { return !u.HasAdminPower() && !u.IsServiceUser } +// IsRestrictable checks whether a user is in a restrictable role. +func (u *User) IsRestrictable() bool { + return u.Role == UserRoleUser || u.Role == UserRoleBillingAdmin +} + // ToUserInfo converts a User object to a UserInfo object. -func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) { +func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) { autoGroups := u.AutoGroups if autoGroups == nil { autoGroups = []string{} } - dashboardViewPermissions := "full" - if !u.HasAdminPower() { - dashboardViewPermissions = "limited" - if settings.RegularUsersViewBlocked { - dashboardViewPermissions = "blocked" - } - } - if userData == nil { return &UserInfo{ ID: u.Id, @@ -153,9 +151,6 @@ func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo IsBlocked: u.Blocked, LastLogin: u.GetLastLogin(), Issued: u.Issued, - Permissions: UserPermissions{ - DashboardView: dashboardViewPermissions, - }, }, nil } if userData.ID != u.Id { @@ -178,9 +173,6 @@ func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo IsBlocked: u.Blocked, LastLogin: u.GetLastLogin(), Issued: u.Issued, - Permissions: UserPermissions{ - DashboardView: dashboardViewPermissions, - }, }, nil } diff --git a/management/server/user.go b/management/server/user.go index 9ec16e72c..44ad3b68f 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -12,6 +12,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" nbContext "github.com/netbirdio/netbird/management/server/context" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/permissions/modules" @@ -19,6 +20,7 @@ import ( "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/util" ) @@ -122,11 +124,6 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u CreatedAt: time.Now().UTC(), } - settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) - if err != nil { - return nil, err - } - if err = am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser); err != nil { return nil, err } @@ -138,7 +135,7 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u am.StoreEvent(ctx, userID, newUser.Id, accountID, activity.UserInvited, nil) - return newUser.ToUserInfo(idpUser, settings) + return newUser.ToUserInfo(idpUser) } // createNewIdpUser validates the invite and creates a new user in the IdP @@ -360,6 +357,7 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string return nil, err } + // @note this is essential to prevent non admin users with Pats create permission frpm creating one for a service user if initiatorUserID != targetUserID && !(initiatorUser.HasAdminPower() && targetUser.IsServiceUser) { return nil, status.NewAdminPermissionError() } @@ -727,19 +725,14 @@ func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initi // If the AccountManager has a non-nil idpManager and the User is not a service user, // it will attempt to look up the UserData from the cache. func (am *DefaultAccountManager) getUserInfo(ctx context.Context, user *types.User, accountID string) (*types.UserInfo, error) { - settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) - if err != nil { - return nil, err - } - if !isNil(am.idpManager) && !user.IsServiceUser { userData, err := am.lookupUserInCache(ctx, user.Id, accountID) if err != nil { return nil, err } - return user.ToUserInfo(userData, settings) + return user.ToUserInfo(userData) } - return user.ToUserInfo(nil, settings) + return user.ToUserInfo(nil) } // validateUserUpdate validates the update operation for a user. @@ -879,17 +872,12 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a queriedUsers = append(queriedUsers, usersFromIntegration...) } - settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) - if err != nil { - return nil, err - } - userInfosMap := make(map[string]*types.UserInfo) // in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo if len(queriedUsers) == 0 { for _, accountUser := range accountUsers { - info, err := accountUser.ToUserInfo(nil, settings) + info, err := accountUser.ToUserInfo(nil) if err != nil { return nil, err } @@ -902,7 +890,7 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a for _, localUser := range accountUsers { var info *types.UserInfo if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains { - info, err = localUser.ToUserInfo(queriedUser, settings) + info, err = localUser.ToUserInfo(queriedUser) if err != nil { return nil, err } @@ -912,14 +900,6 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a name = localUser.ServiceUserName } - dashboardViewPermissions := "full" - if !localUser.HasAdminPower() { - dashboardViewPermissions = "limited" - if settings.RegularUsersViewBlocked { - dashboardViewPermissions = "blocked" - } - } - info = &types.UserInfo{ ID: localUser.Id, Email: "", @@ -929,7 +909,6 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a Status: string(types.UserStatusActive), IsServiceUser: localUser.IsServiceUser, NonDeletable: localUser.NonDeletable, - Permissions: types.UserPermissions{DashboardView: dashboardViewPermissions}, } } userInfosMap[info.ID] = info @@ -940,6 +919,12 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a // expireAndUpdatePeers expires all peers of the given user and updates them in the account func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accountID string, peers []*nbpeer.Peer) error { + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return err + } + dnsDomain := am.GetDNSDomain(settings) + var peerIDs []string for _, peer := range peers { // nolint:staticcheck @@ -957,7 +942,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou am.StoreEvent( ctx, peer.UserID, peer.ID, accountID, - activity.PeerLoginExpired, peer.EventMeta(am.GetDNSDomain()), + activity.PeerLoginExpired, peer.EventMeta(dnsDomain), ) } @@ -1233,8 +1218,10 @@ func validateUserInvite(invite *types.UserInfo) error { return nil } -// GetCurrentUserInfo retrieves the account's current user info -func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, accountID, userID string) (*types.UserInfo, error) { +// GetCurrentUserInfo retrieves the account's current user info and permissions +func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) { + accountID, userID := userAuth.AccountId, userAuth.UserId + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err @@ -1252,10 +1239,25 @@ func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, account return nil, err } + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + userInfo, err := am.getUserInfo(ctx, user, accountID) if err != nil { return nil, err } - return userInfo, nil + userWithPermissions := &users.UserInfoWithPermissions{ + UserInfo: userInfo, + Restricted: !userAuth.IsChild && user.IsRestrictable() && settings.RegularUsersViewBlocked, + } + + permissions, err := am.permissionsManager.GetPermissionsByRole(ctx, user.Role) + if err == nil { + userWithPermissions.Permissions = permissions + } + + return userWithPermissions, nil } diff --git a/management/server/user_test.go b/management/server/user_test.go index 83c5ac49a..66bdc1683 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -13,7 +13,10 @@ import ( nbcache "github.com/netbirdio/netbird/management/server/cache" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/roles" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/util" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -1020,90 +1023,6 @@ func TestDefaultAccountManager_ListUsers(t *testing.T) { assert.Equal(t, 2, regular) } -func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { - testCases := []struct { - name string - role types.UserRole - limitedViewSettings bool - expectedDashboardPermissions string - }{ - { - name: "Regular user, no limited view settings", - role: types.UserRoleUser, - limitedViewSettings: false, - expectedDashboardPermissions: "limited", - }, - { - name: "Admin user, no limited view settings", - role: types.UserRoleAdmin, - limitedViewSettings: false, - expectedDashboardPermissions: "full", - }, - { - name: "Owner, no limited view settings", - role: types.UserRoleOwner, - limitedViewSettings: false, - expectedDashboardPermissions: "full", - }, - { - name: "Regular user, limited view settings", - role: types.UserRoleUser, - limitedViewSettings: true, - expectedDashboardPermissions: "blocked", - }, - { - name: "Admin user, limited view settings", - role: types.UserRoleAdmin, - limitedViewSettings: true, - expectedDashboardPermissions: "full", - }, - { - name: "Owner, limited view settings", - role: types.UserRoleOwner, - limitedViewSettings: true, - expectedDashboardPermissions: "full", - }, - } - - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { - store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) - if err != nil { - t.Fatalf("Error when creating store: %s", err) - } - t.Cleanup(cleanup) - - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users["normal_user1"] = types.NewUser("normal_user1", testCase.role, false, false, "", []string{}, types.UserIssuedAPI) - account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings - delete(account.Users, mockUserID) - - err = store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } - - permissionsManager := permissions.NewManager(store) - am := DefaultAccountManager{ - Store: store, - eventStore: &activity.InMemoryEventStore{}, - permissionsManager: permissionsManager, - } - - users, err := am.ListUsers(context.Background(), mockAccountID) - if err != nil { - t.Fatalf("Error when checking user role: %s", err) - } - - assert.Equal(t, 1, len(users)) - - userInfo, _ := users[0].ToUserInfo(nil, account.Settings) - assert.Equal(t, testCase.expectedDashboardPermissions, userInfo.Permissions.DashboardView) - }) - } - -} - func TestDefaultAccountManager_ExternalCache(t *testing.T) { store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) if err != nil { @@ -1654,121 +1573,154 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { tt := []struct { name string - accountId string - userId string + userAuth nbcontext.UserAuth expectedErr error - expectedResult *types.UserInfo + expectedResult *users.UserInfoWithPermissions }{ { name: "not found", - accountId: account1.Id, - userId: "not-found", + userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "not-found"}, expectedErr: status.NewUserNotFoundError("not-found"), }, { name: "not part of account", - accountId: account1.Id, - userId: "account2Owner", + userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "account2Owner"}, expectedErr: status.NewUserNotPartOfAccountError(), }, { name: "blocked", - accountId: account1.Id, - userId: "blocked-user", + userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "blocked-user"}, expectedErr: status.NewUserBlockedError(), }, { name: "service user", - accountId: account1.Id, - userId: "service-user", + userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "service-user"}, expectedErr: status.NewPermissionDeniedError(), }, { - name: "owner user", - accountId: account1.Id, - userId: "account1Owner", - expectedResult: &types.UserInfo{ - ID: "account1Owner", - Name: "", - Role: "owner", - AutoGroups: []string{}, - Status: "active", - IsServiceUser: false, - IsBlocked: false, - NonDeletable: false, - LastLogin: time.Time{}, - Issued: "api", - IntegrationReference: integration_reference.IntegrationReference{}, - Permissions: types.UserPermissions{ - DashboardView: "full", + name: "owner user", + userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "account1Owner"}, + expectedResult: &users.UserInfoWithPermissions{ + UserInfo: &types.UserInfo{ + ID: "account1Owner", + Name: "", + Role: "owner", + AutoGroups: []string{}, + Status: "active", + IsServiceUser: false, + IsBlocked: false, + NonDeletable: false, + LastLogin: time.Time{}, + Issued: "api", + IntegrationReference: integration_reference.IntegrationReference{}, }, + Permissions: mergeRolePermissions(roles.Owner), }, }, { - name: "regular user", - accountId: account1.Id, - userId: "regular-user", - expectedResult: &types.UserInfo{ - ID: "regular-user", - Name: "", - Role: "user", - Status: "active", - IsServiceUser: false, - IsBlocked: false, - NonDeletable: false, - LastLogin: time.Time{}, - Issued: "api", - IntegrationReference: integration_reference.IntegrationReference{}, - Permissions: types.UserPermissions{ - DashboardView: "limited", + name: "regular user", + userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "regular-user"}, + expectedResult: &users.UserInfoWithPermissions{ + UserInfo: &types.UserInfo{ + ID: "regular-user", + Name: "", + Role: "user", + Status: "active", + IsServiceUser: false, + IsBlocked: false, + NonDeletable: false, + LastLogin: time.Time{}, + Issued: "api", + IntegrationReference: integration_reference.IntegrationReference{}, }, + Permissions: mergeRolePermissions(roles.User), }, }, { - name: "admin user", - accountId: account1.Id, - userId: "admin-user", - expectedResult: &types.UserInfo{ - ID: "admin-user", - Name: "", - Role: "admin", - Status: "active", - IsServiceUser: false, - IsBlocked: false, - NonDeletable: false, - LastLogin: time.Time{}, - Issued: "api", - IntegrationReference: integration_reference.IntegrationReference{}, - Permissions: types.UserPermissions{ - DashboardView: "full", + name: "admin user", + userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "admin-user"}, + expectedResult: &users.UserInfoWithPermissions{ + UserInfo: &types.UserInfo{ + ID: "admin-user", + Name: "", + Role: "admin", + Status: "active", + IsServiceUser: false, + IsBlocked: false, + NonDeletable: false, + LastLogin: time.Time{}, + Issued: "api", + IntegrationReference: integration_reference.IntegrationReference{}, }, + Permissions: mergeRolePermissions(roles.Admin), }, }, { - name: "settings blocked regular user", - accountId: account2.Id, - userId: "settings-blocked-user", - expectedResult: &types.UserInfo{ - ID: "settings-blocked-user", - Name: "", - Role: "user", - Status: "active", - IsServiceUser: false, - IsBlocked: false, - NonDeletable: false, - LastLogin: time.Time{}, - Issued: "api", - IntegrationReference: integration_reference.IntegrationReference{}, - Permissions: types.UserPermissions{ - DashboardView: "blocked", + name: "settings blocked regular user", + userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user"}, + expectedResult: &users.UserInfoWithPermissions{ + UserInfo: &types.UserInfo{ + ID: "settings-blocked-user", + Name: "", + Role: "user", + Status: "active", + IsServiceUser: false, + IsBlocked: false, + NonDeletable: false, + LastLogin: time.Time{}, + Issued: "api", + IntegrationReference: integration_reference.IntegrationReference{}, }, + Permissions: mergeRolePermissions(roles.User), + Restricted: true, + }, + }, + + { + name: "settings blocked regular user child account", + userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user", IsChild: true}, + expectedResult: &users.UserInfoWithPermissions{ + UserInfo: &types.UserInfo{ + ID: "settings-blocked-user", + Name: "", + Role: "user", + Status: "active", + IsServiceUser: false, + IsBlocked: false, + NonDeletable: false, + LastLogin: time.Time{}, + Issued: "api", + IntegrationReference: integration_reference.IntegrationReference{}, + }, + Permissions: mergeRolePermissions(roles.User), + Restricted: false, + }, + }, + { + name: "settings blocked owner user", + userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "account2Owner"}, + expectedResult: &users.UserInfoWithPermissions{ + UserInfo: &types.UserInfo{ + ID: "account2Owner", + Name: "", + Role: "owner", + AutoGroups: []string{}, + Status: "active", + IsServiceUser: false, + IsBlocked: false, + NonDeletable: false, + LastLogin: time.Time{}, + Issued: "api", + IntegrationReference: integration_reference.IntegrationReference{}, + }, + Permissions: mergeRolePermissions(roles.Owner), }, }, } for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { - result, err := am.GetCurrentUserInfo(context.Background(), tc.accountId, tc.userId) + result, err := am.GetCurrentUserInfo(context.Background(), tc.userAuth) if tc.expectedErr != nil { assert.Equal(t, err, tc.expectedErr) @@ -1780,3 +1732,17 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { }) } } + +func mergeRolePermissions(role roles.RolePermissions) roles.Permissions { + permissions := roles.Permissions{} + + for k := range modules.All { + if rolePermissions, ok := role.Permissions[k]; ok { + permissions[k] = rolePermissions + continue + } + permissions[k] = role.AutoAllowNew + } + + return permissions +} diff --git a/management/server/users/user.go b/management/server/users/user.go new file mode 100644 index 000000000..2f2788271 --- /dev/null +++ b/management/server/users/user.go @@ -0,0 +1,14 @@ +package users + +import ( + "github.com/netbirdio/netbird/management/server/permissions/roles" + "github.com/netbirdio/netbird/management/server/types" +) + +// Wrapped UserInfo with Role Permissions +type UserInfoWithPermissions struct { + *types.UserInfo + + Permissions roles.Permissions + Restricted bool +} diff --git a/relay/client/dialer/quic/quic.go b/relay/client/dialer/quic/quic.go index 7fd486f87..3fd48fb19 100644 --- a/relay/client/dialer/quic/quic.go +++ b/relay/client/dialer/quic/quic.go @@ -28,6 +28,16 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { return nil, err } + // Get the base TLS config + tlsClientConfig := quictls.ClientQUICTLSConfig() + + // Set ServerName to hostname if not an IP address + host, _, splitErr := net.SplitHostPort(quicURL) + if splitErr == nil && net.ParseIP(host) == nil { + // It's a hostname, not an IP - modify directly + tlsClientConfig.ServerName = host + } + quicConfig := &quic.Config{ KeepAlivePeriod: 30 * time.Second, MaxIdleTimeout: 4 * time.Minute, @@ -47,7 +57,7 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { return nil, err } - session, err := quic.Dial(ctx, udpConn, udpAddr, quictls.ClientQUICTLSConfig(), quicConfig) + session, err := quic.Dial(ctx, udpConn, udpAddr, tlsClientConfig, quicConfig) if err != nil { if errors.Is(err, context.Canceled) { return nil, err @@ -61,12 +71,29 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { } func prepareURL(address string) (string, error) { - if !strings.HasPrefix(address, "rel://") && !strings.HasPrefix(address, "rels://") { + var host string + var defaultPort string + + switch { + case strings.HasPrefix(address, "rels://"): + host = address[7:] + defaultPort = "443" + case strings.HasPrefix(address, "rel://"): + host = address[6:] + defaultPort = "80" + default: return "", fmt.Errorf("unsupported scheme: %s", address) } - if strings.HasPrefix(address, "rels://") { - return address[7:], nil + finalHost, finalPort, err := net.SplitHostPort(host) + if err != nil { + if strings.Contains(err.Error(), "missing port") { + return host + ":" + defaultPort, nil + } + + // return any other split error as is + return "", err } - return address[6:], nil + + return finalHost + ":" + finalPort, nil } diff --git a/release_files/install.sh b/release_files/install.sh index e5a61dcfe..da5c613d5 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -199,6 +199,21 @@ install_native_binaries() { fi } +# Handle macOS .pkg installer +install_pkg() { + case "$(uname -m)" in + x86_64) ARCH="amd64" ;; + arm64|aarch64) ARCH="arm64" ;; + *) echo "Unsupported macOS arch: $(uname -m)" >&2; exit 1 ;; + esac + + PKG_URL=$(curl -sIL -o /dev/null -w '%{url_effective}' "https://pkgs.netbird.io/macos/${ARCH}") + echo "Downloading NetBird macOS installer from https://pkgs.netbird.io/macos/${ARCH}" + curl -fsSL -o /tmp/netbird.pkg "${PKG_URL}" + ${SUDO} installer -pkg /tmp/netbird.pkg -target / + rm -f /tmp/netbird.pkg +} + check_use_bin_variable() { if [ "${USE_BIN_INSTALL}-x" = "true-x" ]; then echo "The installation will be performed using binary files" @@ -209,16 +224,22 @@ check_use_bin_variable() { install_netbird() { if [ -x "$(command -v netbird)" ]; then - status_output=$(netbird status) - if echo "$status_output" | grep -q 'Management: Connected' && echo "$status_output" | grep -q 'Signal: Connected'; then - echo "NetBird service is running, please stop it before proceeding" - exit 1 - fi + status_output="$(netbird status 2>&1 || true)" - if [ -n "$status_output" ]; then - echo "NetBird seems to be installed already, please remove it before proceeding" - exit 1 - fi + if echo "$status_output" | grep -q 'failed to connect to daemon error: context deadline exceeded'; then + echo "Warning: could not reach NetBird daemon (timeout), proceeding anyway" + else + if echo "$status_output" | grep -q 'Management: Connected' && \ + echo "$status_output" | grep -q 'Signal: Connected'; then + echo "NetBird service is running, please stop it before proceeding" + exit 1 + fi + + if [ -n "$status_output" ]; then + echo "NetBird seems to be installed already, please remove it before proceeding" + exit 1 + fi + fi fi # Run the installation, if a desktop environment is not detected @@ -265,6 +286,16 @@ install_netbird() { ${SUDO} pacman -Syy add_aur_repo ;; + pkg) + # Check if the package is already installed + if [ -f /Library/Receipts/netbird.pkg ]; then + echo "NetBird is already installed. Please remove it before proceeding." + exit 1 + fi + + # Install the package + install_pkg + ;; brew) # Remove Netbird if it had been installed using Homebrew before if brew ls --versions netbird >/dev/null 2>&1; then @@ -274,7 +305,7 @@ install_netbird() { netbird service stop netbird service uninstall - # Unlik the app + # Unlink the app brew unlink netbird fi @@ -312,7 +343,7 @@ install_netbird() { echo "package_manager=$PACKAGE_MANAGER" | ${SUDO} tee "$CONFIG_FILE" > /dev/null # Load and start netbird service - if [ "$PACKAGE_MANAGER" != "rpm-ostree" ]; then + if [ "$PACKAGE_MANAGER" != "rpm-ostree" ] && [ "$PACKAGE_MANAGER" != "pkg" ]; then if ! ${SUDO} netbird service install 2>&1; then echo "NetBird service has already been loaded" fi @@ -451,9 +482,8 @@ if type uname >/dev/null 2>&1; then # Check the availability of a compatible package manager if check_use_bin_variable; then PACKAGE_MANAGER="bin" - elif [ -x "$(command -v brew)" ]; then - PACKAGE_MANAGER="brew" - echo "The installation will be performed using brew package manager" + else + PACKAGE_MANAGER="pkg" fi ;; esac @@ -471,4 +501,4 @@ case "$UPDATE_FLAG" in ;; *) install_netbird -esac +esac \ No newline at end of file diff --git a/route/hauniqueid.go b/route/hauniqueid.go index 4d952beba..064608171 100644 --- a/route/hauniqueid.go +++ b/route/hauniqueid.go @@ -4,13 +4,14 @@ import "strings" const haSeparator = "|" +// HAUniqueID is a unique identifier that is used to group high availability routes. type HAUniqueID string func (id HAUniqueID) String() string { return string(id) } -// NetID returns the Network ID from the HAUniqueID +// NetID returns the NetID from the HAUniqueID func (id HAUniqueID) NetID() NetID { if i := strings.LastIndex(string(id), haSeparator); i != -1 { return NetID(id[:i]) diff --git a/route/route.go b/route/route.go index f7bf3ea87..722dacc2d 100644 --- a/route/route.go +++ b/route/route.go @@ -6,8 +6,6 @@ import ( "slices" "strings" - log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server/status" ) @@ -46,10 +44,16 @@ const ( DomainNetwork ) +// ID is the unique route ID. type ID string +// ResID is the resourceID part of a route.ID (first part before the colon). +type ResID string + +// NetID is the route network identifier, a human-readable string. type NetID string +// HAMap is a map of HAUniqueID to a list of routes. type HAMap map[HAUniqueID][]*Route // NetworkType route network type @@ -162,21 +166,25 @@ func (r *Route) IsDynamic() bool { return r.NetworkType == DomainNetwork } +// GetHAUniqueID returns the HAUniqueID for the route, it can be used for grouping. func (r *Route) GetHAUniqueID() HAUniqueID { - if r.IsDynamic() { - domains, err := r.Domains.String() - if err != nil { - log.Errorf("Failed to convert domains to string: %v", err) - domains = r.Domains.PunycodeString() - } - return HAUniqueID(fmt.Sprintf("%s%s%s", r.NetID, haSeparator, domains)) - } - return HAUniqueID(fmt.Sprintf("%s%s%s", r.NetID, haSeparator, r.Network.String())) + return HAUniqueID(fmt.Sprintf("%s%s%s", r.NetID, haSeparator, r.NetString())) } -// GetResourceID returns the Networks Resource ID from a route ID -func (r *Route) GetResourceID() string { - return strings.Split(string(r.ID), ":")[0] +// GetResourceID returns the Networks ResID from the route ID. +// It's the part before the first colon in the ID string. +func (r *Route) GetResourceID() ResID { + return ResID(strings.Split(string(r.ID), ":")[0]) +} + +// NetString returns the network string. +// If the route is dynamic, it returns the domains as comma-separated punycode-encoded string. +// If the route is not dynamic, it returns the network (prefix) string. +func (r *Route) NetString() string { + if r.IsDynamic() { + return r.Domains.SafeString() + } + return r.Network.String() } // ParseNetwork Parses a network prefix string and returns a netip.Prefix object and if is invalid, IPv4 or IPv6 diff --git a/sharedsock/sock_linux.go b/sharedsock/sock_linux.go index 74ac6c163..1c22e7869 100644 --- a/sharedsock/sock_linux.go +++ b/sharedsock/sock_linux.go @@ -234,7 +234,7 @@ func (s *SharedSocket) read(receiver receiver) { } // ReadFrom reads packets received in the packetDemux channel -func (s *SharedSocket) ReadFrom(b []byte) (n int, addr net.Addr, err error) { +func (s *SharedSocket) ReadFrom(b []byte) (int, net.Addr, error) { var pkt rcvdPacket select { case <-s.ctx.Done(): @@ -263,8 +263,7 @@ func (s *SharedSocket) ReadFrom(b []byte) (n int, addr net.Addr, err error) { decodedLayers := make([]gopacket.LayerType, 0, 3) - err = parser.DecodeLayers(pkt.buf, &decodedLayers) - if err != nil { + if err := parser.DecodeLayers(pkt.buf, &decodedLayers); err != nil { return 0, nil, err } @@ -273,8 +272,8 @@ func (s *SharedSocket) ReadFrom(b []byte) (n int, addr net.Addr, err error) { Port: int(udp.SrcPort), } - copy(b, payload) - return int(udp.Length), remoteAddr, nil + n := copy(b, payload) + return n, remoteAddr, nil } // WriteTo builds a UDP packet and writes it using the specific IP version writer diff --git a/upload-server/Dockerfile b/upload-server/Dockerfile new file mode 100644 index 000000000..a38c6fbb8 --- /dev/null +++ b/upload-server/Dockerfile @@ -0,0 +1,3 @@ +FROM gcr.io/distroless/base:debug +ENTRYPOINT [ "/go/bin/netbird-upload" ] +COPY netbird-upload /go/bin/netbird-upload diff --git a/upload-server/main.go b/upload-server/main.go new file mode 100644 index 000000000..dcfb35cdf --- /dev/null +++ b/upload-server/main.go @@ -0,0 +1,22 @@ +package main + +import ( + "errors" + "log" + "net/http" + + "github.com/netbirdio/netbird/upload-server/server" + "github.com/netbirdio/netbird/util" +) + +func main() { + err := util.InitLog("info", "console") + if err != nil { + log.Fatalf("Failed to initialize logger: %v", err) + } + + srv := server.NewServer() + if err = srv.Start(); err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Fatalf("Failed to start server: %v", err) + } +} diff --git a/upload-server/server/local.go b/upload-server/server/local.go new file mode 100644 index 000000000..f12c472d2 --- /dev/null +++ b/upload-server/server/local.go @@ -0,0 +1,124 @@ +package server + +import ( + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/upload-server/types" +) + +const ( + defaultDir = "/var/lib/netbird" + putHandler = "/{dir}/{file}" +) + +type local struct { + url string + dir string +} + +func configureLocalHandlers(mux *http.ServeMux) error { + envURL, ok := os.LookupEnv("SERVER_URL") + if !ok { + return fmt.Errorf("SERVER_URL environment variable is required") + } + _, err := url.Parse(envURL) + if err != nil { + return fmt.Errorf("SERVER_URL environment variable is invalid: %w", err) + } + + dir := defaultDir + envDir, ok := os.LookupEnv("STORE_DIR") + if ok { + if !filepath.IsAbs(envDir) { + return fmt.Errorf("STORE_DIR environment variable should point to an absolute path, e.g. /tmp") + } + log.Infof("Using local directory: %s", envDir) + dir = envDir + } + + l := &local{ + url: envURL, + dir: dir, + } + mux.HandleFunc(types.GetURLPath, l.handlerGetUploadURL) + mux.HandleFunc(putURLPath+putHandler, l.handlePutRequest) + + return nil +} + +func (l *local) handlerGetUploadURL(w http.ResponseWriter, r *http.Request) { + if !isValidRequest(w, r) { + return + } + + objectKey := getObjectKey(w, r) + if objectKey == "" { + return + } + + uploadURL, err := l.getUploadURL(objectKey) + if err != nil { + http.Error(w, "failed to get upload URL", http.StatusInternalServerError) + log.Errorf("Failed to get upload URL: %v", err) + return + } + + respondGetRequest(w, uploadURL, objectKey) +} + +func (l *local) getUploadURL(objectKey string) (string, error) { + parsedUploadURL, err := url.Parse(l.url) + if err != nil { + return "", fmt.Errorf("failed to parse upload URL: %w", err) + } + newURL := parsedUploadURL.JoinPath(parsedUploadURL.Path, putURLPath, objectKey) + return newURL.String(), nil +} + +func (l *local) handlePutRequest(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, fmt.Sprintf("failed to read body: %v", err), http.StatusInternalServerError) + return + } + + uploadDir := r.PathValue("dir") + if uploadDir == "" { + http.Error(w, "missing dir path", http.StatusBadRequest) + return + } + uploadFile := r.PathValue("file") + if uploadFile == "" { + http.Error(w, "missing file name", http.StatusBadRequest) + return + } + + dirPath := filepath.Join(l.dir, uploadDir) + err = os.MkdirAll(dirPath, 0750) + if err != nil { + http.Error(w, "failed to create upload dir", http.StatusInternalServerError) + log.Errorf("Failed to create upload dir: %v", err) + return + } + + file := filepath.Join(dirPath, uploadFile) + if err := os.WriteFile(file, body, 0600); err != nil { + http.Error(w, "failed to write file", http.StatusInternalServerError) + log.Errorf("Failed to write file %s: %v", file, err) + return + } + log.Infof("Uploading file %s", file) + w.WriteHeader(http.StatusOK) +} diff --git a/upload-server/server/local_test.go b/upload-server/server/local_test.go new file mode 100644 index 000000000..bd8a87809 --- /dev/null +++ b/upload-server/server/local_test.go @@ -0,0 +1,65 @@ +package server + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/upload-server/types" +) + +func Test_LocalHandlerGetUploadURL(t *testing.T) { + mockURL := "http://localhost:8080" + t.Setenv("SERVER_URL", mockURL) + t.Setenv("STORE_DIR", t.TempDir()) + + mux := http.NewServeMux() + err := configureLocalHandlers(mux) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodGet, types.GetURLPath+"?id=test-file", nil) + req.Header.Set(types.ClientHeader, types.ClientHeaderValue) + + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var response types.GetURLResponse + err = json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + require.Contains(t, response.URL, "test-file/") + require.NotEmpty(t, response.Key) + require.Contains(t, response.Key, "test-file/") + +} + +func Test_LocalHandlePutRequest(t *testing.T) { + mockDir := t.TempDir() + mockURL := "http://localhost:8080" + t.Setenv("SERVER_URL", mockURL) + t.Setenv("STORE_DIR", mockDir) + + mux := http.NewServeMux() + err := configureLocalHandlers(mux) + require.NoError(t, err) + + fileContent := []byte("test file content") + req := httptest.NewRequest(http.MethodPut, putURLPath+"/uploads/test.txt", bytes.NewReader(fileContent)) + + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + expectedFilePath := filepath.Join(mockDir, "uploads", "test.txt") + createdFileContent, err := os.ReadFile(expectedFilePath) + require.NoError(t, err) + require.Equal(t, fileContent, createdFileContent) +} diff --git a/upload-server/server/s3.go b/upload-server/server/s3.go new file mode 100644 index 000000000..c0976acb5 --- /dev/null +++ b/upload-server/server/s3.go @@ -0,0 +1,69 @@ +package server + +import ( + "context" + "fmt" + "net/http" + "os" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/s3" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/upload-server/types" +) + +type sThree struct { + ctx context.Context + bucket string + presignClient *s3.PresignClient +} + +func configureS3Handlers(mux *http.ServeMux) error { + bucket := os.Getenv(bucketVar) + region, ok := os.LookupEnv("AWS_REGION") + if !ok { + return fmt.Errorf("AWS_REGION environment variable is required") + } + ctx := context.Background() + cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region)) + if err != nil { + return fmt.Errorf("unable to load SDK config: %w", err) + } + + client := s3.NewFromConfig(cfg) + + handler := &sThree{ + ctx: ctx, + bucket: bucket, + presignClient: s3.NewPresignClient(client), + } + mux.HandleFunc(types.GetURLPath, handler.handlerGetUploadURL) + return nil +} + +func (s *sThree) handlerGetUploadURL(w http.ResponseWriter, r *http.Request) { + if !isValidRequest(w, r) { + return + } + + objectKey := getObjectKey(w, r) + if objectKey == "" { + return + } + + req, err := s.presignClient.PresignPutObject(s.ctx, &s3.PutObjectInput{ + Bucket: aws.String(s.bucket), + Key: aws.String(objectKey), + }, s3.WithPresignExpires(15*time.Minute)) + + if err != nil { + http.Error(w, "failed to presign URL", http.StatusInternalServerError) + log.Errorf("Presign error: %v", err) + return + } + + respondGetRequest(w, req.URL, objectKey) +} diff --git a/upload-server/server/s3_test.go b/upload-server/server/s3_test.go new file mode 100644 index 000000000..26b0ecd09 --- /dev/null +++ b/upload-server/server/s3_test.go @@ -0,0 +1,103 @@ +package server + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "runtime" + "testing" + + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" + + "github.com/netbirdio/netbird/upload-server/types" +) + +func Test_S3HandlerGetUploadURL(t *testing.T) { + if runtime.GOOS != "linux" && os.Getenv("CI") == "true" { + t.Skip("Skipping test on non-Linux and CI environment due to docker dependency") + } + if runtime.GOOS == "windows" { + t.Skip("Skipping test on Windows due to potential docker dependency") + } + + awsEndpoint := "http://127.0.0.1:4566" + awsRegion := "us-east-1" + + ctx := context.Background() + containerRequest := testcontainers.ContainerRequest{ + Image: "localstack/localstack:s3-latest", + ExposedPorts: []string{"4566:4566/tcp"}, + WaitingFor: wait.ForLog("Ready"), + } + + c, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: containerRequest, + Started: true, + }) + if err != nil { + t.Error(err) + } + defer func(c testcontainers.Container, ctx context.Context) { + if err := c.Terminate(ctx); err != nil { + t.Log(err) + } + }(c, ctx) + + t.Setenv("AWS_REGION", awsRegion) + t.Setenv("AWS_ENDPOINT_URL", awsEndpoint) + t.Setenv("AWS_ACCESS_KEY_ID", "test") + t.Setenv("AWS_SECRET_ACCESS_KEY", "test") + + cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(awsRegion), config.WithBaseEndpoint(awsEndpoint)) + if err != nil { + t.Error(err) + } + + client := s3.NewFromConfig(cfg, func(o *s3.Options) { + o.UsePathStyle = true + o.BaseEndpoint = cfg.BaseEndpoint + }) + + bucketName := "test" + if _, err := client.CreateBucket(ctx, &s3.CreateBucketInput{ + Bucket: &bucketName, + }); err != nil { + t.Error(err) + } + + list, err := client.ListBuckets(ctx, &s3.ListBucketsInput{}) + if err != nil { + t.Error(err) + } + + assert.Equal(t, len(list.Buckets), 1) + assert.Equal(t, *list.Buckets[0].Name, bucketName) + + t.Setenv(bucketVar, bucketName) + + mux := http.NewServeMux() + err = configureS3Handlers(mux) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodGet, types.GetURLPath+"?id=test-file", nil) + req.Header.Set(types.ClientHeader, types.ClientHeaderValue) + + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var response types.GetURLResponse + err = json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + require.Contains(t, response.URL, "test-file/") + require.NotEmpty(t, response.Key) + require.Contains(t, response.Key, "test-file/") +} diff --git a/upload-server/server/server.go b/upload-server/server/server.go new file mode 100644 index 000000000..29ef72732 --- /dev/null +++ b/upload-server/server/server.go @@ -0,0 +1,109 @@ +package server + +import ( + "context" + "encoding/json" + "net/http" + "os" + "time" + + "github.com/google/uuid" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/upload-server/types" +) + +const ( + putURLPath = "/upload" + bucketVar = "BUCKET" +) + +type Server struct { + srv *http.Server +} + +func NewServer() *Server { + address := os.Getenv("SERVER_ADDRESS") + if address == "" { + log.Infof("SERVER_ADDRESS environment variable was not set, using 0.0.0.0:8080") + address = "0.0.0.0:8080" + } + mux := http.NewServeMux() + err := configureMux(mux) + if err != nil { + log.Fatalf("Failed to configure server: %v", err) + } + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "not found", http.StatusNotFound) + }) + + return &Server{ + srv: &http.Server{Addr: address, Handler: mux}, + } +} + +func (s *Server) Start() error { + log.Infof("Starting upload server on %s", s.srv.Addr) + return s.srv.ListenAndServe() +} + +func (s *Server) Stop() error { + if s.srv != nil { + log.Infof("Stopping upload server on %s", s.srv.Addr) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + return s.srv.Shutdown(ctx) + } + return nil +} + +func configureMux(mux *http.ServeMux) error { + _, ok := os.LookupEnv(bucketVar) + if ok { + return configureS3Handlers(mux) + } else { + return configureLocalHandlers(mux) + } +} + +func getObjectKey(w http.ResponseWriter, r *http.Request) string { + id := r.URL.Query().Get("id") + if id == "" { + http.Error(w, "id query param required", http.StatusBadRequest) + return "" + } + + return id + "/" + uuid.New().String() +} + +func isValidRequest(w http.ResponseWriter, r *http.Request) bool { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return false + } + + if r.Header.Get(types.ClientHeader) != types.ClientHeaderValue { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return false + } + return true +} +func respondGetRequest(w http.ResponseWriter, uploadURL string, objectKey string) { + response := types.GetURLResponse{ + URL: uploadURL, + Key: objectKey, + } + + rdata, err := json.Marshal(response) + if err != nil { + http.Error(w, "failed to marshal response", http.StatusInternalServerError) + log.Errorf("Marshal error: %v", err) + return + } + + w.WriteHeader(http.StatusOK) + _, err = w.Write(rdata) + if err != nil { + log.Errorf("Write error: %v", err) + } +} diff --git a/upload-server/types/upload.go b/upload-server/types/upload.go new file mode 100644 index 000000000..327c28e75 --- /dev/null +++ b/upload-server/types/upload.go @@ -0,0 +1,18 @@ +package types + +const ( + // ClientHeader is the header used to identify the client + ClientHeader = "x-nb-client" + // ClientHeaderValue is the value of the ClientHeader + ClientHeaderValue = "netbird" + // GetURLPath is the path for the GetURL request + GetURLPath = "/upload-url" + + DefaultBundleURL = "https://upload.debug.netbird.io" + GetURLPath +) + +// GetURLResponse is the response for the GetURL request +type GetURLResponse struct { + URL string + Key string +}