diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml index d7007c860..2b4c43cb4 100644 --- a/.github/workflows/golang-test-darwin.yml +++ b/.github/workflows/golang-test-darwin.yml @@ -14,7 +14,7 @@ jobs: test: strategy: matrix: - store: ['jsonfile', 'sqlite'] + store: ['sqlite'] runs-on: macos-latest steps: - name: Install Go diff --git a/.github/workflows/golang-test-freebsd.yml b/.github/workflows/golang-test-freebsd.yml new file mode 100644 index 000000000..15fc6a729 --- /dev/null +++ b/.github/workflows/golang-test-freebsd.yml @@ -0,0 +1,39 @@ + +name: Test Code FreeBSD + +on: + push: + branches: + - main + pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }} + cancel-in-progress: true + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Test in FreeBSD + id: test + uses: vmactions/freebsd-vm@v1 + with: + usesh: true + prepare: | + pkg install -y curl + pkg install -y git + + run: | + set -x + curl -o go.tar.gz https://go.dev/dl/go1.21.11.freebsd-amd64.tar.gz -L + tar zxf go.tar.gz + mv go /usr/local/go + ln -s /usr/local/go/bin/go /usr/local/bin/go + go mod tidy + go test -timeout 5m -p 1 ./iface/... + go test -timeout 5m -p 1 ./client/... + cd client + go build . + cd .. \ No newline at end of file diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 4259f1b3e..120b213e9 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -15,7 +15,7 @@ jobs: strategy: matrix: arch: [ '386','amd64' ] - store: [ 'jsonfile', 'sqlite', 'postgres'] + store: [ 'sqlite', 'postgres'] runs-on: ubuntu-latest steps: - name: Install Go @@ -86,7 +86,10 @@ jobs: run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock - name: Generate RouteManager Test bin - run: CGO_ENABLED=1 go test -c -o routemanager-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/... + run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager + + - name: Generate SystemOps Test bin + run: CGO_ENABLED=1 go test -c -o systemops-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/systemops - name: Generate nftables Manager Test bin run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/... @@ -108,6 +111,9 @@ jobs: - name: Run RouteManager tests in docker run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1 + - name: Run SystemOps tests in docker + run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager/systemops --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/systemops-testing.bin -test.timeout 5m -test.parallel 1 + - name: Run nftables Manager tests in docker run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index e117e8fab..65ae0aa26 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -173,7 +173,7 @@ jobs: retention-days: 3 release_ui_darwin: - runs-on: macos-11 + runs-on: macos-latest steps: - if: ${{ !startsWith(github.ref, 'refs/tags/v') }} run: echo "flags=--snapshot" >> $GITHUB_ENV diff --git a/.github/workflows/test-infrastructure-files.yml b/.github/workflows/test-infrastructure-files.yml index ee9739e09..abdd18ceb 100644 --- a/.github/workflows/test-infrastructure-files.yml +++ b/.github/workflows/test-infrastructure-files.yml @@ -178,34 +178,79 @@ jobs: - name: Checkout code uses: actions/checkout@v3 - - name: run script + - name: run script with Zitadel PostgreSQL run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh - - name: test Caddy file gen + - name: test Caddy file gen postgres run: test -f Caddyfile - - name: test docker-compose file gen + + - name: test docker-compose file gen postgres run: test -f docker-compose.yml - - name: test management.json file gen + + - name: test management.json file gen postgres run: test -f management.json - - name: test turnserver.conf file gen + + - name: test turnserver.conf file gen postgres run: | set -x test -f turnserver.conf grep external-ip turnserver.conf - - name: test zitadel.env file gen + + - name: test zitadel.env file gen postgres run: test -f zitadel.env - - name: test dashboard.env file gen + + - name: test dashboard.env file gen postgres run: test -f dashboard.env + + - name: test zdb.env file gen postgres + run: test -f zdb.env + + - name: Postgres run cleanup + run: | + docker-compose down --volumes --rmi all + rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env + + - name: run script with Zitadel CockroachDB + run: bash -x infrastructure_files/getting-started-with-zitadel.sh + env: + NETBIRD_DOMAIN: use-ip + ZITADEL_DATABASE: cockroach + + - name: test Caddy file gen CockroachDB + run: test -f Caddyfile + + - name: test docker-compose file gen CockroachDB + run: test -f docker-compose.yml + + - name: test management.json file gen CockroachDB + run: test -f management.json + + - name: test turnserver.conf file gen CockroachDB + run: | + set -x + test -f turnserver.conf + grep external-ip turnserver.conf + + - name: test zitadel.env file gen CockroachDB + run: test -f zitadel.env + + - name: test dashboard.env file gen CockroachDB + run: test -f dashboard.env + test-download-geolite2-script: runs-on: ubuntu-latest steps: - name: Install jq run: sudo apt-get update && sudo apt-get install -y unzip sqlite3 + - name: Checkout code uses: actions/checkout@v3 + - name: test script run: bash -x infrastructure_files/download-geolite2.sh + - name: test mmdb file exists run: test -f GeoLite2-City.mmdb + - name: test geonames file exists run: test -f geonames.db diff --git a/.goreleaser_ui_darwin.yaml b/.goreleaser_ui_darwin.yaml index cde09cda1..2c3afa91b 100644 --- a/.goreleaser_ui_darwin.yaml +++ b/.goreleaser_ui_darwin.yaml @@ -3,8 +3,10 @@ builds: - id: netbird-ui-darwin dir: client/ui binary: netbird-ui - env: [CGO_ENABLED=1] - + env: + - CGO_ENABLED=1 + - MACOSX_DEPLOYMENT_TARGET=11.0 + - MACOS_DEPLOYMENT_TARGET=11.0 goos: - darwin goarch: diff --git a/client/Dockerfile b/client/Dockerfile index 7f4060f3d..a3220bf33 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -1,4 +1,4 @@ -FROM alpine:3.18.5 +FROM alpine:3.19 RUN apk add --no-cache ca-certificates iptables ip6tables ENV NB_FOREGROUND_MODE=true ENTRYPOINT [ "/usr/local/bin/netbird","up"] diff --git a/client/cmd/root.go b/client/cmd/root.go index 839380712..f0b5d2bdf 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -36,6 +36,7 @@ const ( disableAutoConnectFlag = "disable-auto-connect" serverSSHAllowedFlag = "allow-server-ssh" extraIFaceBlackListFlag = "extra-iface-blacklist" + dnsRouteIntervalFlag = "dns-router-interval" ) var ( @@ -68,7 +69,9 @@ var ( autoConnectDisabled bool extraIFaceBlackList []string anonymizeFlag bool - rootCmd = &cobra.Command{ + dnsRouteInterval time.Duration + + rootCmd = &cobra.Command{ Use: "netbird", Short: "", Long: "", diff --git a/client/cmd/route.go b/client/cmd/route.go index d92e079ad..c8881822b 100644 --- a/client/cmd/route.go +++ b/client/cmd/route.go @@ -2,6 +2,7 @@ package cmd import ( "fmt" + "strings" "github.com/spf13/cobra" "google.golang.org/grpc/status" @@ -66,18 +67,60 @@ func routesList(cmd *cobra.Command, _ []string) error { return nil } - cmd.Println("Available Routes:") - for _, route := range resp.Routes { - selectedStatus := "Not Selected" - if route.GetSelected() { - selectedStatus = "Selected" - } - cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus) - } + printRoutes(cmd, resp) return nil } +func printRoutes(cmd *cobra.Command, resp *proto.ListRoutesResponse) { + cmd.Println("Available Routes:") + for _, route := range resp.Routes { + printRoute(cmd, route) + } +} + +func printRoute(cmd *cobra.Command, route *proto.Route) { + selectedStatus := getSelectedStatus(route) + domains := route.GetDomains() + + if len(domains) > 0 { + printDomainRoute(cmd, route, domains, selectedStatus) + } else { + printNetworkRoute(cmd, route, selectedStatus) + } +} + +func getSelectedStatus(route *proto.Route) string { + if route.GetSelected() { + return "Selected" + } + return "Not Selected" +} + +func printDomainRoute(cmd *cobra.Command, route *proto.Route, domains []string, selectedStatus string) { + cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus) + resolvedIPs := route.GetResolvedIPs() + + if len(resolvedIPs) > 0 { + printResolvedIPs(cmd, domains, resolvedIPs) + } else { + cmd.Printf(" Resolved IPs: -\n") + } +} + +func printNetworkRoute(cmd *cobra.Command, route *proto.Route, selectedStatus string) { + cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus) +} + +func printResolvedIPs(cmd *cobra.Command, domains []string, resolvedIPs map[string]*proto.IPList) { + cmd.Printf(" Resolved IPs:\n") + for _, domain := range domains { + if ipList, exists := resolvedIPs[domain]; exists { + cmd.Printf(" [%s]: %s\n", domain, strings.Join(ipList.GetIps(), ", ")) + } + } +} + func routesSelect(cmd *cobra.Command, args []string) error { conn, err := getClient(cmd) if err != nil { diff --git a/client/cmd/status.go b/client/cmd/status.go index 3dacfbe4f..e6c7b8be8 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -807,11 +807,7 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) { } for i, route := range peer.Routes { - prefix, err := netip.ParsePrefix(route) - if err == nil { - ip := a.AnonymizeIPString(prefix.Addr().String()) - peer.Routes[i] = fmt.Sprintf("%s/%d", ip, prefix.Bits()) - } + peer.Routes[i] = anonymizeRoute(a, route) } } @@ -847,12 +843,21 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview) } for i, route := range overview.Routes { - prefix, err := netip.ParsePrefix(route) - if err == nil { - ip := a.AnonymizeIPString(prefix.Addr().String()) - overview.Routes[i] = fmt.Sprintf("%s/%d", ip, prefix.Bits()) - } + overview.Routes[i] = anonymizeRoute(a, route) } overview.FQDN = a.AnonymizeDomain(overview.FQDN) } + +func anonymizeRoute(a *anonymize.Anonymizer, route string) string { + prefix, err := netip.ParsePrefix(route) + if err == nil { + ip := a.AnonymizeIPString(prefix.Addr().String()) + return fmt.Sprintf("%s/%d", ip, prefix.Bits()) + } + domains := strings.Split(route, ", ") + for i, domain := range domains { + domains[i] = a.AnonymizeDomain(domain) + } + return strings.Join(domains, ", ") +} diff --git a/client/cmd/testutil.go b/client/cmd/testutil_test.go similarity index 79% rename from client/cmd/testutil.go rename to client/cmd/testutil_test.go index f032884df..63d90cc63 100644 --- a/client/cmd/testutil.go +++ b/client/cmd/testutil_test.go @@ -7,6 +7,9 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" + "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/util" @@ -53,7 +56,10 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) { t.Fatal(err) } s := grpc.NewServer() - sigProto.RegisterSignalExchangeServer(s, sig.NewServer()) + srv, err := sig.NewServer(otel.Meter("")) + require.NoError(t, err) + + sigProto.RegisterSignalExchangeServer(s, srv) go func() { if err := s.Serve(lis); err != nil { panic(err) @@ -70,7 +76,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste t.Fatal(err) } s := grpc.NewServer() - store, cleanUp, err := mgmt.NewTestStoreFromJson(config.Datadir) + store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir) if err != nil { t.Fatal(err) } @@ -81,13 +87,13 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste if err != nil { return nil, nil } - iv, _ := integrations.NewIntegratedValidator(eventStore) - accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv) + iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) + accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv) if err != nil { t.Fatal(err) } turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "") - mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil) + mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil) if err != nil { t.Fatal(err) } @@ -102,7 +108,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste } func startClientDaemon( - t *testing.T, ctx context.Context, managementURL, configPath string, + t *testing.T, ctx context.Context, _, configPath string, ) (*grpc.Server, net.Listener) { t.Helper() lis, err := net.Listen("tcp", "127.0.0.1:0") diff --git a/client/cmd/up.go b/client/cmd/up.go index a5bbc58be..f69e9eb27 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -7,11 +7,13 @@ import ( "net/netip" "runtime" "strings" + "time" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/durationpb" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" @@ -40,8 +42,12 @@ func init() { upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground") upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name") upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port") - upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", false, "Enable network monitoring") + upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor, + `Manage network monitoring. Defaults to true on Windows and macOS, false on Linux. `+ + `E.g. --network-monitor=false to disable or --network-monitor=true to enable.`, + ) upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening") + upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval") } func upFunc(cmd *cobra.Command, args []string) error { @@ -137,6 +143,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error { } } + if cmd.Flag(dnsRouteIntervalFlag).Changed { + ic.DNSRouteInterval = &dnsRouteInterval + } + config, err := internal.UpdateOrCreateConfig(ic) if err != nil { return fmt.Errorf("get config file: %v", err) @@ -237,6 +247,10 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { loginRequest.NetworkMonitor = &networkMonitor } + if cmd.Flag(dnsRouteIntervalFlag).Changed { + loginRequest.DnsRouteInterval = durationpb.New(dnsRouteInterval) + } + var loginErr error var loginResp *proto.LoginResponse diff --git a/client/errors/errors.go b/client/errors/errors.go new file mode 100644 index 000000000..cef999ac8 --- /dev/null +++ b/client/errors/errors.go @@ -0,0 +1,30 @@ +package errors + +import ( + "fmt" + "strings" + + "github.com/hashicorp/go-multierror" +) + +func formatError(es []error) string { + if len(es) == 0 { + return fmt.Sprintf("0 error occurred:\n\t* %s", es[0]) + } + + points := make([]string, len(es)) + for i, err := range es { + points[i] = fmt.Sprintf("* %s", err) + } + + return fmt.Sprintf( + "%d errors occurred:\n\t%s", + len(es), strings.Join(points, "\n\t")) +} + +func FormatErrorOrNil(err *multierror.Error) error { + if err != nil { + err.ErrorFormat = formatError + } + return err.ErrorOrNil() +} diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 40e1077be..e8f09a106 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -74,12 +74,12 @@ func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error { return nil } - err = i.insertRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair) + err = i.addNATRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair) if err != nil { return err } - err = i.insertRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair)) + err = i.addNATRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair)) if err != nil { return err } @@ -101,6 +101,7 @@ func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, } delete(i.rules, ruleKey) } + err = i.iptablesClient.Insert(table, chain, 1, rule...) if err != nil { return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err) @@ -317,6 +318,13 @@ func (i *routerManager) createChain(table, newChain string) error { return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err) } + // Add the loopback return rule to the NAT chain + loopbackRule := []string{"-o", "lo", "-j", "RETURN"} + err = i.iptablesClient.Insert(table, newChain, 1, loopbackRule...) + if err != nil { + return fmt.Errorf("failed to add loopback return rule to %s: %v", chainRTNAT, err) + } + err = i.iptablesClient.Append(table, newChain, "-j", "RETURN") if err != nil { return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err) @@ -326,6 +334,30 @@ func (i *routerManager) createChain(table, newChain string) error { return nil } +// addNATRule appends an iptables rule pair to the nat chain +func (i *routerManager) addNATRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error { + ruleKey := firewall.GenKey(keyFormat, pair.ID) + rule := genRuleSpec(jump, pair.Source, pair.Destination) + existingRule, found := i.rules[ruleKey] + if found { + err := i.iptablesClient.DeleteIfExists(table, chain, existingRule...) + if err != nil { + return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err) + } + delete(i.rules, ruleKey) + } + + // inserting after loopback ignore rule + err := i.iptablesClient.Insert(table, chain, 2, rule...) + if err != nil { + return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err) + } + + i.rules[ruleKey] = rule + + return nil +} + // genRuleSpec generates rule specification func genRuleSpec(jump, source, destination string) []string { return []string{"-s", source, "-d", destination, "-j", jump} diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 8395fc270..a376c98c3 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -95,7 +95,7 @@ func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.InsertRoutingRules(pair) + return m.router.AddRoutingRules(pair) } func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error { diff --git a/client/firewall/nftables/route_linux.go b/client/firewall/nftables/route_linux.go index 381136e50..71d5ac88e 100644 --- a/client/firewall/nftables/route_linux.go +++ b/client/firewall/nftables/route_linux.go @@ -22,6 +22,8 @@ const ( userDataAcceptForwardRuleSrc = "frwacceptsrc" userDataAcceptForwardRuleDst = "frwacceptdst" + + loopbackInterface = "lo\x00" ) // some presets for building nftable rules @@ -126,6 +128,22 @@ func (r *router) createContainers() error { Type: nftables.ChainTypeNAT, }) + // Add RETURN rule for loopback interface + loRule := &nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameRoutingNat], + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte(loopbackInterface), + }, + &expr.Verdict{Kind: expr.VerdictReturn}, + }, + } + r.conn.InsertRule(loRule) + err := r.refreshRulesMap() if err != nil { log.Errorf("failed to clean up rules from FORWARD chain: %s", err) @@ -138,28 +156,28 @@ func (r *router) createContainers() error { return nil } -// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain -func (r *router) InsertRoutingRules(pair manager.RouterPair) error { +// AddRoutingRules appends a nftable rule pair to the forwarding chain and if enabled, to the nat chain +func (r *router) AddRoutingRules(pair manager.RouterPair) error { err := r.refreshRulesMap() if err != nil { return err } - err = r.insertRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false) + err = r.addRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false) if err != nil { return err } - err = r.insertRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false) + err = r.addRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false) if err != nil { return err } if pair.Masquerade { - err = r.insertRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true) + err = r.addRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true) if err != nil { return err } - err = r.insertRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true) + err = r.addRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true) if err != nil { return err } @@ -177,8 +195,8 @@ func (r *router) InsertRoutingRules(pair manager.RouterPair) error { return nil } -// insertRoutingRule inserts a nftable rule to the conn client flush queue -func (r *router) insertRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error { +// addRoutingRule inserts a nftable rule to the conn client flush queue +func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error { sourceExp := generateCIDRMatcherExpressions(true, pair.Source) destExp := generateCIDRMatcherExpressions(false, pair.Destination) @@ -199,7 +217,7 @@ func (r *router) insertRoutingRule(format, chainName string, pair manager.Router } } - r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{ + r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ Table: r.workTable, Chain: r.chains[chainName], Exprs: expression, diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index aa1224a5a..913fbd5d2 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -47,7 +47,7 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) { require.NoError(t, err, "shouldn't return error") - err = manager.InsertRoutingRules(testCase.InputPair) + err = manager.AddRoutingRules(testCase.InputPair) defer func() { _ = manager.RemoveRoutingRules(testCase.InputPair) }() diff --git a/client/internal/config.go b/client/internal/config.go index 66721cd21..461dcdd96 100644 --- a/client/internal/config.go +++ b/client/internal/config.go @@ -6,13 +6,16 @@ import ( "net/url" "os" "reflect" + "runtime" "strings" + "time" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/iface" mgm "github.com/netbirdio/netbird/management/client" @@ -53,6 +56,7 @@ type ConfigInput struct { NetworkMonitor *bool DisableAutoConnect *bool ExtraIFaceBlackList []string + DNSRouteInterval *time.Duration } // Config Configuration type @@ -64,7 +68,7 @@ type Config struct { AdminURL *url.URL WgIface string WgPort int - NetworkMonitor bool + NetworkMonitor *bool IFaceBlackList []string DisableIPv6Discovery bool RosenpassEnabled bool @@ -95,6 +99,9 @@ type Config struct { // DisableAutoConnect determines whether the client should not start with the service // it's set to false by default due to backwards compatibility DisableAutoConnect bool + + // DNSRouteInterval is the interval in which the DNS routes are updated + DNSRouteInterval time.Duration } // ReadConfig read config file and return with Config. If it is not exists create a new with default values @@ -304,12 +311,21 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) { updated = true } - if input.NetworkMonitor != nil && *input.NetworkMonitor != config.NetworkMonitor { + if input.NetworkMonitor != nil && input.NetworkMonitor != config.NetworkMonitor { log.Infof("switching Network Monitor to %t", *input.NetworkMonitor) - config.NetworkMonitor = *input.NetworkMonitor + config.NetworkMonitor = input.NetworkMonitor updated = true } + if config.NetworkMonitor == nil { + // enable network monitoring by default on windows and darwin clients + if runtime.GOOS == "windows" || runtime.GOOS == "darwin" { + enabled := true + config.NetworkMonitor = &enabled + updated = true + } + } + if input.CustomDNSAddress != nil && string(input.CustomDNSAddress) != config.CustomDNSAddress { log.Infof("updating custom DNS address %#v (old value %#v)", string(input.CustomDNSAddress), config.CustomDNSAddress) @@ -357,6 +373,18 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) { updated = true } + if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval { + log.Infof("updating DNS route interval to %s (old value %s)", + input.DNSRouteInterval.String(), config.DNSRouteInterval.String()) + config.DNSRouteInterval = *input.DNSRouteInterval + updated = true + } else if config.DNSRouteInterval == 0 { + config.DNSRouteInterval = dynamic.DefaultInterval + log.Infof("using default DNS route interval %s", config.DNSRouteInterval) + updated = true + + } + return updated, nil } diff --git a/client/internal/connect.go b/client/internal/connect.go index 4d79ba72b..a20557f89 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -264,8 +264,10 @@ func (c *ConnectClient) run( return wrapErr(err) } + checks := loginResp.GetChecks() + c.engineMutex.Lock() - c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe) + c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe, checks) c.engineMutex.Unlock() err = c.engine.Start() @@ -342,6 +344,10 @@ func (c *ConnectClient) Engine() *Engine { // createEngineConfig converts configuration received from Management Service to EngineConfig func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) { + nm := false + if config.NetworkMonitor != nil { + nm = *config.NetworkMonitor + } engineConf := &EngineConfig{ WgIfaceName: config.WgIface, WgAddr: peerConfig.Address, @@ -349,13 +355,14 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe DisableIPv6Discovery: config.DisableIPv6Discovery, WgPrivateKey: key, WgPort: config.WgPort, - NetworkMonitor: config.NetworkMonitor, + NetworkMonitor: nm, SSHKey: []byte(config.SSHKey), NATExternalIPs: config.NATExternalIPs, CustomDNSAddress: config.CustomDNSAddress, RosenpassEnabled: config.RosenpassEnabled, RosenpassPermissive: config.RosenpassPermissive, ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed), + DNSRouteInterval: config.DNSRouteInterval, } if config.PreSharedKey != "" { diff --git a/client/internal/dns/consts_freebsd.go b/client/internal/dns/consts_freebsd.go new file mode 100644 index 000000000..958eca8e5 --- /dev/null +++ b/client/internal/dns/consts_freebsd.go @@ -0,0 +1,6 @@ +package dns + +const ( + fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf" + fileUncleanShutdownManagerTypeLocation = "/var/db/netbird/manager" +) diff --git a/client/internal/dns/consts_linux.go b/client/internal/dns/consts_linux.go new file mode 100644 index 000000000..32456a50f --- /dev/null +++ b/client/internal/dns/consts_linux.go @@ -0,0 +1,8 @@ +//go:build !android + +package dns + +const ( + fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf" + fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager" +) diff --git a/client/internal/dns/dbus_linux.go b/client/internal/dns/dbus_unix.go similarity index 96% rename from client/internal/dns/dbus_linux.go rename to client/internal/dns/dbus_unix.go index b2604e9fa..ba1c07fae 100644 --- a/client/internal/dns/dbus_linux.go +++ b/client/internal/dns/dbus_unix.go @@ -1,4 +1,4 @@ -//go:build !android +//go:build (linux && !android) || freebsd package dns diff --git a/client/internal/dns/file_parser_linux.go b/client/internal/dns/file_parser_unix.go similarity index 99% rename from client/internal/dns/file_parser_linux.go rename to client/internal/dns/file_parser_unix.go index 02f6d03a5..130c88214 100644 --- a/client/internal/dns/file_parser_linux.go +++ b/client/internal/dns/file_parser_unix.go @@ -1,4 +1,4 @@ -//go:build !android +//go:build (linux && !android) || freebsd package dns diff --git a/client/internal/dns/file_parser_linux_test.go b/client/internal/dns/file_parser_unix_test.go similarity index 99% rename from client/internal/dns/file_parser_linux_test.go rename to client/internal/dns/file_parser_unix_test.go index 4263d4063..1d6e64683 100644 --- a/client/internal/dns/file_parser_linux_test.go +++ b/client/internal/dns/file_parser_unix_test.go @@ -1,4 +1,4 @@ -//go:build !android +//go:build (linux && !android) || freebsd package dns diff --git a/client/internal/dns/file_repair_linux.go b/client/internal/dns/file_repair_unix.go similarity index 98% rename from client/internal/dns/file_repair_linux.go rename to client/internal/dns/file_repair_unix.go index cbdda5e9e..ae2c33b86 100644 --- a/client/internal/dns/file_repair_linux.go +++ b/client/internal/dns/file_repair_unix.go @@ -1,4 +1,4 @@ -//go:build !android +//go:build (linux && !android) || freebsd package dns diff --git a/client/internal/dns/file_repair_linux_test.go b/client/internal/dns/file_repair_unix_test.go similarity index 98% rename from client/internal/dns/file_repair_linux_test.go rename to client/internal/dns/file_repair_unix_test.go index 4e27f46ba..4dba79e99 100644 --- a/client/internal/dns/file_repair_linux_test.go +++ b/client/internal/dns/file_repair_unix_test.go @@ -1,4 +1,4 @@ -//go:build !android +//go:build (linux && !android) || freebsd package dns diff --git a/client/internal/dns/file_linux.go b/client/internal/dns/file_unix.go similarity index 99% rename from client/internal/dns/file_linux.go rename to client/internal/dns/file_unix.go index b9d6d699d..624e089cb 100644 --- a/client/internal/dns/file_linux.go +++ b/client/internal/dns/file_unix.go @@ -1,4 +1,4 @@ -//go:build !android +//go:build (linux && !android) || freebsd package dns diff --git a/client/internal/dns/file_linux_test.go b/client/internal/dns/file_unix_test.go similarity index 98% rename from client/internal/dns/file_linux_test.go rename to client/internal/dns/file_unix_test.go index 902791b36..46726536e 100644 --- a/client/internal/dns/file_linux_test.go +++ b/client/internal/dns/file_unix_test.go @@ -1,4 +1,4 @@ -//go:build !android +//go:build (linux && !android) || freebsd package dns diff --git a/client/internal/dns/host_linux.go b/client/internal/dns/host_unix.go similarity index 85% rename from client/internal/dns/host_linux.go rename to client/internal/dns/host_unix.go index cb246bcfe..72b8f6c6e 100644 --- a/client/internal/dns/host_linux.go +++ b/client/internal/dns/host_unix.go @@ -1,4 +1,4 @@ -//go:build !android +//go:build (linux && !android) || freebsd package dns @@ -108,7 +108,7 @@ func getOSDNSManagerType() (osManagerType, error) { if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() { return networkManager, nil } - if strings.Contains(text, "systemd-resolved") && isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) { + if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() { if checkStub() { return systemdManager, nil } else { @@ -116,16 +116,10 @@ func getOSDNSManagerType() (osManagerType, error) { } } if strings.Contains(text, "resolvconf") { - if isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) { - var value string - err = getSystemdDbusProperty(systemdDbusResolvConfModeProperty, &value) - if err == nil { - if value == systemdDbusResolvConfModeForeign { - return systemdManager, nil - } - } - log.Errorf("got an error while checking systemd resolv conf mode, error: %s", err) + if isSystemdResolveConfMode() { + return systemdManager, nil } + return resolvConfManager, nil } } diff --git a/client/internal/dns/network_manager_linux.go b/client/internal/dns/network_manager_unix.go similarity index 99% rename from client/internal/dns/network_manager_linux.go rename to client/internal/dns/network_manager_unix.go index dfd4cf4d3..184047a64 100644 --- a/client/internal/dns/network_manager_linux.go +++ b/client/internal/dns/network_manager_unix.go @@ -1,4 +1,4 @@ -//go:build !android +//go:build (linux && !android) || freebsd package dns diff --git a/client/internal/dns/resolvconf_linux.go b/client/internal/dns/resolvconf_unix.go similarity index 98% rename from client/internal/dns/resolvconf_linux.go rename to client/internal/dns/resolvconf_unix.go index 72db5faf1..0c17626c7 100644 --- a/client/internal/dns/resolvconf_linux.go +++ b/client/internal/dns/resolvconf_unix.go @@ -1,4 +1,4 @@ -//go:build !android +//go:build (linux && !android) || freebsd package dns diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 22966d89c..6cbd9ea15 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -39,6 +39,10 @@ func (w *mocWGIface) Address() iface.WGAddress { } } +func (w *mocWGIface) ToInterface() *net.Interface { + panic("implement me") +} + func (w *mocWGIface) GetFilter() iface.PacketFilter { return w.filter } @@ -261,7 +265,7 @@ func TestUpdateDNSServer(t *testing.T) { if err != nil { t.Fatal(err) } - wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil) + wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil) if err != nil { t.Fatal(err) } @@ -339,7 +343,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { } privKey, _ := wgtypes.GeneratePrivateKey() - wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil) + wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil) if err != nil { t.Errorf("build interface wireguard: %v", err) return @@ -797,7 +801,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) { } privKey, _ := wgtypes.GeneratePrivateKey() - wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil) + wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil) if err != nil { t.Fatalf("build interface wireguard: %v", err) return nil, err diff --git a/client/internal/dns/server_linux.go b/client/internal/dns/server_unix.go similarity index 76% rename from client/internal/dns/server_linux.go rename to client/internal/dns/server_unix.go index aeb24b511..455425625 100644 --- a/client/internal/dns/server_linux.go +++ b/client/internal/dns/server_unix.go @@ -1,4 +1,4 @@ -//go:build !android +//go:build (linux && !android) || freebsd package dns diff --git a/client/internal/dns/systemd_freebsd.go b/client/internal/dns/systemd_freebsd.go new file mode 100644 index 000000000..0de805337 --- /dev/null +++ b/client/internal/dns/systemd_freebsd.go @@ -0,0 +1,20 @@ +package dns + +import ( + "errors" + "fmt" +) + +var errNotImplemented = errors.New("not implemented") + +func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) { + return nil, fmt.Errorf("systemd dns management: %w on freebsd", errNotImplemented) +} + +func isSystemdResolvedRunning() bool { + return false +} + +func isSystemdResolveConfMode() bool { + return false +} diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go index 27a93fbe1..e2fa5b71a 100644 --- a/client/internal/dns/systemd_linux.go +++ b/client/internal/dns/systemd_linux.go @@ -242,3 +242,25 @@ func getSystemdDbusProperty(property string, store any) error { return v.Store(store) } + +func isSystemdResolvedRunning() bool { + return isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) +} + +func isSystemdResolveConfMode() bool { + if !isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) { + return false + } + + var value string + if err := getSystemdDbusProperty(systemdDbusResolvConfModeProperty, &value); err != nil { + log.Errorf("got an error while checking systemd resolv conf mode, error: %s", err) + return false + } + + if value == systemdDbusResolvConfModeForeign { + return true + } + + return false +} diff --git a/client/internal/dns/unclean_shutdown_linux.go b/client/internal/dns/unclean_shutdown_unix.go similarity index 94% rename from client/internal/dns/unclean_shutdown_linux.go rename to client/internal/dns/unclean_shutdown_unix.go index afd587720..8a32090c3 100644 --- a/client/internal/dns/unclean_shutdown_linux.go +++ b/client/internal/dns/unclean_shutdown_unix.go @@ -1,4 +1,4 @@ -//go:build !android +//go:build (linux && !android) || freebsd package dns @@ -14,11 +14,6 @@ import ( log "github.com/sirupsen/logrus" ) -const ( - fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf" - fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager" -) - func CheckUncleanShutdown(wgIface string) error { if _, err := os.Stat(fileUncleanShutdownResolvConfLocation); err != nil { if errors.Is(err, fs.ErrNotExist) { diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index e82c98fbc..b502bf5eb 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -78,6 +78,11 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { }() log.WithField("question", r.Question[0]).Trace("received an upstream question") + // set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records + if r.Extra == nil { + r.SetEdns0(4096, false) + r.MsgHdr.AuthenticatedData = true + } select { case <-u.ctx.Done(): diff --git a/client/internal/dns/wgiface.go b/client/internal/dns/wgiface.go index 2c34f1c47..2f08e8d52 100644 --- a/client/internal/dns/wgiface.go +++ b/client/internal/dns/wgiface.go @@ -2,12 +2,17 @@ package dns -import "github.com/netbirdio/netbird/iface" +import ( + "net" + + "github.com/netbirdio/netbird/iface" +) // WGIface defines subset methods of interface required for manager type WGIface interface { Name() string Address() iface.WGAddress + ToInterface() *net.Interface IsUserspaceBind() bool GetFilter() iface.PacketFilter GetDevice() *iface.DeviceWrapper diff --git a/client/internal/engine.go b/client/internal/engine.go index cb624bf43..ec513391a 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -10,6 +10,7 @@ import ( "net/netip" "reflect" "runtime" + "slices" "strings" "sync" "sync/atomic" @@ -30,12 +31,15 @@ import ( "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/routemanager" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/wgproxy" nbssh "github.com/netbirdio/netbird/client/ssh" + "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface/bind" mgm "github.com/netbirdio/netbird/management/client" + "github.com/netbirdio/netbird/management/domain" mgmProto "github.com/netbirdio/netbird/management/proto" auth "github.com/netbirdio/netbird/relay/auth/hmac" relayClient "github.com/netbirdio/netbird/relay/client" @@ -43,6 +47,7 @@ import ( signal "github.com/netbirdio/netbird/signal/client" sProto "github.com/netbirdio/netbird/signal/proto" "github.com/netbirdio/netbird/util" + nbnet "github.com/netbirdio/netbird/util/net" ) // PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer. @@ -93,6 +98,8 @@ type EngineConfig struct { RosenpassPermissive bool ServerSSHAllowed bool + + DNSRouteInterval time.Duration } // Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers. @@ -105,8 +112,8 @@ type Engine struct { // peerConns is a map that holds all the peers that are known to this peer peerConns map[string]*peer.Conn - beforePeerHook peer.BeforeAddPeerHookFunc - afterPeerHook peer.AfterRemovePeerHookFunc + beforePeerHook nbnet.AddHookFunc + afterPeerHook nbnet.RemoveHookFunc // rpManager is a Rosenpass manager rpManager *rosenpass.Manager @@ -159,6 +166,9 @@ type Engine struct { relayProbe *Probe wgProbe *Probe + // checks are the client-applied posture checks that need to be evaluated on the client + checks []*mgmProto.Checks + relayManager *relayClient.Manager } @@ -178,6 +188,7 @@ func NewEngine( config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status, + checks []*mgmProto.Checks, ) *Engine { return NewEngineWithProbes( clientCtx, @@ -192,6 +203,7 @@ func NewEngine( nil, nil, nil, + checks, ) } @@ -209,6 +221,7 @@ func NewEngineWithProbes( signalProbe *Probe, relayProbe *Probe, wgProbe *Probe, + checks []*mgmProto.Checks, ) *Engine { return &Engine{ clientCtx: clientCtx, @@ -230,6 +243,7 @@ func NewEngineWithProbes( signalProbe: signalProbe, relayProbe: relayProbe, wgProbe: wgProbe, + checks: checks, } } @@ -277,8 +291,6 @@ func (e *Engine) Start() error { } e.ctx, e.cancel = context.WithCancel(e.clientCtx) - e.wgProxyFactory = wgproxy.NewFactory(e.ctx, e.config.WgPort) - wgIface, err := e.newWgIface() if err != nil { log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err) @@ -286,6 +298,9 @@ func (e *Engine) Start() error { } e.wgInterface = wgIface + userspace := e.wgInterface.IsUserspaceBind() + e.wgProxyFactory = wgproxy.NewFactory(e.ctx, userspace, e.config.WgPort) + if e.config.RosenpassEnabled { log.Infof("rosenpass is enabled") if e.config.RosenpassPermissive { @@ -310,7 +325,7 @@ func (e *Engine) Start() error { } e.dnsServer = dnsServer - e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes) + e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, initialRoutes) beforePeerHook, afterPeerHook, err := e.routeManager.Init() if err != nil { log.Errorf("Failed to initialize route manager: %s", err) @@ -498,6 +513,10 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { // todo update signal } + if err := e.updateChecksIfNew(update.Checks); err != nil { + return err + } + if update.GetNetworkMap() != nil { // only apply new changes and ignore old ones err := e.updateNetworkMap(update.GetNetworkMap()) @@ -505,7 +524,27 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { return err } } + return nil +} +// updateChecksIfNew updates checks if there are changes and sync new meta with management +func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error { + // if checks are equal, we skip the update + if isChecksEqual(e.checks, checks) { + return nil + } + e.checks = checks + + info, err := system.GetInfoWithChecks(e.ctx, checks) + if err != nil { + log.Warnf("failed to get system info with checks: %v", err) + info = system.GetInfo(e.ctx) + } + + if err := e.mgmClient.SyncMeta(info); err != nil { + log.Errorf("could not sync meta: error %s", err) + return err + } return nil } @@ -521,8 +560,8 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { } else { if sshConf.GetSshEnabled() { - if runtime.GOOS == "windows" { - log.Warnf("running SSH server on Windows is not supported") + if runtime.GOOS == "windows" || runtime.GOOS == "freebsd" { + log.Warnf("running SSH server on %s is not supported", runtime.GOOS) return nil } // start SSH server if it wasn't running @@ -595,7 +634,14 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { // E.g. when a new peer has been registered and we are allowed to connect to it. func (e *Engine) receiveManagementEvents() { go func() { - err := e.mgmClient.Sync(e.ctx, e.handleSync) + info, err := system.GetInfoWithChecks(e.ctx, e.checks) + if err != nil { + log.Warnf("failed to get system info with checks: %v", err) + info = system.GetInfo(e.ctx) + } + + // err = e.mgmClient.Sync(info, e.handleSync) + err = e.mgmClient.Sync(e.ctx, info, e.handleSync) if err != nil { // happens if management is unavailable for a long time. // We want to cancel the operation of the whole client @@ -662,6 +708,20 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { return nil } + protoRoutes := networkMap.GetRoutes() + if protoRoutes == nil { + protoRoutes = []*mgmProto.Route{} + } + + _, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes)) + if err != nil { + log.Errorf("failed to update clientRoutes, err: %v", err) + } + + e.clientRoutesMu.Lock() + e.clientRoutes = clientRoutes + e.clientRoutesMu.Unlock() + log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers())) e.updateOfflinePeers(networkMap.GetOfflinePeers()) @@ -703,19 +763,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { } } } - protoRoutes := networkMap.GetRoutes() - if protoRoutes == nil { - protoRoutes = []*mgmProto.Route{} - } - - _, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes)) - if err != nil { - log.Errorf("failed to update clientRoutes, err: %v", err) - } - - e.clientRoutesMu.Lock() - e.clientRoutes = clientRoutes - e.clientRoutesMu.Unlock() protoDNSConfig := networkMap.GetDNSConfig() if protoDNSConfig == nil { @@ -743,15 +790,24 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route { routes := make([]*route.Route, 0) for _, protoRoute := range protoRoutes { - _, prefix, _ := route.ParseNetwork(protoRoute.Network) + var prefix netip.Prefix + if len(protoRoute.Domains) == 0 { + var err error + if prefix, err = netip.ParsePrefix(protoRoute.Network); err != nil { + log.Errorf("Failed to parse prefix %s: %v", protoRoute.Network, err) + continue + } + } convertedRoute := &route.Route{ ID: route.ID(protoRoute.ID), Network: prefix, + Domains: domain.FromPunycodeList(protoRoute.Domains), NetID: route.NetID(protoRoute.NetID), NetworkType: route.NetworkType(protoRoute.NetworkType), Peer: protoRoute.Peer, Metric: int(protoRoute.Metric), Masquerade: protoRoute.Masquerade, + KeepRoute: protoRoute.KeepRoute, } routes = append(routes, convertedRoute) } @@ -1105,7 +1161,8 @@ func (e *Engine) close() { } func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) { - netMap, err := e.mgmClient.GetNetworkMap() + info := system.GetInfo(e.ctx) + netMap, err := e.mgmClient.GetNetworkMap(info) if err != nil { return nil, nil, err } @@ -1134,7 +1191,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) { default: } - return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs) + return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs, e.addrViaRoutes) } func (e *Engine) wgInterfaceCreate() (err error) { @@ -1309,6 +1366,15 @@ func (e *Engine) probeTURNs() []relay.ProbeResult { return relay.ProbeAll(e.ctx, relay.ProbeTURN, e.TURNs) } +func (e *Engine) restartEngine() { + if err := e.Stop(); err != nil { + log.Errorf("Failed to stop engine: %v", err) + } + if err := e.Start(); err != nil { + log.Errorf("Failed to start engine: %v", err) + } +} + func (e *Engine) startNetworkMonitor() { if !e.config.NetworkMonitor { log.Infof("Network monitor is disabled, not starting") @@ -1317,17 +1383,54 @@ func (e *Engine) startNetworkMonitor() { e.networkMonitor = networkmonitor.New() go func() { + var mu sync.Mutex + var debounceTimer *time.Timer + + // Start the network monitor with a callback, Start will block until the monitor is stopped, + // a network change is detected, or an error occurs on start up err := e.networkMonitor.Start(e.ctx, func() { - log.Infof("Network monitor detected network change, restarting engine") - if err := e.Stop(); err != nil { - log.Errorf("Failed to stop engine: %v", err) - } - if err := e.Start(); err != nil { - log.Errorf("Failed to start engine: %v", err) + // This function is called when a network change is detected + mu.Lock() + defer mu.Unlock() + + if debounceTimer != nil { + debounceTimer.Stop() } + + // Set a new timer to debounce rapid network changes + debounceTimer = time.AfterFunc(1*time.Second, func() { + // This function is called after the debounce period + mu.Lock() + defer mu.Unlock() + + log.Infof("Network monitor detected network change, restarting engine") + e.restartEngine() + }) }) if err != nil && !errors.Is(err, networkmonitor.ErrStopped) { log.Errorf("Network monitor: %v", err) } }() } + +func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) { + var vpnRoutes []netip.Prefix + for _, routes := range e.GetClientRoutes() { + if len(routes) > 0 && routes[0] != nil { + vpnRoutes = append(vpnRoutes, routes[0].Network) + } + } + + if isVpn, prefix := systemops.IsAddrRouted(addr, vpnRoutes); isVpn { + return true, prefix, nil + } + + return false, netip.Prefix{}, nil +} + +// isChecksEqual checks if two slices of checks are equal. +func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool { + return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool { + return slices.Equal(checks.Files, oChecks.Files) + }) +} diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index af0662541..79b3cd498 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -17,6 +17,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" "google.golang.org/grpc/keepalive" @@ -58,9 +59,9 @@ var ( ) func TestEngine_SSH(t *testing.T) { - - if runtime.GOOS == "windows" { - t.Skip("skipping TestEngine_SSH on Windows") + // todo resolve test execution on freebsd + if runtime.GOOS == "windows" || runtime.GOOS == "freebsd" { + t.Skip("skipping TestEngine_SSH") } key, err := wgtypes.GeneratePrivateKey() @@ -80,7 +81,7 @@ func TestEngine_SSH(t *testing.T) { WgPort: 33100, ServerSSHAllowed: true, }, - MobileDependency{}, peer.NewRecorder("https://mgm")) + MobileDependency{}, peer.NewRecorder("https://mgm"), nil) engine.dnsServer = &dns.MockServer{ UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, @@ -176,7 +177,7 @@ func TestEngine_SSH(t *testing.T) { t.Fatal(err) } - //time.Sleep(250 * time.Millisecond) + // time.Sleep(250 * time.Millisecond) assert.NotNil(t, engine.sshServer) assert.Contains(t, sshPeersRemoved, "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=") @@ -215,16 +216,16 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { WgAddr: "100.64.0.1/24", WgPrivateKey: key, WgPort: 33100, - }, MobileDependency{}, peer.NewRecorder("https://mgm")) + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) newNet, err := stdnet.NewNet() if err != nil { t.Fatal(err) } - engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil) + engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil) if err != nil { t.Fatal(err) } - engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder, nil) + engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, nil) engine.dnsServer = &dns.MockServer{ UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, } @@ -397,7 +398,7 @@ func TestEngine_Sync(t *testing.T) { // feed updates to Engine via mocked Management client updates := make(chan *mgmtProto.SyncResponse) defer close(updates) - syncFunc := func(ctx context.Context, msgHandler func(msg *mgmtProto.SyncResponse) error) error { + syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error { for msg := range updates { err := msgHandler(msg) if err != nil { @@ -412,7 +413,7 @@ func TestEngine_Sync(t *testing.T) { WgAddr: "100.64.0.1/24", WgPrivateKey: key, WgPort: 33100, - }, MobileDependency{}, peer.NewRecorder("https://mgm")) + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) engine.ctx = ctx engine.dnsServer = &dns.MockServer{ @@ -572,13 +573,13 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { WgAddr: wgAddr, WgPrivateKey: key, WgPort: 33100, - }, MobileDependency{}, peer.NewRecorder("https://mgm")) + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) engine.ctx = ctx newNet, err := stdnet.NewNet() if err != nil { t.Fatal(err) } - engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil) + engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil) assert.NoError(t, err, "shouldn't return error") input := struct { inputSerial uint64 @@ -743,14 +744,14 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { WgAddr: wgAddr, WgPrivateKey: key, WgPort: 33100, - }, MobileDependency{}, peer.NewRecorder("https://mgm")) + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) engine.ctx = ctx newNet, err := stdnet.NewNet() if err != nil { t.Fatal(err) } - engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil) + engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil, nil) assert.NoError(t, err, "shouldn't return error") mockRouteManager := &routemanager.MockManager{ @@ -816,13 +817,13 @@ func TestEngine_MultiplePeers(t *testing.T) { ctx, cancel := context.WithCancel(CtxInitState(context.Background())) defer cancel() - sigServer, signalAddr, err := startSignal() + sigServer, signalAddr, err := startSignal(t) if err != nil { t.Fatal(err) return } defer sigServer.Stop() - mgmtServer, mgmtAddr, err := startManagement(dir) + mgmtServer, mgmtAddr, err := startManagement(t, dir) if err != nil { t.Fatal(err) return @@ -1015,12 +1016,14 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin } relayMgr := relayClient.NewManager(ctx, "", key.PublicKey().String()) - e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm")), nil + e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil e.ctx = ctx return e, err } -func startSignal() (*grpc.Server, string, error) { +func startSignal(t *testing.T) (*grpc.Server, string, error) { + t.Helper() + s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) lis, err := net.Listen("tcp", "localhost:0") @@ -1028,7 +1031,9 @@ func startSignal() (*grpc.Server, string, error) { log.Fatalf("failed to listen: %v", err) } - proto.RegisterSignalExchangeServer(s, signalServer.NewServer()) + srv, err := signalServer.NewServer(otel.Meter("")) + require.NoError(t, err) + proto.RegisterSignalExchangeServer(s, srv) go func() { if err = s.Serve(lis); err != nil { @@ -1039,7 +1044,9 @@ func startSignal() (*grpc.Server, string, error) { return s, lis.Addr().String(), nil } -func startManagement(dataDir string) (*grpc.Server, string, error) { +func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error) { + t.Helper() + config := &server.Config{ Stuns: []*server.Host{}, TURNConfig: &server.TURNConfig{}, @@ -1056,23 +1063,25 @@ func startManagement(dataDir string) (*grpc.Server, string, error) { return nil, "", err } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, _, err := server.NewTestStoreFromJson(config.Datadir) + + store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir) if err != nil { return nil, "", err } + t.Cleanup(cleanUp) peersUpdateManager := server.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} if err != nil { return nil, "", err } - ia, _ := integrations.NewIntegratedValidator(eventStore) - accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia) + ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) + accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia) if err != nil { return nil, "", err } turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "") - mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil) if err != nil { return nil, "", err } diff --git a/client/internal/networkmonitor/monitor_bsd.go b/client/internal/networkmonitor/monitor_bsd.go index de4209f5d..8d6ccd51b 100644 --- a/client/internal/networkmonitor/monitor_bsd.go +++ b/client/internal/networkmonitor/monitor_bsd.go @@ -5,8 +5,6 @@ package networkmonitor import ( "context" "fmt" - "net" - "net/netip" "syscall" "unsafe" @@ -14,10 +12,10 @@ import ( "golang.org/x/net/route" "golang.org/x/sys/unix" - "github.com/netbirdio/netbird/client/internal/routemanager" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" ) -func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthopv6 netip.Addr, intfv6 *net.Interface, callback func()) error { +func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error { fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) if err != nil { return fmt.Errorf("failed to open routing socket: %v", err) @@ -47,24 +45,6 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0])) switch msg.Type { - - // handle interface state changes - case unix.RTM_IFINFO: - ifinfo, err := parseInterfaceMessage(buf[:n]) - if err != nil { - log.Errorf("Network monitor: error parsing interface message: %v", err) - continue - } - if msg.Flags&unix.IFF_UP != 0 { - continue - } - if (intfv4 == nil || ifinfo.Index != intfv4.Index) && (intfv6 == nil || ifinfo.Index != intfv6.Index) { - continue - } - - log.Infof("Network monitor: monitored interface (%s) is down.", ifinfo.Name) - go callback() - // handle route changes case unix.RTM_ADD, syscall.RTM_DELETE: route, err := parseRouteMessage(buf[:n]) @@ -86,7 +66,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf) go callback() case unix.RTM_DELETE: - if intfv4 != nil && route.Gw.Compare(nexthopv4) == 0 || intfv6 != nil && route.Gw.Compare(nexthopv6) == 0 { + if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 { log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf) go callback() } @@ -96,25 +76,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac } } -func parseInterfaceMessage(buf []byte) (*route.InterfaceMessage, error) { - msgs, err := route.ParseRIB(route.RIBTypeInterface, buf) - if err != nil { - return nil, fmt.Errorf("parse RIB: %v", err) - } - - if len(msgs) != 1 { - return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs) - } - - msg, ok := msgs[0].(*route.InterfaceMessage) - if !ok { - return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0]) - } - - return msg, nil -} - -func parseRouteMessage(buf []byte) (*routemanager.Route, error) { +func parseRouteMessage(buf []byte) (*systemops.Route, error) { msgs, err := route.ParseRIB(route.RIBTypeRoute, buf) if err != nil { return nil, fmt.Errorf("parse RIB: %v", err) @@ -129,5 +91,5 @@ func parseRouteMessage(buf []byte) (*routemanager.Route, error) { return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0]) } - return routemanager.MsgToRoute(msg) + return systemops.MsgToRoute(msg) } diff --git a/client/internal/networkmonitor/monitor_generic.go b/client/internal/networkmonitor/monitor_generic.go index 97cfbc2ca..f5cc19473 100644 --- a/client/internal/networkmonitor/monitor_generic.go +++ b/client/internal/networkmonitor/monitor_generic.go @@ -6,14 +6,13 @@ import ( "context" "errors" "fmt" - "net" "net/netip" "runtime/debug" "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/internal/routemanager" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" ) // Start begins monitoring network changes. When a change is detected, it calls the callback asynchronously and returns. @@ -29,23 +28,22 @@ func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error nw.wg.Add(1) defer nw.wg.Done() - var nexthop4, nexthop6 netip.Addr - var intf4, intf6 *net.Interface + var nexthop4, nexthop6 systemops.Nexthop operation := func() error { var errv4, errv6 error - nexthop4, intf4, errv4 = routemanager.GetNextHop(netip.IPv4Unspecified()) - nexthop6, intf6, errv6 = routemanager.GetNextHop(netip.IPv6Unspecified()) + nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified()) + nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified()) if errv4 != nil && errv6 != nil { return errors.New("failed to get default next hops") } if errv4 == nil { - log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4, intf4.Name) + log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name) } if errv6 == nil { - log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6, intf6.Name) + log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name) } // continue if either route was found @@ -65,7 +63,7 @@ func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error } }() - if err := checkChange(ctx, nexthop4, intf4, nexthop6, intf6, callback); err != nil { + if err := checkChange(ctx, nexthop4, nexthop6, callback); err != nil { return fmt.Errorf("check change: %w", err) } diff --git a/client/internal/networkmonitor/monitor_linux.go b/client/internal/networkmonitor/monitor_linux.go index 3f93c6ac6..035be1f09 100644 --- a/client/internal/networkmonitor/monitor_linux.go +++ b/client/internal/networkmonitor/monitor_linux.go @@ -6,27 +6,22 @@ import ( "context" "errors" "fmt" - "net" - "net/netip" "syscall" log "github.com/sirupsen/logrus" "github.com/vishvananda/netlink" + + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" ) -func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthop6 netip.Addr, intfv6 *net.Interface, callback func()) error { - if intfv4 == nil && intfv6 == nil { +func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error { + if nexthopv4.Intf == nil && nexthopv6.Intf == nil { return errors.New("no interfaces available") } - linkChan := make(chan netlink.LinkUpdate) done := make(chan struct{}) defer close(done) - if err := netlink.LinkSubscribe(linkChan, done); err != nil { - return fmt.Errorf("subscribe to link updates: %v", err) - } - routeChan := make(chan netlink.RouteUpdate) if err := netlink.RouteSubscribe(routeChan, done); err != nil { return fmt.Errorf("subscribe to route updates: %v", err) @@ -38,25 +33,6 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac case <-ctx.Done(): return ErrStopped - // handle interface state changes - case update := <-linkChan: - if (intfv4 == nil || update.Index != int32(intfv4.Index)) && (intfv6 == nil || update.Index != int32(intfv6.Index)) { - continue - } - - switch update.Header.Type { - case syscall.RTM_DELLINK: - log.Infof("Network monitor: monitored interface (%s) is gone", update.Link.Attrs().Name) - go callback() - return nil - case syscall.RTM_NEWLINK: - if (update.IfInfomsg.Flags&syscall.IFF_RUNNING) == 0 && update.Link.Attrs().OperState == netlink.OperDown { - log.Infof("Network monitor: monitored interface (%s) is down.", update.Link.Attrs().Name) - go callback() - return nil - } - } - // handle route changes case route := <-routeChan: // default route and main table @@ -70,7 +46,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac go callback() return nil case syscall.RTM_DELROUTE: - if intfv4 != nil && route.Gw.Equal(nexthopv4.AsSlice()) || intfv6 != nil && route.Gw.Equal(nexthop6.AsSlice()) { + if nexthopv4.Intf != nil && route.Gw.Equal(nexthopv4.IP.AsSlice()) || nexthopv6.Intf != nil && route.Gw.Equal(nexthopv6.IP.AsSlice()) { log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex) go callback() return nil diff --git a/client/internal/networkmonitor/monitor_windows.go b/client/internal/networkmonitor/monitor_windows.go index b8d9c6de7..e24bdd066 100644 --- a/client/internal/networkmonitor/monitor_windows.go +++ b/client/internal/networkmonitor/monitor_windows.go @@ -5,11 +5,12 @@ import ( "fmt" "net" "net/netip" + "strings" "time" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/internal/routemanager" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" ) const ( @@ -25,20 +26,16 @@ const ( const interval = 10 * time.Second -func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthopv6 netip.Addr, intfv6 *net.Interface, callback func()) error { - var neighborv4, neighborv6 *routemanager.Neighbor +func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error { + var neighborv4, neighborv6 *systemops.Neighbor { initialNeighbors, err := getNeighbors() if err != nil { return fmt.Errorf("get neighbors: %w", err) } - if n, ok := initialNeighbors[nexthopv4]; ok { - neighborv4 = &n - } - if n, ok := initialNeighbors[nexthopv6]; ok { - neighborv6 = &n - } + neighborv4 = assignNeighbor(nexthopv4, initialNeighbors) + neighborv6 = assignNeighbor(nexthopv6, initialNeighbors) } log.Debugf("Network monitor: initial IPv4 neighbor: %v, IPv6 neighbor: %v", neighborv4, neighborv6) @@ -50,7 +47,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac case <-ctx.Done(): return ErrStopped case <-ticker.C: - if changed(nexthopv4, intfv4, neighborv4, nexthopv6, intfv6, neighborv6) { + if changed(nexthopv4, neighborv4, nexthopv6, neighborv6) { go callback() return nil } @@ -58,13 +55,21 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac } } +func assignNeighbor(nexthop systemops.Nexthop, initialNeighbors map[netip.Addr]systemops.Neighbor) *systemops.Neighbor { + if n, ok := initialNeighbors[nexthop.IP]; ok && + n.State != unreachable && + n.State != incomplete && + n.State != tbd { + return &n + } + return nil +} + func changed( - nexthopv4 netip.Addr, - intfv4 *net.Interface, - neighborv4 *routemanager.Neighbor, - nexthopv6 netip.Addr, - intfv6 *net.Interface, - neighborv6 *routemanager.Neighbor, + nexthopv4 systemops.Nexthop, + neighborv4 *systemops.Neighbor, + nexthopv6 systemops.Nexthop, + neighborv6 *systemops.Neighbor, ) bool { neighbors, err := getNeighbors() if err != nil { @@ -81,7 +86,7 @@ func changed( return false } - if routeChanged(nexthopv4, intfv4, routes) || routeChanged(nexthopv6, intfv6, routes) { + if routeChanged(nexthopv4, nexthopv4.Intf, routes) || routeChanged(nexthopv6, nexthopv6.Intf, routes) { return true } @@ -89,44 +94,74 @@ func changed( } // routeChanged checks if the default routes still point to our nexthop/interface -func routeChanged(nexthop netip.Addr, intf *net.Interface, routes map[netip.Prefix]routemanager.Route) bool { - if !nexthop.IsValid() { +func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route) bool { + if !nexthop.IP.IsValid() { return false } - var unspec netip.Prefix - if nexthop.Is6() { - unspec = netip.PrefixFrom(netip.IPv6Unspecified(), 0) - } else { - unspec = netip.PrefixFrom(netip.IPv4Unspecified(), 0) - } + unspec := getUnspecifiedPrefix(nexthop.IP) + defaultRoutes, foundMatchingRoute := processRoutes(nexthop, intf, routes, unspec) - if r, ok := routes[unspec]; ok { - if r.Nexthop != nexthop || compareIntf(r.Interface, intf) != 0 { - intf := "" - if r.Interface != nil { - intf = r.Interface.Name - } - log.Infof("network monitor: default route changed: %s via %s (%s)", r.Destination, r.Nexthop, intf) - return true - } - } else { - log.Infof("network monitor: default route is gone") + log.Tracef("network monitor: all default routes:\n%s", strings.Join(defaultRoutes, "\n")) + + if !foundMatchingRoute { + logRouteChange(nexthop.IP, intf) return true } return false - } -func neighborChanged(nexthop netip.Addr, neighbor *routemanager.Neighbor, neighbors map[netip.Addr]routemanager.Neighbor) bool { +func getUnspecifiedPrefix(ip netip.Addr) netip.Prefix { + if ip.Is6() { + return netip.PrefixFrom(netip.IPv6Unspecified(), 0) + } + return netip.PrefixFrom(netip.IPv4Unspecified(), 0) +} + +func processRoutes(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) { + var defaultRoutes []string + foundMatchingRoute := false + + for _, r := range routes { + if r.Destination == unspec { + routeInfo := formatRouteInfo(r) + defaultRoutes = append(defaultRoutes, routeInfo) + + if r.Nexthop == nexthop.IP && compareIntf(r.Interface, intf) == 0 { + foundMatchingRoute = true + log.Debugf("network monitor: found matching default route: %s", routeInfo) + } + } + } + + return defaultRoutes, foundMatchingRoute +} + +func formatRouteInfo(r systemops.Route) string { + newIntf := "" + if r.Interface != nil { + newIntf = r.Interface.Name + } + return fmt.Sprintf("Nexthop: %s, Interface: %s", r.Nexthop, newIntf) +} + +func logRouteChange(ip netip.Addr, intf *net.Interface) { + oldIntf := "" + if intf != nil { + oldIntf = intf.Name + } + log.Infof("network monitor: default route for %s (%s) is gone or changed", ip, oldIntf) +} + +func neighborChanged(nexthop systemops.Nexthop, neighbor *systemops.Neighbor, neighbors map[netip.Addr]systemops.Neighbor) bool { if neighbor == nil { return false } // TODO: consider non-local nexthops, e.g. on point-to-point interfaces - if n, ok := neighbors[nexthop]; ok { - if n.State != reachable && n.State != permanent { + if n, ok := neighbors[nexthop.IP]; ok { + if n.State == unreachable || n.State == incomplete { log.Infof("network monitor: neighbor %s (%s) is not reachable: %s", neighbor.IPAddress, neighbor.LinkLayerAddress, stateFromInt(n.State)) return true } else if n.InterfaceIndex != neighbor.InterfaceIndex { @@ -150,13 +185,13 @@ func neighborChanged(nexthop netip.Addr, neighbor *routemanager.Neighbor, neighb return false } -func getNeighbors() (map[netip.Addr]routemanager.Neighbor, error) { - entries, err := routemanager.GetNeighbors() +func getNeighbors() (map[netip.Addr]systemops.Neighbor, error) { + entries, err := systemops.GetNeighbors() if err != nil { return nil, fmt.Errorf("get neighbors: %w", err) } - neighbours := make(map[netip.Addr]routemanager.Neighbor, len(entries)) + neighbours := make(map[netip.Addr]systemops.Neighbor, len(entries)) for _, entry := range entries { neighbours[entry.IPAddress] = entry } @@ -164,18 +199,13 @@ func getNeighbors() (map[netip.Addr]routemanager.Neighbor, error) { return neighbours, nil } -func getRoutes() (map[netip.Prefix]routemanager.Route, error) { - entries, err := routemanager.GetRoutes() +func getRoutes() ([]systemops.Route, error) { + entries, err := systemops.GetRoutes() if err != nil { return nil, fmt.Errorf("get routes: %w", err) } - routes := make(map[netip.Prefix]routemanager.Route, len(entries)) - for _, entry := range entries { - routes[entry.Destination] = entry - } - - return routes, nil + return entries, nil } func stateFromInt(state uint8) string { diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index ac5c46f83..3901709ef 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -62,9 +62,6 @@ type ConnConfig struct { ICEConfig ICEConfig } -type BeforeAddPeerHookFunc func(connID nbnet.ConnectionID, IP net.IP) error -type AfterRemovePeerHookFunc func(connID nbnet.ConnectionID) error - type WorkerCallbacks struct { OnRelayReadyCallback func(info RelayConnInfo) OnRelayStatusChanged func(ConnStatus) @@ -99,8 +96,8 @@ type Conn struct { workerRelay *WorkerRelay connID nbnet.ConnectionID - beforeAddPeerHooks []BeforeAddPeerHookFunc - afterRemovePeerHooks []AfterRemovePeerHookFunc + beforeAddPeerHooks []nbnet.AddHookFunc + afterRemovePeerHooks []nbnet.RemoveHookFunc endpointRelay *net.UDPAddr @@ -266,11 +263,10 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMa conn.workerICE.OnRemoteCandidate(candidate, haRoutes) } -func (conn *Conn) AddBeforeAddPeerHook(hook BeforeAddPeerHookFunc) { +func (conn *Conn) AddBeforeAddPeerHook(hook nbnet.AddHookFunc) { conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook) } - -func (conn *Conn) AddAfterRemovePeerHook(hook AfterRemovePeerHookFunc) { +func (conn *Conn) AddAfterRemovePeerHook(hook nbnet.RemoveHookFunc) { conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook) } diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go index 78ccf9724..7caef01bf 100644 --- a/client/internal/peer/conn_test.go +++ b/client/internal/peer/conn_test.go @@ -46,7 +46,7 @@ func TestNewConn_interfaceFilter(t *testing.T) { } func TestConn_GetKey(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort) + wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort) defer func() { _ = wgProxyFactory.Free() }() @@ -61,7 +61,7 @@ func TestConn_GetKey(t *testing.T) { } func TestConn_OnRemoteOffer(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort) + wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort) defer func() { _ = wgProxyFactory.Free() }() @@ -98,7 +98,7 @@ func TestConn_OnRemoteOffer(t *testing.T) { } func TestConn_OnRemoteAnswer(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort) + wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort) defer func() { _ = wgProxyFactory.Free() }() @@ -134,7 +134,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) { wg.Wait() } func TestConn_Status(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort) + wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort) defer func() { _ = wgProxyFactory.Free() }() @@ -172,7 +172,7 @@ func TestConn_Status(t *testing.T) { func TestConn_Switch(t *testing.T) { ctx := context.Background() - wgProxyFactory := wgproxy.NewFactory(ctx, connConf.LocalWgPort) + wgProxyFactory := wgproxy.NewFactory(ctx, false, connConf.LocalWgPort) connConfAlice := ConnConfig{ Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index ddea7d04e..a7cfb95c4 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -2,14 +2,17 @@ package peer import ( "errors" + "net/netip" "sync" "time" + "golang.org/x/exp/maps" "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/management/domain" ) // State contains the latest state of a peer @@ -37,25 +40,25 @@ type State struct { // AddRoute add a single route to routes map func (s *State) AddRoute(network string) { s.Mux.Lock() + defer s.Mux.Unlock() if s.routes == nil { s.routes = make(map[string]struct{}) } s.routes[network] = struct{}{} - s.Mux.Unlock() } // SetRoutes set state routes func (s *State) SetRoutes(routes map[string]struct{}) { s.Mux.Lock() + defer s.Mux.Unlock() s.routes = routes - s.Mux.Unlock() } // DeleteRoute removes a route from the network amp func (s *State) DeleteRoute(network string) { s.Mux.Lock() + defer s.Mux.Unlock() delete(s.routes, network) - s.Mux.Unlock() } // GetRoutes return routes map @@ -117,22 +120,23 @@ type FullStatus struct { // Status holds a state of peers, signal, management connections and relays type Status struct { - mux sync.Mutex - peers map[string]State - changeNotify map[string]chan struct{} - signalState bool - signalError error - managementState bool - managementError error - relayStates []relay.ProbeResult - localPeer LocalPeerState - offlinePeers []State - mgmAddress string - signalAddress string - notifier *notifier - rosenpassEnabled bool - rosenpassPermissive bool - nsGroupStates []NSGroupState + mux sync.Mutex + peers map[string]State + changeNotify map[string]chan struct{} + signalState bool + signalError error + managementState bool + managementError error + relayStates []relay.ProbeResult + localPeer LocalPeerState + offlinePeers []State + mgmAddress string + signalAddress string + notifier *notifier + rosenpassEnabled bool + rosenpassPermissive bool + nsGroupStates []NSGroupState + resolvedDomainsStates map[domain.Domain][]netip.Prefix // To reduce the number of notification invocation this bool will be true when need to call the notification // Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events @@ -143,11 +147,12 @@ type Status struct { // NewRecorder returns a new Status instance func NewRecorder(mgmAddress string) *Status { return &Status{ - peers: make(map[string]State), - changeNotify: make(map[string]chan struct{}), - offlinePeers: make([]State, 0), - notifier: newNotifier(), - mgmAddress: mgmAddress, + peers: make(map[string]State), + changeNotify: make(map[string]chan struct{}), + offlinePeers: make([]State, 0), + notifier: newNotifier(), + mgmAddress: mgmAddress, + resolvedDomainsStates: make(map[domain.Domain][]netip.Prefix), } } @@ -188,7 +193,7 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) { state, ok := d.peers[peerPubKey] if !ok { - return State{}, errors.New("peer not found") + return State{}, iface.ErrPeerNotFound } return state, nil } @@ -429,6 +434,18 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) { d.nsGroupStates = dnsStates } +func (d *Status) UpdateResolvedDomainsStates(domain domain.Domain, prefixes []netip.Prefix) { + d.mux.Lock() + defer d.mux.Unlock() + d.resolvedDomainsStates[domain] = prefixes +} + +func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) { + d.mux.Lock() + defer d.mux.Unlock() + delete(d.resolvedDomainsStates, domain) +} + func (d *Status) GetRosenpassState() RosenpassState { return RosenpassState{ d.rosenpassEnabled, @@ -493,6 +510,12 @@ func (d *Status) GetDNSStates() []NSGroupState { return d.nsGroupStates } +func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix { + d.mux.Lock() + defer d.mux.Unlock() + return maps.Clone(d.resolvedDomainsStates) +} + // GetFullStatus gets full status func (d *Status) GetFullStatus() FullStatus { d.mux.Lock() diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index e82f4b1da..3c230df21 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -3,19 +3,20 @@ package routemanager import ( "context" "fmt" - "net" - "net/netip" "time" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" + nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" + "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" + "github.com/netbirdio/netbird/client/internal/routemanager/static" "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" ) -const minRangeBits = 7 - type routerPeerStatus struct { connected bool relayed bool @@ -28,33 +29,42 @@ type routesUpdate struct { routes []*route.Route } +// RouteHandler defines the interface for handling routes +type RouteHandler interface { + String() string + AddRoute(ctx context.Context) error + RemoveRoute() error + AddAllowedIPs(peerKey string) error + RemoveAllowedIPs() error +} + type clientNetwork struct { ctx context.Context - stop context.CancelFunc + cancel context.CancelFunc statusRecorder *peer.Status wgInterface *iface.WGIface routes map[route.ID]*route.Route routeUpdate chan routesUpdate peerStateUpdate chan struct{} routePeersNotifiers map[string]chan struct{} - chosenRoute *route.Route - network netip.Prefix + currentChosen *route.Route + handler RouteHandler updateSerial uint64 } -func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *peer.Status, network netip.Prefix) *clientNetwork { +func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface *iface.WGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork { ctx, cancel := context.WithCancel(ctx) client := &clientNetwork{ ctx: ctx, - stop: cancel, + cancel: cancel, statusRecorder: statusRecorder, wgInterface: wgInterface, routes: make(map[route.ID]*route.Route), routePeersNotifiers: make(map[string]chan struct{}), routeUpdate: make(chan routesUpdate), peerStateUpdate: make(chan struct{}), - network: network, + handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder), } return client } @@ -86,8 +96,8 @@ func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus { // * Metric: Routes with lower metrics (better) are prioritized. // * Non-relayed: Routes without relays are preferred. // * Direct connections: Routes with direct peer connections are favored. -// * Stability: In case of equal scores, the currently active route (if any) is maintained. // * Latency: Routes with lower latency are prioritized. +// * Stability: In case of equal scores, the currently active route (if any) is maintained. // // It returns the ID of the selected optimal route. func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID { @@ -96,8 +106,8 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID] currScore := float64(0) currID := route.ID("") - if c.chosenRoute != nil { - currID = c.chosenRoute.ID + if c.currentChosen != nil { + currID = c.currentChosen.ID } for _, r := range c.routes { @@ -151,18 +161,18 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID] peers = append(peers, r.Peer) } - log.Warnf("the network %s has not been assigned a routing peer as no peers from the list %s are currently connected", c.network, peers) + log.Warnf("The network [%v] has not been assigned a routing peer as no peers from the list %s are currently connected", c.handler, peers) case chosen != currID: // we compare the current score + 10ms to the chosen score to avoid flapping between routes if currScore != 0 && currScore+0.01 > chosenScore { - log.Debugf("keeping current routing peer because the score difference with latency is less than 0.01(10ms), current: %f, new: %f", currScore, chosenScore) + log.Debugf("Keeping current routing peer because the score difference with latency is less than 0.01(10ms), current: %f, new: %f", currScore, chosenScore) return currID } var p string if rt := c.routes[chosen]; rt != nil { p = rt.Peer } - log.Infof("new chosen route is %s with peer %s with score %f for network %s", chosen, p, chosenScore, c.network) + log.Infof("New chosen route is %s with peer %s with score %f for network [%v]", chosen, p, chosenScore, c.handler) } return chosen @@ -196,98 +206,103 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() { } } -func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { - state, err := c.statusRecorder.GetPeer(peerKey) - if err != nil { - return fmt.Errorf("get peer state: %v", err) - } +func (c *clientNetwork) removeRouteFromWireguardPeer() error { + c.removeStateRoute() - state.DeleteRoute(c.network.String()) - if err := c.statusRecorder.UpdatePeerState(state); err != nil { - log.Warnf("Failed to update peer state: %v", err) - } - - if state.ConnStatus != peer.StatusConnected { - return nil - } - - err = c.wgInterface.RemoveAllowedIP(peerKey, c.network.String()) - if err != nil { - return fmt.Errorf("remove allowed IP %s removed for peer %s, err: %v", - c.network, c.chosenRoute.Peer, err) + if err := c.handler.RemoveAllowedIPs(); err != nil { + return fmt.Errorf("remove allowed IPs: %w", err) } return nil } func (c *clientNetwork) removeRouteFromPeerAndSystem() error { - if c.chosenRoute != nil { - if err := removeVPNRoute(c.network, c.getAsInterface()); err != nil { - return fmt.Errorf("remove route %s from system, err: %v", c.network, err) - } - - if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil { - return fmt.Errorf("remove route: %v", err) - } + if c.currentChosen == nil { + return nil } - return nil + + var merr *multierror.Error + + if err := c.removeRouteFromWireguardPeer(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)) + } + if err := c.handler.RemoveRoute(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove route: %w", err)) + } + + return nberrors.FormatErrorOrNil(merr) } func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { routerPeerStatuses := c.getRouterPeerStatuses() - chosen := c.getBestRouteFromStatuses(routerPeerStatuses) + newChosenID := c.getBestRouteFromStatuses(routerPeerStatuses) // If no route is chosen, remove the route from the peer and system - if chosen == "" { + if newChosenID == "" { if err := c.removeRouteFromPeerAndSystem(); err != nil { - return fmt.Errorf("remove route from peer and system: %v", err) + return fmt.Errorf("remove route for peer %s: %w", c.currentChosen.Peer, err) } - c.chosenRoute = nil + c.currentChosen = nil return nil } // If the chosen route is the same as the current route, do nothing - if c.chosenRoute != nil && c.chosenRoute.ID == chosen { - if c.chosenRoute.IsEqual(c.routes[chosen]) { - return nil - } + if c.currentChosen != nil && c.currentChosen.ID == newChosenID && + c.currentChosen.IsEqual(c.routes[newChosenID]) { + return nil } - if c.chosenRoute != nil { - // If a previous route exists, remove it from the peer - if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil { - return fmt.Errorf("remove route from peer: %v", err) + if c.currentChosen == nil { + // If they were not previously assigned to another peer, add routes to the system first + if err := c.handler.AddRoute(c.ctx); err != nil { + return fmt.Errorf("add route: %w", err) } } else { - // otherwise add the route to the system - if err := addVPNRoute(c.network, c.getAsInterface()); err != nil { - return fmt.Errorf("route %s couldn't be added for peer %s, err: %v", - c.network.String(), c.wgInterface.Address().IP.String(), err) + // Otherwise, remove the allowed IPs from the previous peer first + if err := c.removeRouteFromWireguardPeer(); err != nil { + return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err) } } - c.chosenRoute = c.routes[chosen] + c.currentChosen = c.routes[newChosenID] - state, err := c.statusRecorder.GetPeer(c.chosenRoute.Peer) - if err != nil { - log.Errorf("Failed to get peer state: %v", err) - } else { - state.AddRoute(c.network.String()) - if err := c.statusRecorder.UpdatePeerState(state); err != nil { - log.Warnf("Failed to update peer state: %v", err) - } + if err := c.handler.AddAllowedIPs(c.currentChosen.Peer); err != nil { + return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err) } - if err := c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()); err != nil { - log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v", - c.network, c.chosenRoute.Peer, err) - } + c.addStateRoute() return nil } +func (c *clientNetwork) addStateRoute() { + state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer) + if err != nil { + log.Errorf("Failed to get peer state: %v", err) + return + } + + state.AddRoute(c.handler.String()) + if err := c.statusRecorder.UpdatePeerState(state); err != nil { + log.Warnf("Failed to update peer state: %v", err) + } +} + +func (c *clientNetwork) removeStateRoute() { + state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer) + if err != nil { + log.Errorf("Failed to get peer state: %v", err) + return + } + + state.DeleteRoute(c.handler.String()) + if err := c.statusRecorder.UpdatePeerState(state); err != nil { + log.Warnf("Failed to update peer state: %v", err) + } +} + func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) { go func() { c.routeUpdate <- update @@ -318,24 +333,23 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { for { select { case <-c.ctx.Done(): - log.Debugf("stopping watcher for network %s", c.network) - err := c.removeRouteFromPeerAndSystem() - if err != nil { - log.Errorf("Couldn't remove route from peer and system for network %s: %v", c.network, err) + log.Debugf("Stopping watcher for network [%v]", c.handler) + if err := c.removeRouteFromPeerAndSystem(); err != nil { + log.Errorf("Failed to remove routes for [%v]: %v", c.handler, err) } return case <-c.peerStateUpdate: err := c.recalculateRouteAndUpdatePeerAndSystem() if err != nil { - log.Errorf("Couldn't recalculate route and update peer and system: %v", err) + log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err) } case update := <-c.routeUpdate: if update.updateSerial < c.updateSerial { - log.Warnf("Received a routes update with smaller serial number, ignoring it") + log.Warnf("Received a routes update with smaller serial number (%d -> %d), ignoring it", c.updateSerial, update.updateSerial) continue } - log.Debugf("Received a new client network route update for %s", c.network) + log.Debugf("Received a new client network route update for [%v]", c.handler) c.handleUpdate(update) @@ -343,7 +357,7 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { err := c.recalculateRouteAndUpdatePeerAndSystem() if err != nil { - log.Errorf("Couldn't recalculate route and update peer and system for network %s: %v", c.network, err) + log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err) } c.startPeersStatusChangeWatcher() @@ -351,14 +365,9 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { } } -func (c *clientNetwork) getAsInterface() *net.Interface { - intf, err := net.InterfaceByName(c.wgInterface.Name()) - if err != nil { - log.Warnf("Couldn't get interface by name %s: %v", c.wgInterface.Name(), err) - intf = &net.Interface{ - Name: c.wgInterface.Name(), - } +func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status) RouteHandler { + if rt.IsDynamic() { + return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder) } - - return intf + return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter) } diff --git a/client/internal/routemanager/client_test.go b/client/internal/routemanager/client_test.go index 9419ea777..0ae10e568 100644 --- a/client/internal/routemanager/client_test.go +++ b/client/internal/routemanager/client_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/netbirdio/netbird/client/internal/routemanager/static" "github.com/netbirdio/netbird/route" ) @@ -340,9 +341,9 @@ func TestGetBestrouteFromStatuses(t *testing.T) { // create new clientNetwork client := &clientNetwork{ - network: netip.MustParsePrefix("192.168.0.0/24"), - routes: tc.existingRoutes, - chosenRoute: currentRoute, + handler: static.NewRoute(&route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, nil, nil), + routes: tc.existingRoutes, + currentChosen: currentRoute, } chosenRoute := client.getBestRouteFromStatuses(tc.statuses) diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go new file mode 100644 index 000000000..8429b4534 --- /dev/null +++ b/client/internal/routemanager/dynamic/route.go @@ -0,0 +1,378 @@ +package dynamic + +import ( + "context" + "fmt" + "net" + "net/netip" + "strings" + "sync" + "time" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" + "github.com/netbirdio/netbird/client/internal/routemanager/util" + "github.com/netbirdio/netbird/management/domain" + "github.com/netbirdio/netbird/route" +) + +const ( + DefaultInterval = time.Minute + + minInterval = 2 * time.Second + failureInterval = 5 * time.Second + + addAllowedIP = "add allowed IP %s: %w" +) + +type domainMap map[domain.Domain][]netip.Prefix + +type resolveResult struct { + domain domain.Domain + prefix netip.Prefix + err error +} + +type Route struct { + route *route.Route + routeRefCounter *refcounter.RouteRefCounter + allowedIPsRefcounter *refcounter.AllowedIPsRefCounter + interval time.Duration + dynamicDomains domainMap + mu sync.Mutex + currentPeerKey string + cancel context.CancelFunc + statusRecorder *peer.Status +} + +func NewRoute( + rt *route.Route, + routeRefCounter *refcounter.RouteRefCounter, + allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, + interval time.Duration, + statusRecorder *peer.Status, +) *Route { + return &Route{ + route: rt, + routeRefCounter: routeRefCounter, + allowedIPsRefcounter: allowedIPsRefCounter, + interval: interval, + dynamicDomains: domainMap{}, + statusRecorder: statusRecorder, + } +} + +func (r *Route) String() string { + s, err := r.route.Domains.String() + if err != nil { + return r.route.Domains.PunycodeString() + } + return s +} + +func (r *Route) AddRoute(ctx context.Context) error { + r.mu.Lock() + defer r.mu.Unlock() + + if r.cancel != nil { + r.cancel() + } + + ctx, r.cancel = context.WithCancel(ctx) + + go r.startResolver(ctx) + + return nil +} + +// RemoveRoute will stop the dynamic resolver and remove all dynamic routes. +// It doesn't touch allowed IPs, these should be removed separately and before calling this method. +func (r *Route) RemoveRoute() error { + r.mu.Lock() + defer r.mu.Unlock() + + if r.cancel != nil { + r.cancel() + } + + var merr *multierror.Error + for domain, prefixes := range r.dynamicDomains { + for _, prefix := range prefixes { + if _, err := r.routeRefCounter.Decrement(prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %w", prefix, err)) + } + } + log.Debugf("Removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", ")) + + r.statusRecorder.DeleteResolvedDomainsStates(domain) + } + + r.dynamicDomains = domainMap{} + + return nberrors.FormatErrorOrNil(merr) +} + +func (r *Route) AddAllowedIPs(peerKey string) error { + r.mu.Lock() + defer r.mu.Unlock() + + var merr *multierror.Error + for domain, domainPrefixes := range r.dynamicDomains { + for _, prefix := range domainPrefixes { + if err := r.incrementAllowedIP(domain, prefix, peerKey); err != nil { + merr = multierror.Append(merr, fmt.Errorf(addAllowedIP, prefix, err)) + } + } + } + r.currentPeerKey = peerKey + return nberrors.FormatErrorOrNil(merr) +} + +func (r *Route) RemoveAllowedIPs() error { + r.mu.Lock() + defer r.mu.Unlock() + + var merr *multierror.Error + for _, domainPrefixes := range r.dynamicDomains { + for _, prefix := range domainPrefixes { + if _, err := r.allowedIPsRefcounter.Decrement(prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %w", prefix, err)) + } + } + } + + r.currentPeerKey = "" + return nberrors.FormatErrorOrNil(merr) +} + +func (r *Route) startResolver(ctx context.Context) { + log.Debugf("Starting dynamic route resolver for domains [%v]", r) + + interval := r.interval + if interval < minInterval { + interval = minInterval + log.Warnf("Dynamic route resolver interval %s is too low, setting to minimum value %s", r.interval, minInterval) + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + if err := r.update(ctx); err != nil { + log.Errorf("Failed to resolve domains for route [%v]: %v", r, err) + if interval > failureInterval { + ticker.Reset(failureInterval) + } + } + + for { + select { + case <-ctx.Done(): + log.Debugf("Stopping dynamic route resolver for domains [%v]", r) + return + case <-ticker.C: + if err := r.update(ctx); err != nil { + log.Errorf("Failed to resolve domains for route [%v]: %v", r, err) + // Use a lower ticker interval if the update fails + if interval > failureInterval { + ticker.Reset(failureInterval) + } + } else if interval > failureInterval { + // Reset to the original interval if the update succeeds + ticker.Reset(interval) + } + } + } +} + +func (r *Route) update(ctx context.Context) error { + if resolved, err := r.resolveDomains(); err != nil { + return fmt.Errorf("resolve domains: %w", err) + } else if err := r.updateDynamicRoutes(ctx, resolved); err != nil { + return fmt.Errorf("update dynamic routes: %w", err) + } + + return nil +} + +func (r *Route) resolveDomains() (domainMap, error) { + results := make(chan resolveResult) + go r.resolve(results) + + resolved := domainMap{} + var merr *multierror.Error + + for result := range results { + if result.err != nil { + merr = multierror.Append(merr, result.err) + } else { + resolved[result.domain] = append(resolved[result.domain], result.prefix) + } + } + + return resolved, nberrors.FormatErrorOrNil(merr) +} + +func (r *Route) resolve(results chan resolveResult) { + var wg sync.WaitGroup + + for _, d := range r.route.Domains { + wg.Add(1) + go func(domain domain.Domain) { + defer wg.Done() + ips, err := net.LookupIP(string(domain)) + if err != nil { + results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)} + return + } + for _, ip := range ips { + prefix, err := util.GetPrefixFromIP(ip) + if err != nil { + results <- resolveResult{domain: domain, err: fmt.Errorf("get prefix from IP %s: %w", ip.String(), err)} + return + } + results <- resolveResult{domain: domain, prefix: prefix} + } + }(d) + } + + wg.Wait() + close(results) +} + +func (r *Route) updateDynamicRoutes(ctx context.Context, newDomains domainMap) error { + r.mu.Lock() + defer r.mu.Unlock() + + if ctx.Err() != nil { + return ctx.Err() + } + + var merr *multierror.Error + + for domain, newPrefixes := range newDomains { + oldPrefixes := r.dynamicDomains[domain] + toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes) + + addedPrefixes, err := r.addRoutes(domain, toAdd) + if err != nil { + merr = multierror.Append(merr, err) + } else if len(addedPrefixes) > 0 { + log.Debugf("Added dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", addedPrefixes), " ", ", ")) + } + + removedPrefixes, err := r.removeRoutes(toRemove) + if err != nil { + merr = multierror.Append(merr, err) + } else if len(removedPrefixes) > 0 { + log.Debugf("Removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", removedPrefixes), " ", ", ")) + } + + updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes) + r.dynamicDomains[domain] = updatedPrefixes + + r.statusRecorder.UpdateResolvedDomainsStates(domain, updatedPrefixes) + } + + return nberrors.FormatErrorOrNil(merr) +} + +func (r *Route) addRoutes(domain domain.Domain, prefixes []netip.Prefix) ([]netip.Prefix, error) { + var addedPrefixes []netip.Prefix + var merr *multierror.Error + + for _, prefix := range prefixes { + if _, err := r.routeRefCounter.Increment(prefix, nil); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add dynamic route for IP %s: %w", prefix, err)) + continue + } + if r.currentPeerKey != "" { + if err := r.incrementAllowedIP(domain, prefix, r.currentPeerKey); err != nil { + merr = multierror.Append(merr, fmt.Errorf(addAllowedIP, prefix, err)) + } + } + addedPrefixes = append(addedPrefixes, prefix) + } + + return addedPrefixes, merr.ErrorOrNil() +} + +func (r *Route) removeRoutes(prefixes []netip.Prefix) ([]netip.Prefix, error) { + if r.route.KeepRoute { + return nil, nil + } + + var removedPrefixes []netip.Prefix + var merr *multierror.Error + + for _, prefix := range prefixes { + if _, err := r.routeRefCounter.Decrement(prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %w", prefix, err)) + } + if r.currentPeerKey != "" { + if _, err := r.allowedIPsRefcounter.Decrement(prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %w", prefix, err)) + } + } + removedPrefixes = append(removedPrefixes, prefix) + } + + return removedPrefixes, merr.ErrorOrNil() +} + +func (r *Route) incrementAllowedIP(domain domain.Domain, prefix netip.Prefix, peerKey string) error { + if ref, err := r.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil { + return fmt.Errorf(addAllowedIP, prefix, err) + } else if ref.Count > 1 && ref.Out != peerKey { + log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled", + prefix.Addr(), + domain.SafeString(), + ref.Out, + ) + + } + return nil +} + +func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) { + prefixSet := make(map[netip.Prefix]bool) + for _, prefix := range oldPrefixes { + prefixSet[prefix] = false + } + for _, prefix := range newPrefixes { + if _, exists := prefixSet[prefix]; exists { + prefixSet[prefix] = true + } else { + toAdd = append(toAdd, prefix) + } + } + for prefix, inUse := range prefixSet { + if !inUse { + toRemove = append(toRemove, prefix) + } + } + return +} + +func combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes []netip.Prefix) []netip.Prefix { + prefixSet := make(map[netip.Prefix]struct{}) + for _, prefix := range oldPrefixes { + prefixSet[prefix] = struct{}{} + } + for _, prefix := range removedPrefixes { + delete(prefixSet, prefix) + } + for _, prefix := range addedPrefixes { + prefixSet[prefix] = struct{}{} + } + + var combinedPrefixes []netip.Prefix + for prefix := range prefixSet { + combinedPrefixes = append(combinedPrefixes, prefix) + } + + return combinedPrefixes +} diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 47549f74d..0673ea6c3 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -2,18 +2,23 @@ package routemanager import ( "context" + "errors" "fmt" "net" "net/netip" "net/url" "runtime" "sync" + "time" log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" + "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" @@ -21,14 +26,9 @@ import ( "github.com/netbirdio/netbird/version" ) -var defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0) - -// nolint:unused -var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0) - // Manager is a route manager interface type Manager interface { - Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) + Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) TriggerSelection(route.HAMap) GetRouteSelector() *routeselector.RouteSelector @@ -40,31 +40,71 @@ type Manager interface { // DefaultManager is the default instance of a route manager type DefaultManager struct { - ctx context.Context - stop context.CancelFunc - mux sync.Mutex - clientNetworks map[route.HAUniqueID]*clientNetwork - routeSelector *routeselector.RouteSelector - serverRouter serverRouter - statusRecorder *peer.Status - wgInterface *iface.WGIface - pubKey string - notifier *notifier + ctx context.Context + stop context.CancelFunc + mux sync.Mutex + clientNetworks map[route.HAUniqueID]*clientNetwork + routeSelector *routeselector.RouteSelector + serverRouter serverRouter + sysOps *systemops.SysOps + statusRecorder *peer.Status + wgInterface *iface.WGIface + pubKey string + notifier *notifier + routeRefCounter *refcounter.RouteRefCounter + allowedIPsRefCounter *refcounter.AllowedIPsRefCounter + dnsRouteInterval time.Duration } -func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager { +func NewManager( + ctx context.Context, + pubKey string, + dnsRouteInterval time.Duration, + wgInterface *iface.WGIface, + statusRecorder *peer.Status, + initialRoutes []*route.Route, +) *DefaultManager { mCTX, cancel := context.WithCancel(ctx) + sysOps := systemops.NewSysOps(wgInterface) + dm := &DefaultManager{ - ctx: mCTX, - stop: cancel, - clientNetworks: make(map[route.HAUniqueID]*clientNetwork), - routeSelector: routeselector.NewRouteSelector(), - statusRecorder: statusRecorder, - wgInterface: wgInterface, - pubKey: pubKey, - notifier: newNotifier(), + ctx: mCTX, + stop: cancel, + dnsRouteInterval: dnsRouteInterval, + clientNetworks: make(map[route.HAUniqueID]*clientNetwork), + routeSelector: routeselector.NewRouteSelector(), + sysOps: sysOps, + statusRecorder: statusRecorder, + wgInterface: wgInterface, + pubKey: pubKey, + notifier: newNotifier(), } + dm.routeRefCounter = refcounter.New( + func(prefix netip.Prefix, _ any) (any, error) { + return nil, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface()) + }, + func(prefix netip.Prefix, _ any) error { + return sysOps.RemoveVPNRoute(prefix, wgInterface.ToInterface()) + }, + ) + + dm.allowedIPsRefCounter = refcounter.New( + func(prefix netip.Prefix, peerKey string) (string, error) { + // save peerKey to use it in the remove function + return peerKey, wgInterface.AddAllowedIP(peerKey, prefix.String()) + }, + func(prefix netip.Prefix, peerKey string) error { + if err := wgInterface.RemoveAllowedIP(peerKey, prefix.String()); err != nil { + if !errors.Is(err, iface.ErrPeerNotFound) && !errors.Is(err, iface.ErrAllowedIPNotFound) { + return err + } + log.Tracef("Remove allowed IPs %s for %s: %v", prefix, peerKey, err) + } + return nil + }, + ) + if runtime.GOOS == "android" { cr := dm.clientRoutes(initialRoutes) dm.notifier.setInitialClientRoutes(cr) @@ -73,12 +113,12 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, } // Init sets up the routing -func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { +func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { if nbnet.CustomRoutingDisabled() { return nil, nil, nil } - if err := cleanupRouting(); err != nil { + if err := m.sysOps.CleanupRouting(); err != nil { log.Warnf("Failed cleaning up routing: %v", err) } @@ -86,7 +126,7 @@ func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePee signalAddress := m.statusRecorder.GetSignalState().URL ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress}) - beforePeerHook, afterPeerHook, err := setupRouting(ips, m.wgInterface) + beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips) if err != nil { return nil, nil, fmt.Errorf("setup routing: %w", err) } @@ -110,8 +150,19 @@ func (m *DefaultManager) Stop() { m.serverRouter.cleanUp() } + if m.routeRefCounter != nil { + if err := m.routeRefCounter.Flush(); err != nil { + log.Errorf("Error flushing route ref counter: %v", err) + } + } + if m.allowedIPsRefCounter != nil { + if err := m.allowedIPsRefCounter.Flush(); err != nil { + log.Errorf("Error flushing allowed IPs ref counter: %v", err) + } + } + if !nbnet.CustomRoutingDisabled() { - if err := cleanupRouting(); err != nil { + if err := m.sysOps.CleanupRouting(); err != nil { log.Errorf("Error cleaning up routing: %v", err) } else { log.Info("Routing cleanup complete") @@ -185,7 +236,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) { continue } - clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network) + clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter) m.clientNetworks[id] = clientNetworkWatcher go clientNetworkWatcher.peersStateAndUpdateWatcher() clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes}) @@ -197,7 +248,7 @@ func (m *DefaultManager) stopObsoleteClients(networks route.HAMap) { for id, client := range m.clientNetworks { if _, ok := networks[id]; !ok { log.Debugf("Stopping client network watcher, %s", id) - client.stop() + client.cancel() delete(m.clientNetworks, id) } } @@ -210,7 +261,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout for id, routes := range networks { clientNetworkWatcher, found := m.clientNetworks[id] if !found { - clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network) + clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter) m.clientNetworks[id] = clientNetworkWatcher go clientNetworkWatcher.peersStateAndUpdateWatcher() } @@ -228,7 +279,7 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID] ownNetworkIDs := make(map[route.HAUniqueID]bool) for _, newRoute := range newRoutes { - haID := route.GetHAUniqueID(newRoute) + haID := newRoute.GetHAUniqueID() if newRoute.Peer == m.pubKey { ownNetworkIDs[haID] = true // only linux is supported for now @@ -241,9 +292,9 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID] } for _, newRoute := range newRoutes { - haID := route.GetHAUniqueID(newRoute) + haID := newRoute.GetHAUniqueID() if !ownNetworkIDs[haID] { - if !isPrefixSupported(newRoute.Network) { + if !isRouteSupported(newRoute) { continue } newClientRoutesIDMap[haID] = append(newClientRoutesIDMap[haID], newRoute) @@ -255,23 +306,23 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID] func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route { _, crMap := m.classifyRoutes(initialRoutes) - rs := make([]*route.Route, 0) + rs := make([]*route.Route, 0, len(crMap)) for _, routes := range crMap { rs = append(rs, routes...) } return rs } -func isPrefixSupported(prefix netip.Prefix) bool { - if !nbnet.CustomRoutingDisabled() { +func isRouteSupported(route *route.Route) bool { + if !nbnet.CustomRoutingDisabled() || route.IsDynamic() { return true } // If prefix is too small, lets assume it is a possible default prefix which is not yet supported // we skip this prefix management - if prefix.Bits() <= minRangeBits { + if route.Network.Bits() <= vars.MinRangeBits { log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix", - version.NetbirdVersion(), prefix) + version.NetbirdVersion(), route.Network) return false } return true diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 7eb8dd002..455c7ac0b 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -407,7 +407,7 @@ func TestManagerUpdateRoutes(t *testing.T) { if err != nil { t.Fatal(err) } - wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) + wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil) require.NoError(t, err, "should create testing WGIface interface") defer wgInterface.Close() @@ -416,7 +416,7 @@ func TestManagerUpdateRoutes(t *testing.T) { statusRecorder := peer.NewRecorder("https://mgm") ctx := context.TODO() - routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil) + routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil) _, _, err = routeManager.Init() @@ -436,7 +436,7 @@ func TestManagerUpdateRoutes(t *testing.T) { require.NoError(t, err, "should update routes") expectedWatchers := testCase.clientNetworkWatchersExpected - if (runtime.GOOS == "linux" || runtime.GOOS == "windows" || runtime.GOOS == "darwin") && testCase.clientNetworkWatchersExpectedAllowed != 0 { + if testCase.clientNetworkWatchersExpectedAllowed != 0 { expectedWatchers = testCase.clientNetworkWatchersExpectedAllowed } require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match") diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index adbef8061..58a66715c 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -6,10 +6,10 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/listener" - "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/util/net" ) // MockManager is the mock instance of a route manager @@ -20,7 +20,7 @@ type MockManager struct { StopFunc func() } -func (m *MockManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { +func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) { return nil, nil, nil } diff --git a/client/internal/routemanager/refcounter/refcounter.go b/client/internal/routemanager/refcounter/refcounter.go new file mode 100644 index 000000000..f1d696ad9 --- /dev/null +++ b/client/internal/routemanager/refcounter/refcounter.go @@ -0,0 +1,155 @@ +package refcounter + +import ( + "errors" + "fmt" + "net/netip" + "sync" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nberrors "github.com/netbirdio/netbird/client/errors" +) + +// ErrIgnore can be returned by AddFunc to indicate that the counter not be incremented for the given prefix. +var ErrIgnore = errors.New("ignore") + +type Ref[O any] struct { + Count int + Out O +} + +type AddFunc[I, O any] func(prefix netip.Prefix, in I) (out O, err error) +type RemoveFunc[I, O any] func(prefix netip.Prefix, out O) error + +type Counter[I, O any] struct { + // refCountMap keeps track of the reference Ref for prefixes + refCountMap map[netip.Prefix]Ref[O] + refCountMu sync.Mutex + // idMap keeps track of the prefixes associated with an ID for removal + idMap map[string][]netip.Prefix + idMu sync.Mutex + add AddFunc[I, O] + remove RemoveFunc[I, O] +} + +// New creates a new Counter instance +func New[I, O any](add AddFunc[I, O], remove RemoveFunc[I, O]) *Counter[I, O] { + return &Counter[I, O]{ + refCountMap: map[netip.Prefix]Ref[O]{}, + idMap: map[string][]netip.Prefix{}, + add: add, + remove: remove, + } +} + +// Increment increments the reference count for the given prefix. +// If this is the first reference to the prefix, the AddFunc is called. +func (rm *Counter[I, O]) Increment(prefix netip.Prefix, in I) (Ref[O], error) { + rm.refCountMu.Lock() + defer rm.refCountMu.Unlock() + + ref := rm.refCountMap[prefix] + log.Tracef("Increasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out) + + // Call AddFunc only if it's a new prefix + if ref.Count == 0 { + log.Tracef("Adding for prefix %s with [%v]", prefix, ref.Out) + out, err := rm.add(prefix, in) + + if errors.Is(err, ErrIgnore) { + return ref, nil + } + if err != nil { + return ref, fmt.Errorf("failed to add for prefix %s: %w", prefix, err) + } + ref.Out = out + } + + ref.Count++ + rm.refCountMap[prefix] = ref + + return ref, nil +} + +// IncrementWithID increments the reference count for the given prefix and groups it under the given ID. +// If this is the first reference to the prefix, the AddFunc is called. +func (rm *Counter[I, O]) IncrementWithID(id string, prefix netip.Prefix, in I) (Ref[O], error) { + rm.idMu.Lock() + defer rm.idMu.Unlock() + + ref, err := rm.Increment(prefix, in) + if err != nil { + return ref, fmt.Errorf("with ID: %w", err) + } + rm.idMap[id] = append(rm.idMap[id], prefix) + + return ref, nil +} + +// Decrement decrements the reference count for the given prefix. +// If the reference count reaches 0, the RemoveFunc is called. +func (rm *Counter[I, O]) Decrement(prefix netip.Prefix) (Ref[O], error) { + rm.refCountMu.Lock() + defer rm.refCountMu.Unlock() + + ref, ok := rm.refCountMap[prefix] + if !ok { + log.Tracef("No reference found for prefix %s", prefix) + return ref, nil + } + + log.Tracef("Decreasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out) + if ref.Count == 1 { + log.Tracef("Removing for prefix %s with [%v]", prefix, ref.Out) + if err := rm.remove(prefix, ref.Out); err != nil { + return ref, fmt.Errorf("remove for prefix %s: %w", prefix, err) + } + delete(rm.refCountMap, prefix) + } else { + ref.Count-- + rm.refCountMap[prefix] = ref + } + + return ref, nil +} + +// DecrementWithID decrements the reference count for all prefixes associated with the given ID. +// If the reference count reaches 0, the RemoveFunc is called. +func (rm *Counter[I, O]) DecrementWithID(id string) error { + rm.idMu.Lock() + defer rm.idMu.Unlock() + + var merr *multierror.Error + for _, prefix := range rm.idMap[id] { + if _, err := rm.Decrement(prefix); err != nil { + merr = multierror.Append(merr, err) + } + } + delete(rm.idMap, id) + + return nberrors.FormatErrorOrNil(merr) +} + +// Flush removes all references and calls RemoveFunc for each prefix. +func (rm *Counter[I, O]) Flush() error { + rm.refCountMu.Lock() + defer rm.refCountMu.Unlock() + rm.idMu.Lock() + defer rm.idMu.Unlock() + + var merr *multierror.Error + for prefix := range rm.refCountMap { + log.Tracef("Removing for prefix %s", prefix) + ref := rm.refCountMap[prefix] + if err := rm.remove(prefix, ref.Out); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove for prefix %s: %w", prefix, err)) + } + } + rm.refCountMap = map[netip.Prefix]Ref[O]{} + + rm.idMap = map[string][]netip.Prefix{} + + return nberrors.FormatErrorOrNil(merr) +} diff --git a/client/internal/routemanager/refcounter/types.go b/client/internal/routemanager/refcounter/types.go new file mode 100644 index 000000000..6753b64ef --- /dev/null +++ b/client/internal/routemanager/refcounter/types.go @@ -0,0 +1,7 @@ +package refcounter + +// RouteRefCounter is a Counter for Route, it doesn't take any input on Increment and doesn't use any output on Decrement +type RouteRefCounter = Counter[any, any] + +// AllowedIPsRefCounter is a Counter for AllowedIPs, it takes a peer key on Increment and passes it back to Decrement +type AllowedIPsRefCounter = Counter[string, string] diff --git a/client/internal/routemanager/routemanager.go b/client/internal/routemanager/routemanager.go deleted file mode 100644 index 7715aa819..000000000 --- a/client/internal/routemanager/routemanager.go +++ /dev/null @@ -1,127 +0,0 @@ -//go:build !android && !ios - -package routemanager - -import ( - "errors" - "fmt" - "net" - "net/netip" - "sync" - - "github.com/hashicorp/go-multierror" - log "github.com/sirupsen/logrus" - - nbnet "github.com/netbirdio/netbird/util/net" -) - -type ref struct { - count int - nexthop netip.Addr - intf *net.Interface -} - -type RouteManager struct { - // refCountMap keeps track of the reference ref for prefixes - refCountMap map[netip.Prefix]ref - // prefixMap keeps track of the prefixes associated with a connection ID for removal - prefixMap map[nbnet.ConnectionID][]netip.Prefix - addRoute AddRouteFunc - removeRoute RemoveRouteFunc - mutex sync.Mutex -} - -type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf *net.Interface, err error) -type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error - -func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager { - // TODO: read initial routing table into refCountMap - return &RouteManager{ - refCountMap: map[netip.Prefix]ref{}, - prefixMap: map[nbnet.ConnectionID][]netip.Prefix{}, - addRoute: addRoute, - removeRoute: removeRoute, - } -} - -func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Prefix) error { - rm.mutex.Lock() - defer rm.mutex.Unlock() - - ref := rm.refCountMap[prefix] - log.Debugf("Increasing route ref count %d for prefix %s", ref.count, prefix) - - // Add route to the system, only if it's a new prefix - if ref.count == 0 { - log.Debugf("Adding route for prefix %s", prefix) - nexthop, intf, err := rm.addRoute(prefix) - if errors.Is(err, ErrRouteNotFound) { - return nil - } - if errors.Is(err, ErrRouteNotAllowed) { - log.Debugf("Adding route for prefix %s: %s", prefix, err) - } - if err != nil { - return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err) - } - ref.nexthop = nexthop - ref.intf = intf - } - - ref.count++ - rm.refCountMap[prefix] = ref - rm.prefixMap[connID] = append(rm.prefixMap[connID], prefix) - - return nil -} - -func (rm *RouteManager) RemoveRouteRef(connID nbnet.ConnectionID) error { - rm.mutex.Lock() - defer rm.mutex.Unlock() - - prefixes, ok := rm.prefixMap[connID] - if !ok { - log.Debugf("No prefixes found for connection ID %s", connID) - return nil - } - - var result *multierror.Error - for _, prefix := range prefixes { - ref := rm.refCountMap[prefix] - log.Debugf("Decreasing route ref count %d for prefix %s", ref.count, prefix) - if ref.count == 1 { - log.Debugf("Removing route for prefix %s", prefix) - // TODO: don't fail if the route is not found - if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil { - result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err)) - continue - } - delete(rm.refCountMap, prefix) - } else { - ref.count-- - rm.refCountMap[prefix] = ref - } - } - delete(rm.prefixMap, connID) - - return result.ErrorOrNil() -} - -// Flush removes all references and routes from the system -func (rm *RouteManager) Flush() error { - rm.mutex.Lock() - defer rm.mutex.Unlock() - - var result *multierror.Error - for prefix := range rm.refCountMap { - log.Debugf("Removing route for prefix %s", prefix) - ref := rm.refCountMap[prefix] - if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil { - result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err)) - } - } - rm.refCountMap = map[netip.Prefix]ref{} - rm.prefixMap = map[nbnet.ConnectionID][]netip.Prefix{} - - return result.ErrorOrNil() -} diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index 95672e480..8470934c2 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -12,6 +12,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" ) @@ -70,7 +71,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route) } if len(m.routes) > 0 { - err := enableIPForwarding() + err := systemops.EnableIPForwarding() if err != nil { return err } @@ -88,7 +89,7 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error m.mux.Lock() defer m.mux.Unlock() - routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route) + routerPair, err := routeToRouterPair(route) if err != nil { return fmt.Errorf("parse prefix: %w", err) } @@ -117,7 +118,7 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error { m.mux.Lock() defer m.mux.Unlock() - routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route) + routerPair, err := routeToRouterPair(route) if err != nil { return fmt.Errorf("parse prefix: %w", err) } @@ -133,7 +134,13 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error { if state.Routes == nil { state.Routes = map[string]struct{}{} } - state.Routes[route.Network.String()] = struct{}{} + + routeStr := route.Network.String() + if route.IsDynamic() { + routeStr = route.Domains.SafeString() + } + state.Routes[routeStr] = struct{}{} + m.statusRecorder.UpdateLocalPeerState(state) return nil @@ -144,7 +151,7 @@ func (m *defaultServerRouter) cleanUp() { m.mux.Lock() defer m.mux.Unlock() for _, r := range m.routes { - routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), r) + routerPair, err := routeToRouterPair(r) if err != nil { log.Errorf("Failed to convert route to router pair: %v", err) continue @@ -162,15 +169,27 @@ func (m *defaultServerRouter) cleanUp() { m.statusRecorder.UpdateLocalPeerState(state) } -func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair, error) { - parsed, err := netip.ParsePrefix(source) - if err != nil { - return firewall.RouterPair{}, err +func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) { + // TODO: add ipv6 + source := getDefaultPrefix(route.Network) + + destination := route.Network.Masked().String() + if route.IsDynamic() { + // TODO: add ipv6 + destination = "0.0.0.0/0" } + return firewall.RouterPair{ ID: string(route.ID), - Source: parsed.String(), - Destination: route.Network.Masked().String(), + Source: source.String(), + Destination: destination, Masquerade: route.Masquerade, }, nil } + +func getDefaultPrefix(prefix netip.Prefix) netip.Prefix { + if prefix.Addr().Is6() { + return netip.PrefixFrom(netip.IPv6Unspecified(), 0) + } + return netip.PrefixFrom(netip.IPv4Unspecified(), 0) +} diff --git a/client/internal/routemanager/static/route.go b/client/internal/routemanager/static/route.go new file mode 100644 index 000000000..88cca522a --- /dev/null +++ b/client/internal/routemanager/static/route.go @@ -0,0 +1,57 @@ +package static + +import ( + "context" + "fmt" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" + "github.com/netbirdio/netbird/route" +) + +type Route struct { + route *route.Route + routeRefCounter *refcounter.RouteRefCounter + allowedIPsRefcounter *refcounter.AllowedIPsRefCounter +} + +func NewRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *Route { + return &Route{ + route: rt, + routeRefCounter: routeRefCounter, + allowedIPsRefcounter: allowedIPsRefCounter, + } +} + +// Route route methods +func (r *Route) String() string { + return r.route.Network.String() +} + +func (r *Route) AddRoute(context.Context) error { + _, err := r.routeRefCounter.Increment(r.route.Network, nil) + return err +} + +func (r *Route) RemoveRoute() error { + _, err := r.routeRefCounter.Decrement(r.route.Network) + return err +} + +func (r *Route) AddAllowedIPs(peerKey string) error { + if ref, err := r.allowedIPsRefcounter.Increment(r.route.Network, peerKey); err != nil { + return fmt.Errorf("add allowed IP %s: %w", r.route.Network, err) + } else if ref.Count > 1 && ref.Out != peerKey { + log.Warnf("Prefix [%s] is already routed by peer [%s]. HA routing disabled", + r.route.Network, + ref.Out, + ) + } + return nil +} + +func (r *Route) RemoveAllowedIPs() error { + _, err := r.allowedIPsRefcounter.Decrement(r.route.Network) + return err +} diff --git a/client/internal/routemanager/sysctl/sysctl_linux.go b/client/internal/routemanager/sysctl/sysctl_linux.go new file mode 100644 index 000000000..3f2937c89 --- /dev/null +++ b/client/internal/routemanager/sysctl/sysctl_linux.go @@ -0,0 +1,103 @@ +// go:build !android +package sysctl + +import ( + "fmt" + "net" + "os" + "strconv" + "strings" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/iface" +) + +const ( + rpFilterPath = "net.ipv4.conf.all.rp_filter" + rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter" + srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark" +) + +// Setup configures sysctl settings for RP filtering and source validation. +func Setup(wgIface *iface.WGIface) (map[string]int, error) { + keys := map[string]int{} + var result *multierror.Error + + oldVal, err := Set(srcValidMarkPath, 1, false) + if err != nil { + result = multierror.Append(result, err) + } else { + keys[srcValidMarkPath] = oldVal + } + + oldVal, err = Set(rpFilterPath, 2, true) + if err != nil { + result = multierror.Append(result, err) + } else { + keys[rpFilterPath] = oldVal + } + + interfaces, err := net.Interfaces() + if err != nil { + result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err)) + } + + for _, intf := range interfaces { + if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() { + continue + } + + i := fmt.Sprintf(rpFilterInterfacePath, intf.Name) + oldVal, err := Set(i, 2, true) + if err != nil { + result = multierror.Append(result, err) + } else { + keys[i] = oldVal + } + } + + return keys, nberrors.FormatErrorOrNil(result) +} + +// Set sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1 +func Set(key string, desiredValue int, onlyIfOne bool) (int, error) { + path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/")) + currentValue, err := os.ReadFile(path) + if err != nil { + return -1, fmt.Errorf("read sysctl %s: %w", key, err) + } + + currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue))) + if err != nil && len(currentValue) > 0 { + return -1, fmt.Errorf("convert current desiredValue to int: %w", err) + } + + if currentV == desiredValue || onlyIfOne && currentV != 1 { + return currentV, nil + } + + //nolint:gosec + if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil { + return currentV, fmt.Errorf("write sysctl %s: %w", key, err) + } + log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue) + + return currentV, nil +} + +// Cleanup resets sysctl settings to their original values. +func Cleanup(originalSettings map[string]int) error { + var result *multierror.Error + + for key, value := range originalSettings { + _, err := Set(key, value, false) + if err != nil { + result = multierror.Append(result, err) + } + } + + return nberrors.FormatErrorOrNil(result) +} diff --git a/client/internal/routemanager/systemops.go b/client/internal/routemanager/systemops.go deleted file mode 100644 index bc506411c..000000000 --- a/client/internal/routemanager/systemops.go +++ /dev/null @@ -1,414 +0,0 @@ -//go:build !android && !ios - -package routemanager - -import ( - "context" - "errors" - "fmt" - "net" - "net/netip" - "runtime" - "strconv" - - "github.com/hashicorp/go-multierror" - "github.com/libp2p/go-netroute" - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" - nbnet "github.com/netbirdio/netbird/util/net" -) - -var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1) -var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) -var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) -var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) - -var ErrRouteNotFound = errors.New("route not found") -var ErrRouteNotAllowed = errors.New("route not allowed") - -// TODO: fix: for default our wg address now appears as the default gw -func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { - addr := netip.IPv4Unspecified() - if prefix.Addr().Is6() { - addr = netip.IPv6Unspecified() - } - - defaultGateway, _, err := GetNextHop(addr) - if err != nil && !errors.Is(err, ErrRouteNotFound) { - return fmt.Errorf("get existing route gateway: %s", err) - } - - if !prefix.Contains(defaultGateway) { - log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix) - return nil - } - - gatewayPrefix := netip.PrefixFrom(defaultGateway, 32) - if defaultGateway.Is6() { - gatewayPrefix = netip.PrefixFrom(defaultGateway, 128) - } - - ok, err := existsInRouteTable(gatewayPrefix) - if err != nil { - return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err) - } - - if ok { - log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix) - return nil - } - - gatewayHop, intf, err := GetNextHop(defaultGateway) - if err != nil && !errors.Is(err, ErrRouteNotFound) { - return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) - } - - log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) - return addToRouteTable(gatewayPrefix, gatewayHop, intf) -} - -func GetNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) { - r, err := netroute.New() - if err != nil { - return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err) - } - intf, gateway, preferredSrc, err := r.Route(ip.AsSlice()) - if err != nil { - log.Debugf("Failed to get route for %s: %v", ip, err) - return netip.Addr{}, nil, ErrRouteNotFound - } - - log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) - if gateway == nil { - if preferredSrc == nil { - return netip.Addr{}, nil, ErrRouteNotFound - } - log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc) - - addr, err := ipToAddr(preferredSrc, intf) - if err != nil { - return netip.Addr{}, nil, fmt.Errorf("convert preferred source to address: %w", err) - } - return addr.Unmap(), intf, nil - } - - addr, err := ipToAddr(gateway, intf) - if err != nil { - return netip.Addr{}, nil, fmt.Errorf("convert gateway to address: %w", err) - } - - return addr, intf, nil -} - -// converts a net.IP to a netip.Addr including the zone based on the passed interface -func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) { - addr, ok := netip.AddrFromSlice(ip) - if !ok { - return netip.Addr{}, fmt.Errorf("failed to convert IP address to netip.Addr: %s", ip) - } - - if intf != nil && (addr.IsLinkLocalMulticast() || addr.IsLinkLocalUnicast()) { - log.Tracef("Adding zone %s to address %s", intf.Name, addr) - if runtime.GOOS == "windows" { - addr = addr.WithZone(strconv.Itoa(intf.Index)) - } else { - addr = addr.WithZone(intf.Name) - } - } - - return addr.Unmap(), nil -} - -func existsInRouteTable(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() - if err != nil { - return false, fmt.Errorf("get routes from table: %w", err) - } - for _, tableRoute := range routes { - if tableRoute == prefix { - return true, nil - } - } - return false, nil -} - -func isSubRange(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() - if err != nil { - return false, fmt.Errorf("get routes from table: %w", err) - } - for _, tableRoute := range routes { - if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { - return true, nil - } - } - return false, nil -} - -// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface. -// If the next hop or interface is pointing to the VPN interface, it will return the initial values. -func addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop netip.Addr, initialIntf *net.Interface) (netip.Addr, *net.Interface, error) { - addr := prefix.Addr() - switch { - case addr.IsLoopback(), - addr.IsLinkLocalUnicast(), - addr.IsLinkLocalMulticast(), - addr.IsInterfaceLocalMulticast(), - addr.IsUnspecified(), - addr.IsMulticast(): - - return netip.Addr{}, nil, ErrRouteNotAllowed - } - - // Determine the exit interface and next hop for the prefix, so we can add a specific route - nexthop, intf, err := GetNextHop(addr) - if err != nil { - return netip.Addr{}, nil, fmt.Errorf("get next hop: %w", err) - } - - log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf) - exitNextHop := nexthop - exitIntf := intf - - vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP) - if !ok { - return netip.Addr{}, nil, fmt.Errorf("failed to convert vpn address to netip.Addr") - } - - // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values - if exitNextHop == vpnAddr || exitIntf != nil && exitIntf.Name == vpnIntf.Name() { - log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix) - exitNextHop = initialNextHop - exitIntf = initialIntf - } - - log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop) - if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil { - return netip.Addr{}, nil, fmt.Errorf("add route to table: %w", err) - } - - return exitNextHop, exitIntf, nil -} - -// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix -// in two /1 prefixes to avoid replacing the existing default route -func genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { - if prefix == defaultv4 { - if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { - return err - } - if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { - if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) - } - return err - } - - // TODO: remove once IPv6 is supported on the interface - if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { - return fmt.Errorf("add unreachable route split 1: %w", err) - } - if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { - if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) - } - return fmt.Errorf("add unreachable route split 2: %w", err) - } - - return nil - } else if prefix == defaultv6 { - if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { - return fmt.Errorf("add unreachable route split 1: %w", err) - } - if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { - if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) - } - return fmt.Errorf("add unreachable route split 2: %w", err) - } - - return nil - } - - return addNonExistingRoute(prefix, intf) -} - -// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table -func addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error { - ok, err := existsInRouteTable(prefix) - if err != nil { - return fmt.Errorf("exists in route table: %w", err) - } - if ok { - log.Warnf("Skipping adding a new route for network %s because it already exists", prefix) - return nil - } - - ok, err = isSubRange(prefix) - if err != nil { - return fmt.Errorf("sub range: %w", err) - } - - if ok { - err := addRouteForCurrentDefaultGateway(prefix) - if err != nil { - log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err) - } - } - - return addToRouteTable(prefix, netip.Addr{}, intf) -} - -// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given, -// it will remove the split /1 prefixes -func genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error { - if prefix == defaultv4 { - var result *multierror.Error - if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - - // TODO: remove once IPv6 is supported on the interface - if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - - return result.ErrorOrNil() - } else if prefix == defaultv6 { - var result *multierror.Error - if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - - return result.ErrorOrNil() - } - - return removeFromRouteTable(prefix, netip.Addr{}, intf) -} - -func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) { - addr, ok := netip.AddrFromSlice(ip) - if !ok { - return nil, fmt.Errorf("parse IP address: %s", ip) - } - addr = addr.Unmap() - - var prefixLength int - switch { - case addr.Is4(): - prefixLength = 32 - case addr.Is6(): - prefixLength = 128 - default: - return nil, fmt.Errorf("invalid IP address: %s", addr) - } - - prefix := netip.PrefixFrom(addr, prefixLength) - return &prefix, nil -} - -func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - initialNextHopV4, initialIntfV4, err := GetNextHop(netip.IPv4Unspecified()) - if err != nil && !errors.Is(err, ErrRouteNotFound) { - log.Errorf("Unable to get initial v4 default next hop: %v", err) - } - initialNextHopV6, initialIntfV6, err := GetNextHop(netip.IPv6Unspecified()) - if err != nil && !errors.Is(err, ErrRouteNotFound) { - log.Errorf("Unable to get initial v6 default next hop: %v", err) - } - - *routeManager = NewRouteManager( - func(prefix netip.Prefix) (netip.Addr, *net.Interface, error) { - addr := prefix.Addr() - nexthop, intf := initialNextHopV4, initialIntfV4 - if addr.Is6() { - nexthop, intf = initialNextHopV6, initialIntfV6 - } - return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf) - }, - removeFromRouteTable, - ) - - return setupHooks(*routeManager, initAddresses) -} - -func cleanupRoutingWithRouteManager(routeManager *RouteManager) error { - if routeManager == nil { - return nil - } - - // TODO: Remove hooks selectively - nbnet.RemoveDialerHooks() - nbnet.RemoveListenerHooks() - - if err := routeManager.Flush(); err != nil { - return fmt.Errorf("flush route manager: %w", err) - } - - return nil -} - -func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { - prefix, err := getPrefixFromIP(ip) - if err != nil { - return fmt.Errorf("convert ip to prefix: %w", err) - } - - if err := routeManager.AddRouteRef(connID, *prefix); err != nil { - return fmt.Errorf("adding route reference: %v", err) - } - - return nil - } - afterHook := func(connID nbnet.ConnectionID) error { - if err := routeManager.RemoveRouteRef(connID); err != nil { - return fmt.Errorf("remove route reference: %w", err) - } - - return nil - } - - for _, ip := range initAddresses { - if err := beforeHook("init", ip); err != nil { - log.Errorf("Failed to add route reference: %v", err) - } - } - - nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error { - if ctx.Err() != nil { - return ctx.Err() - } - - var result *multierror.Error - for _, ip := range resolvedIPs { - result = multierror.Append(result, beforeHook(connID, ip.IP)) - } - return result.ErrorOrNil() - }) - - nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { - return afterHook(connID) - }) - - nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error { - return beforeHook(connID, ip.IP) - }) - - nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error { - return afterHook(connID) - }) - - return beforeHook, afterHook, nil -} diff --git a/client/internal/routemanager/systemops/routeflags_bsd.go b/client/internal/routemanager/systemops/routeflags_bsd.go new file mode 100644 index 000000000..12f158dcb --- /dev/null +++ b/client/internal/routemanager/systemops/routeflags_bsd.go @@ -0,0 +1,18 @@ +//go:build darwin || dragonfly || netbsd || openbsd + +package systemops + +import "syscall" + +// filterRoutesByFlags - return true if need to ignore such route message because it consists specific flags. +func filterRoutesByFlags(routeMessageFlags int) bool { + if routeMessageFlags&syscall.RTF_UP == 0 { + return true + } + + if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 { + return true + } + + return false +} diff --git a/client/internal/routemanager/systemops/routeflags_freebsd.go b/client/internal/routemanager/systemops/routeflags_freebsd.go new file mode 100644 index 000000000..cb35f521e --- /dev/null +++ b/client/internal/routemanager/systemops/routeflags_freebsd.go @@ -0,0 +1,19 @@ +//go:build: freebsd +package systemops + +import "syscall" + +// filterRoutesByFlags - return true if need to ignore such route message because it consists specific flags. +func filterRoutesByFlags(routeMessageFlags int) bool { + if routeMessageFlags&syscall.RTF_UP == 0 { + return true + } + + // NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0 (https://www.freebsd.org/releases/8.0R/relnotes-detailed/) + // a concept of cloned route (a route generated by an entry with RTF_CLONING flag) is deprecated. + if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE) != 0 { + return true + } + + return false +} diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go new file mode 100644 index 000000000..9ee51538b --- /dev/null +++ b/client/internal/routemanager/systemops/systemops.go @@ -0,0 +1,27 @@ +package systemops + +import ( + "net" + "net/netip" + + "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" + "github.com/netbirdio/netbird/iface" +) + +type Nexthop struct { + IP netip.Addr + Intf *net.Interface +} + +type ExclusionCounter = refcounter.Counter[any, Nexthop] + +type SysOps struct { + refCounter *ExclusionCounter + wgInterface *iface.WGIface +} + +func NewSysOps(wgInterface *iface.WGIface) *SysOps { + return &SysOps{ + wgInterface: wgInterface, + } +} diff --git a/client/internal/routemanager/systemops_bsd.go b/client/internal/routemanager/systemops/systemops_bsd.go similarity index 94% rename from client/internal/routemanager/systemops_bsd.go rename to client/internal/routemanager/systemops/systemops_bsd.go index a3548a1f1..b7fb554db 100644 --- a/client/internal/routemanager/systemops_bsd.go +++ b/client/internal/routemanager/systemops/systemops_bsd.go @@ -1,6 +1,6 @@ //go:build darwin || dragonfly || freebsd || netbsd || openbsd -package routemanager +package systemops import ( "errors" @@ -43,8 +43,7 @@ func getRoutesFromTable() ([]netip.Prefix, error) { return nil, fmt.Errorf("unexpected RIB message type: %d", m.Type) } - if m.Flags&syscall.RTF_UP == 0 || - m.Flags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 { + if filterRoutesByFlags(m.Flags) { continue } @@ -93,7 +92,7 @@ func toNetIP(a route.Addr) netip.Addr { case *route.Inet6Addr: ip := netip.AddrFrom16(t.IP) if t.ZoneID != 0 { - ip.WithZone(strconv.Itoa(t.ZoneID)) + ip = ip.WithZone(strconv.Itoa(t.ZoneID)) } return ip default: @@ -101,6 +100,7 @@ func toNetIP(a route.Addr) netip.Addr { } } +// ones returns the number of leading ones in the mask. func ones(a route.Addr) (int, error) { switch t := a.(type) { case *route.Inet4Addr: @@ -114,6 +114,7 @@ func ones(a route.Addr) (int, error) { } } +// MsgToRoute converts a route message to a Route. func MsgToRoute(msg *route.RouteMessage) (*Route, error) { dstIP, nexthop, dstMask := msg.Addrs[0], msg.Addrs[1], msg.Addrs[2] diff --git a/client/internal/routemanager/systemops_darwin_test.go b/client/internal/routemanager/systemops/systemops_bsd_test.go similarity index 72% rename from client/internal/routemanager/systemops_darwin_test.go rename to client/internal/routemanager/systemops/systemops_bsd_test.go index c23a7cde3..ce9a9082a 100644 --- a/client/internal/routemanager/systemops_darwin_test.go +++ b/client/internal/routemanager/systemops/systemops_bsd_test.go @@ -1,6 +1,6 @@ -//go:build !ios +//go:build darwin || dragonfly || freebsd || netbsd || openbsd -package routemanager +package systemops import ( "fmt" @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/net/route" ) var expectedVPNint = "utun100" @@ -35,13 +36,15 @@ func TestConcurrentRoutes(t *testing.T) { baseIP := netip.MustParseAddr("192.0.2.0") intf := &net.Interface{Name: "lo0"} + r := NewSysOps(nil) + var wg sync.WaitGroup for i := 0; i < 1024; i++ { wg.Add(1) go func(ip netip.Addr) { defer wg.Done() prefix := netip.PrefixFrom(ip, 32) - if err := addToRouteTable(prefix, netip.Addr{}, intf); err != nil { + if err := r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil { t.Errorf("Failed to add route for %s: %v", prefix, err) } }(baseIP) @@ -57,7 +60,7 @@ func TestConcurrentRoutes(t *testing.T) { go func(ip netip.Addr) { defer wg.Done() prefix := netip.PrefixFrom(ip, 32) - if err := removeFromRouteTable(prefix, netip.Addr{}, intf); err != nil { + if err := r.removeFromRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil { t.Errorf("Failed to remove route for %s: %v", prefix, err) } }(baseIP) @@ -67,6 +70,53 @@ func TestConcurrentRoutes(t *testing.T) { wg.Wait() } +func TestBits(t *testing.T) { + tests := []struct { + name string + addr route.Addr + want int + wantErr bool + }{ + { + name: "IPv4 all ones", + addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 255}}, + want: 32, + }, + { + name: "IPv4 normal mask", + addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 0}}, + want: 24, + }, + { + name: "IPv6 all ones", + addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}}, + want: 128, + }, + { + name: "IPv6 normal mask", + addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0}}, + want: 64, + }, + { + name: "Unsupported type", + addr: &route.LinkAddr{}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ones(tt.addr) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} + func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string { t.Helper() diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go new file mode 100644 index 000000000..0d1c16ca1 --- /dev/null +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -0,0 +1,473 @@ +//go:build !android && !ios + +package systemops + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "runtime" + "strconv" + + "github.com/hashicorp/go-multierror" + "github.com/libp2p/go-netroute" + log "github.com/sirupsen/logrus" + + nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" + "github.com/netbirdio/netbird/client/internal/routemanager/util" + "github.com/netbirdio/netbird/client/internal/routemanager/vars" + "github.com/netbirdio/netbird/iface" + nbnet "github.com/netbirdio/netbird/util/net" +) + +var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1) +var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) +var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) +var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) + +var ErrRoutingIsSeparate = errors.New("routing is separate") + +func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { + initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified()) + if err != nil && !errors.Is(err, vars.ErrRouteNotFound) { + log.Errorf("Unable to get initial v4 default next hop: %v", err) + } + initialNextHopV6, err := GetNextHop(netip.IPv6Unspecified()) + if err != nil && !errors.Is(err, vars.ErrRouteNotFound) { + log.Errorf("Unable to get initial v6 default next hop: %v", err) + } + + refCounter := refcounter.New( + func(prefix netip.Prefix, _ any) (Nexthop, error) { + initialNexthop := initialNextHopV4 + if prefix.Addr().Is6() { + initialNexthop = initialNextHopV6 + } + + nexthop, err := r.addRouteToNonVPNIntf(prefix, r.wgInterface, initialNexthop) + if errors.Is(err, vars.ErrRouteNotAllowed) || errors.Is(err, vars.ErrRouteNotFound) { + log.Tracef("Adding for prefix %s: %v", prefix, err) + // These errors are not critical but also we should not track and try to remove the routes either. + return nexthop, refcounter.ErrIgnore + } + return nexthop, err + }, + r.removeFromRouteTable, + ) + + r.refCounter = refCounter + + return r.setupHooks(initAddresses) +} + +func (r *SysOps) cleanupRefCounter() error { + if r.refCounter == nil { + return nil + } + + // TODO: Remove hooks selectively + nbnet.RemoveDialerHooks() + nbnet.RemoveListenerHooks() + + if err := r.refCounter.Flush(); err != nil { + return fmt.Errorf("flush route manager: %w", err) + } + + return nil +} + +// TODO: fix: for default our wg address now appears as the default gw +func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { + addr := netip.IPv4Unspecified() + if prefix.Addr().Is6() { + addr = netip.IPv6Unspecified() + } + + nexthop, err := GetNextHop(addr) + if err != nil && !errors.Is(err, vars.ErrRouteNotFound) { + return fmt.Errorf("get existing route gateway: %s", err) + } + + if !prefix.Contains(nexthop.IP) { + log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", nexthop.IP, prefix) + return nil + } + + gatewayPrefix := netip.PrefixFrom(nexthop.IP, 32) + if nexthop.IP.Is6() { + gatewayPrefix = netip.PrefixFrom(nexthop.IP, 128) + } + + ok, err := existsInRouteTable(gatewayPrefix) + if err != nil { + return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err) + } + + if ok { + log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix) + return nil + } + + nexthop, err = GetNextHop(nexthop.IP) + if err != nil && !errors.Is(err, vars.ErrRouteNotFound) { + return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) + } + + log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, nexthop.IP) + return r.addToRouteTable(gatewayPrefix, nexthop) +} + +// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface. +// If the next hop or interface is pointing to the VPN interface, it will return the initial values. +func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop Nexthop) (Nexthop, error) { + addr := prefix.Addr() + switch { + case addr.IsLoopback(), + addr.IsLinkLocalUnicast(), + addr.IsLinkLocalMulticast(), + addr.IsInterfaceLocalMulticast(), + addr.IsUnspecified(), + addr.IsMulticast(): + + return Nexthop{}, vars.ErrRouteNotAllowed + } + + // Determine the exit interface and next hop for the prefix, so we can add a specific route + nexthop, err := GetNextHop(addr) + if err != nil { + return Nexthop{}, fmt.Errorf("get next hop: %w", err) + } + + log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.IP) + exitNextHop := Nexthop{ + IP: nexthop.IP, + Intf: nexthop.Intf, + } + + vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP) + if !ok { + return Nexthop{}, fmt.Errorf("failed to convert vpn address to netip.Addr") + } + + // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values + if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() { + log.Debugf("Route for prefix %s is pointing to the VPN interface, using initial next hop %v", prefix, initialNextHop) + + exitNextHop = initialNextHop + } + + log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop.IP) + if err := r.addToRouteTable(prefix, exitNextHop); err != nil { + return Nexthop{}, fmt.Errorf("add route to table: %w", err) + } + + return exitNextHop, nil +} + +// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix +// in two /1 prefixes to avoid replacing the existing default route +func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { + nextHop := Nexthop{netip.Addr{}, intf} + + if prefix == vars.Defaultv4 { + if err := r.addToRouteTable(splitDefaultv4_1, nextHop); err != nil { + return err + } + if err := r.addToRouteTable(splitDefaultv4_2, nextHop); err != nil { + if err2 := r.removeFromRouteTable(splitDefaultv4_1, nextHop); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return err + } + + // TODO: remove once IPv6 is supported on the interface + if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil { + return fmt.Errorf("add unreachable route split 1: %w", err) + } + if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil { + if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return fmt.Errorf("add unreachable route split 2: %w", err) + } + + return nil + } else if prefix == vars.Defaultv6 { + if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil { + return fmt.Errorf("add unreachable route split 1: %w", err) + } + if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil { + if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return fmt.Errorf("add unreachable route split 2: %w", err) + } + + return nil + } + + return r.addNonExistingRoute(prefix, intf) +} + +// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table +func (r *SysOps) addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error { + ok, err := existsInRouteTable(prefix) + if err != nil { + return fmt.Errorf("exists in route table: %w", err) + } + if ok { + log.Warnf("Skipping adding a new route for network %s because it already exists", prefix) + return nil + } + + ok, err = isSubRange(prefix) + if err != nil { + return fmt.Errorf("sub range: %w", err) + } + + if ok { + if err := r.addRouteForCurrentDefaultGateway(prefix); err != nil { + log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err) + } + } + + return r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf}) +} + +// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given, +// it will remove the split /1 prefixes +func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error { + nextHop := Nexthop{netip.Addr{}, intf} + + if prefix == vars.Defaultv4 { + var result *multierror.Error + if err := r.removeFromRouteTable(splitDefaultv4_1, nextHop); err != nil { + result = multierror.Append(result, err) + } + if err := r.removeFromRouteTable(splitDefaultv4_2, nextHop); err != nil { + result = multierror.Append(result, err) + } + + // TODO: remove once IPv6 is supported on the interface + if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil { + result = multierror.Append(result, err) + } + if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil { + result = multierror.Append(result, err) + } + + return nberrors.FormatErrorOrNil(result) + } else if prefix == vars.Defaultv6 { + var result *multierror.Error + if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil { + result = multierror.Append(result, err) + } + if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil { + result = multierror.Append(result, err) + } + + return nberrors.FormatErrorOrNil(result) + } + + return r.removeFromRouteTable(prefix, nextHop) +} + +func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { + beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { + prefix, err := util.GetPrefixFromIP(ip) + if err != nil { + return fmt.Errorf("convert ip to prefix: %w", err) + } + + if _, err := r.refCounter.IncrementWithID(string(connID), prefix, nil); err != nil { + return fmt.Errorf("adding route reference: %v", err) + } + + return nil + } + afterHook := func(connID nbnet.ConnectionID) error { + if err := r.refCounter.DecrementWithID(string(connID)); err != nil { + return fmt.Errorf("remove route reference: %w", err) + } + + return nil + } + + for _, ip := range initAddresses { + if err := beforeHook("init", ip); err != nil { + log.Errorf("Failed to add route reference: %v", err) + } + } + + nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error { + if ctx.Err() != nil { + return ctx.Err() + } + + var result *multierror.Error + for _, ip := range resolvedIPs { + result = multierror.Append(result, beforeHook(connID, ip.IP)) + } + return nberrors.FormatErrorOrNil(result) + }) + + nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { + return afterHook(connID) + }) + + nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error { + return beforeHook(connID, ip.IP) + }) + + nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error { + return afterHook(connID) + }) + + return beforeHook, afterHook, nil +} + +func GetNextHop(ip netip.Addr) (Nexthop, error) { + r, err := netroute.New() + if err != nil { + return Nexthop{}, fmt.Errorf("new netroute: %w", err) + } + intf, gateway, preferredSrc, err := r.Route(ip.AsSlice()) + if err != nil { + log.Debugf("Failed to get route for %s: %v", ip, err) + return Nexthop{}, vars.ErrRouteNotFound + } + + log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) + if gateway == nil { + if runtime.GOOS == "freebsd" { + return Nexthop{Intf: intf}, nil + } + + if preferredSrc == nil { + return Nexthop{}, vars.ErrRouteNotFound + } + log.Debugf("No next hop found for IP %s, using preferred source %s", ip, preferredSrc) + + addr, err := ipToAddr(preferredSrc, intf) + if err != nil { + return Nexthop{}, fmt.Errorf("convert preferred source to address: %w", err) + } + return Nexthop{ + IP: addr, + Intf: intf, + }, nil + } + + addr, err := ipToAddr(gateway, intf) + if err != nil { + return Nexthop{}, fmt.Errorf("convert gateway to address: %w", err) + } + + return Nexthop{ + IP: addr, + Intf: intf, + }, nil +} + +// converts a net.IP to a netip.Addr including the zone based on the passed interface +func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) { + addr, ok := netip.AddrFromSlice(ip) + if !ok { + return netip.Addr{}, fmt.Errorf("failed to convert IP address to netip.Addr: %s", ip) + } + + if intf != nil && (addr.IsLinkLocalMulticast() || addr.IsLinkLocalUnicast()) { + zone := intf.Name + if runtime.GOOS == "windows" { + zone = strconv.Itoa(intf.Index) + } + log.Tracef("Adding zone %s to address %s", zone, addr) + addr = addr.WithZone(zone) + } + + return addr.Unmap(), nil +} + +func existsInRouteTable(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, fmt.Errorf("get routes from table: %w", err) + } + for _, tableRoute := range routes { + if tableRoute == prefix { + return true, nil + } + } + return false, nil +} + +func isSubRange(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, fmt.Errorf("get routes from table: %w", err) + } + for _, tableRoute := range routes { + if tableRoute.Bits() > vars.MinRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { + return true, nil + } + } + return false, nil +} + +// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix. +func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) { + localRoutes, err := hasSeparateRouting() + if err != nil { + if !errors.Is(err, ErrRoutingIsSeparate) { + log.Errorf("Failed to get routes: %v", err) + } + return false, netip.Prefix{} + } + + return isVpnRoute(addr, vpnRoutes, localRoutes) +} + +func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.Prefix) (bool, netip.Prefix) { + vpnPrefixMap := map[netip.Prefix]struct{}{} + for _, prefix := range vpnRoutes { + vpnPrefixMap[prefix] = struct{}{} + } + + // remove vpnRoute duplicates + for _, prefix := range localRoutes { + delete(vpnPrefixMap, prefix) + } + + var longestPrefix netip.Prefix + var isVpn bool + + combinedRoutes := make([]netip.Prefix, len(vpnRoutes)+len(localRoutes)) + copy(combinedRoutes, vpnRoutes) + copy(combinedRoutes[len(vpnRoutes):], localRoutes) + + for _, prefix := range combinedRoutes { + // Ignore the default route, it has special handling + if prefix.Bits() == 0 { + continue + } + + if prefix.Contains(addr) { + // Longest prefix match + if !longestPrefix.IsValid() || prefix.Bits() > longestPrefix.Bits() { + longestPrefix = prefix + _, isVpn = vpnPrefixMap[prefix] + } + } + } + + if !longestPrefix.IsValid() { + // No route matched + return false, netip.Prefix{} + } + + // Return true if the longest matching prefix is from vpnRoutes + return isVpn, longestPrefix +} diff --git a/client/internal/routemanager/systemops_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go similarity index 59% rename from client/internal/routemanager/systemops_test.go rename to client/internal/routemanager/systemops/systemops_generic_test.go index 8bcf06dce..292166582 100644 --- a/client/internal/routemanager/systemops_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -1,6 +1,6 @@ //go:build !android && !ios -package routemanager +package systemops import ( "bytes" @@ -49,6 +49,10 @@ func TestAddRemoveRoutes(t *testing.T) { } for n, testCase := range testCases { + // todo resolve test execution on freebsd + if runtime.GOOS == "freebsd" { + t.Skip("skipping ", testCase.name, " on freebsd") + } t.Run(testCase.name, func(t *testing.T) { t.Setenv("NB_DISABLE_ROUTE_CACHE", "true") @@ -57,23 +61,26 @@ func TestAddRemoveRoutes(t *testing.T) { if err != nil { t.Fatal(err) } - wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) + wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil) require.NoError(t, err, "should create testing WGIface interface") defer wgInterface.Close() err = wgInterface.Create() require.NoError(t, err, "should create testing wireguard interface") - _, _, err = setupRouting(nil, wgInterface) + + r := NewSysOps(wgInterface) + + _, _, err = r.SetupRouting(nil) require.NoError(t, err) t.Cleanup(func() { - assert.NoError(t, cleanupRouting()) + assert.NoError(t, r.CleanupRouting()) }) index, err := net.InterfaceByName(wgInterface.Name()) require.NoError(t, err, "InterfaceByName should not return err") intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()} - err = addVPNRoute(testCase.prefix, intf) + err = r.AddVPNRoute(testCase.prefix, intf) require.NoError(t, err, "genericAddVPNRoute should not return err") if testCase.shouldRouteToWireguard { @@ -84,19 +91,19 @@ func TestAddRemoveRoutes(t *testing.T) { exists, err := existsInRouteTable(testCase.prefix) require.NoError(t, err, "existsInRouteTable should not return err") if exists && testCase.shouldRouteToWireguard { - err = removeVPNRoute(testCase.prefix, intf) + err = r.RemoveVPNRoute(testCase.prefix, intf) require.NoError(t, err, "genericRemoveVPNRoute should not return err") - prefixGateway, _, err := GetNextHop(testCase.prefix.Addr()) + prefixNexthop, err := GetNextHop(testCase.prefix.Addr()) require.NoError(t, err, "GetNextHop should not return err") - internetGateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0")) + internetNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0")) require.NoError(t, err) if testCase.shouldBeRemoved { - require.Equal(t, internetGateway, prefixGateway, "route should be pointing to default internet gateway") + require.Equal(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to default internet gateway") } else { - require.NotEqual(t, internetGateway, prefixGateway, "route should be pointing to a different gateway than the internet gateway") + require.NotEqual(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to a different gateway than the internet gateway") } } }) @@ -104,11 +111,14 @@ func TestAddRemoveRoutes(t *testing.T) { } func TestGetNextHop(t *testing.T) { - gateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0")) + if runtime.GOOS == "freebsd" { + t.Skip("skipping on freebsd") + } + nexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0")) if err != nil { t.Fatal("shouldn't return error when fetching the gateway: ", err) } - if !gateway.IsValid() { + if !nexthop.IP.IsValid() { t.Fatal("should return a gateway") } addresses, err := net.InterfaceAddrs() @@ -130,24 +140,24 @@ func TestGetNextHop(t *testing.T) { } } - localIP, _, err := GetNextHop(testingPrefix.Addr()) + localIP, err := GetNextHop(testingPrefix.Addr()) if err != nil { t.Fatal("shouldn't return error: ", err) } - if !localIP.IsValid() { + if !localIP.IP.IsValid() { t.Fatal("should return a gateway for local network") } - if localIP.String() == gateway.String() { - t.Fatal("local ip should not match with gateway IP") + if localIP.IP.String() == nexthop.IP.String() { + t.Fatal("local IP should not match with gateway IP") } - if localIP.String() != testingIP { - t.Fatalf("local ip should match with testing IP: want %s got %s", testingIP, localIP.String()) + if localIP.IP.String() != testingIP { + t.Fatalf("local IP should match with testing IP: want %s got %s", testingIP, localIP.IP.String()) } } func TestAddExistAndRemoveRoute(t *testing.T) { - defaultGateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0")) - t.Log("defaultGateway: ", defaultGateway) + defaultNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0")) + t.Log("defaultNexthop: ", defaultNexthop) if err != nil { t.Fatal("shouldn't return error when fetching the gateway: ", err) } @@ -164,7 +174,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) { }, { name: "Should Not Add Route if overlaps with default gateway", - prefix: netip.MustParsePrefix(defaultGateway.String() + "/31"), + prefix: netip.MustParsePrefix(defaultNexthop.IP.String() + "/31"), shouldAddRoute: false, }, { @@ -203,7 +213,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) { if err != nil { t.Fatal(err) } - wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) + wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil) require.NoError(t, err, "should create testing WGIface interface") defer wgInterface.Close() @@ -214,14 +224,16 @@ func TestAddExistAndRemoveRoute(t *testing.T) { require.NoError(t, err, "InterfaceByName should not return err") intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()} + r := NewSysOps(wgInterface) + // Prepare the environment if testCase.preExistingPrefix.IsValid() { - err := addVPNRoute(testCase.preExistingPrefix, intf) + err := r.AddVPNRoute(testCase.preExistingPrefix, intf) require.NoError(t, err, "should not return err when adding pre-existing route") } // Add the route - err = addVPNRoute(testCase.prefix, intf) + err = r.AddVPNRoute(testCase.prefix, intf) require.NoError(t, err, "should not return err when adding route") if testCase.shouldAddRoute { @@ -231,7 +243,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) { require.True(t, ok, "route should exist") // remove route again if added - err = removeVPNRoute(testCase.prefix, intf) + err = r.RemoveVPNRoute(testCase.prefix, intf) require.NoError(t, err, "should not return err") } @@ -295,19 +307,22 @@ func TestExistsInRouteTable(t *testing.T) { var addressPrefixes []netip.Prefix for _, address := range addresses { p := netip.MustParsePrefix(address.String()) - if p.Addr().Is6() { - continue - } - // Windows sometimes has hidden interface link local addrs that don't turn up on any interface - if runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast() { - continue - } - // Linux loopback 127/8 is in the local table, not in the main table and always takes precedence - if runtime.GOOS == "linux" && p.Addr().IsLoopback() { - continue - } - addressPrefixes = append(addressPrefixes, p.Masked()) + switch { + case p.Addr().Is6(): + continue + // Windows sometimes has hidden interface link local addrs that don't turn up on any interface + case runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast(): + continue + // Linux loopback 127/8 is in the local table, not in the main table and always takes precedence + case runtime.GOOS == "linux" && p.Addr().IsLoopback(): + continue + // FreeBSD loopback 127/8 is not added to the routing table + case runtime.GOOS == "freebsd" && p.Addr().IsLoopback(): + continue + default: + addressPrefixes = append(addressPrefixes, p.Masked()) + } } for _, prefix := range addressPrefixes { @@ -330,7 +345,7 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen newNet, err := stdnet.NewNet() require.NoError(t, err) - wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) + wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil) require.NoError(t, err, "should create testing WireGuard interface") err = wgInterface.Create() @@ -343,65 +358,52 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen return wgInterface } +func setupRouteAndCleanup(t *testing.T, r *SysOps, prefix netip.Prefix, intf *net.Interface) { + t.Helper() + + err := r.AddVPNRoute(prefix, intf) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = r.RemoveVPNRoute(prefix, intf) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) +} + func setupTestEnv(t *testing.T) { t.Helper() setupDummyInterfacesAndRoutes(t) - wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820) + wgInterface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820) t.Cleanup(func() { - assert.NoError(t, wgIface.Close()) + assert.NoError(t, wgInterface.Close()) }) - _, _, err := setupRouting(nil, wgIface) + r := NewSysOps(wgInterface) + _, _, err := r.SetupRouting(nil) require.NoError(t, err, "setupRouting should not return err") t.Cleanup(func() { - assert.NoError(t, cleanupRouting()) + assert.NoError(t, r.CleanupRouting()) }) - index, err := net.InterfaceByName(wgIface.Name()) + index, err := net.InterfaceByName(wgInterface.Name()) require.NoError(t, err, "InterfaceByName should not return err") - intf := &net.Interface{Index: index.Index, Name: wgIface.Name()} + intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()} // default route exists in main table and vpn table - err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) + setupRouteAndCleanup(t, r, netip.MustParsePrefix("0.0.0.0/0"), intf) // 10.0.0.0/8 route exists in main table and vpn table - err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) + setupRouteAndCleanup(t, r, netip.MustParsePrefix("10.0.0.0/8"), intf) // 10.10.0.0/24 more specific route exists in vpn table - err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) + setupRouteAndCleanup(t, r, netip.MustParsePrefix("10.10.0.0/24"), intf) // 127.0.10.0/24 more specific route exists in vpn table - err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) + setupRouteAndCleanup(t, r, netip.MustParsePrefix("127.0.10.0/24"), intf) // unique route in vpn table - err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) + setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf) } func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) { @@ -410,11 +412,133 @@ func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIf return } - prefixGateway, _, err := GetNextHop(prefix.Addr()) + prefixNexthop, err := GetNextHop(prefix.Addr()) require.NoError(t, err, "GetNextHop should not return err") if invert { - assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP") + assert.NotEqual(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should not point to wireguard interface IP") } else { - assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") + assert.Equal(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should point to wireguard interface IP") + } +} + +func TestIsVpnRoute(t *testing.T) { + tests := []struct { + name string + addr string + vpnRoutes []string + localRoutes []string + expectedVpn bool + expectedPrefix netip.Prefix + }{ + { + name: "Match in VPN routes", + addr: "192.168.1.1", + vpnRoutes: []string{"192.168.1.0/24"}, + localRoutes: []string{"10.0.0.0/8"}, + expectedVpn: true, + expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"), + }, + { + name: "Match in local routes", + addr: "10.1.1.1", + vpnRoutes: []string{"192.168.1.0/24"}, + localRoutes: []string{"10.0.0.0/8"}, + expectedVpn: false, + expectedPrefix: netip.MustParsePrefix("10.0.0.0/8"), + }, + { + name: "No match", + addr: "172.16.0.1", + vpnRoutes: []string{"192.168.1.0/24"}, + localRoutes: []string{"10.0.0.0/8"}, + expectedVpn: false, + expectedPrefix: netip.Prefix{}, + }, + { + name: "Default route ignored", + addr: "192.168.1.1", + vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"}, + localRoutes: []string{"10.0.0.0/8"}, + expectedVpn: true, + expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"), + }, + { + name: "Default route matches but ignored", + addr: "172.16.1.1", + vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"}, + localRoutes: []string{"10.0.0.0/8"}, + expectedVpn: false, + expectedPrefix: netip.Prefix{}, + }, + { + name: "Longest prefix match local", + addr: "192.168.1.1", + vpnRoutes: []string{"192.168.0.0/16"}, + localRoutes: []string{"192.168.1.0/24"}, + expectedVpn: false, + expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"), + }, + { + name: "Longest prefix match local multiple", + addr: "192.168.0.1", + vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"}, + localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26", "192.168.0.0/28"}, + expectedVpn: false, + expectedPrefix: netip.MustParsePrefix("192.168.0.0/28"), + }, + { + name: "Longest prefix match vpn", + addr: "192.168.1.1", + vpnRoutes: []string{"192.168.1.0/24"}, + localRoutes: []string{"192.168.0.0/16"}, + expectedVpn: true, + expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"), + }, + { + name: "Longest prefix match vpn multiple", + addr: "192.168.0.1", + vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"}, + localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26"}, + expectedVpn: true, + expectedPrefix: netip.MustParsePrefix("192.168.0.0/27"), + }, + { + name: "Duplicate prefix in both", + addr: "192.168.1.1", + vpnRoutes: []string{"192.168.1.0/24"}, + localRoutes: []string{"192.168.1.0/24"}, + expectedVpn: false, + expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + addr, err := netip.ParseAddr(tt.addr) + if err != nil { + t.Fatalf("Failed to parse address %s: %v", tt.addr, err) + } + + var vpnRoutes, localRoutes []netip.Prefix + for _, route := range tt.vpnRoutes { + prefix, err := netip.ParsePrefix(route) + if err != nil { + t.Fatalf("Failed to parse VPN route %s: %v", route, err) + } + vpnRoutes = append(vpnRoutes, prefix) + } + + for _, route := range tt.localRoutes { + prefix, err := netip.ParsePrefix(route) + if err != nil { + t.Fatalf("Failed to parse local route %s: %v", route, err) + } + localRoutes = append(localRoutes, prefix) + } + + isVpn, matchedPrefix := isVpnRoute(addr, vpnRoutes, localRoutes) + assert.Equal(t, tt.expectedVpn, isVpn, "isVpnRoute should return expectedVpn value") + assert.Equal(t, tt.expectedPrefix, matchedPrefix, "isVpnRoute should return expectedVpn prefix") + }) } } diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go similarity index 72% rename from client/internal/routemanager/systemops_linux.go rename to client/internal/routemanager/systemops/systemops_linux.go index ce0c07ce6..c4f69fba5 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -1,6 +1,6 @@ //go:build !android -package routemanager +package systemops import ( "bufio" @@ -9,16 +9,15 @@ import ( "net" "net/netip" "os" - "strconv" - "strings" "syscall" "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" "github.com/vishvananda/netlink" - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" + nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/internal/routemanager/sysctl" + "github.com/netbirdio/netbird/client/internal/routemanager/vars" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -33,16 +32,10 @@ const ( // ipv4ForwardingPath is the path to the file containing the IP forwarding setting. ipv4ForwardingPath = "net.ipv4.ip_forward" - - rpFilterPath = "net.ipv4.conf.all.rp_filter" - rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter" - srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark" ) var ErrTableIDExists = errors.New("ID exists with different name") -var routeManager = &RouteManager{} - // originalSysctl stores the original sysctl values before they are modified var originalSysctl map[string]int @@ -82,7 +75,7 @@ func getSetupRules() []ruleParams { } } -// setupRouting establishes the routing configuration for the VPN, including essential rules +// SetupRouting establishes the routing configuration for the VPN, including essential rules // to ensure proper traffic flow for management, locally configured routes, and VPN traffic. // // Rule 1 (Main Route Precedence): Safeguards locally installed routes by giving them precedence over @@ -92,17 +85,17 @@ func getSetupRules() []ruleParams { // Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. // This table is where a default route or other specific routes received from the management server are configured, // enabling VPN connectivity. -func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) { +func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) { if isLegacy() { log.Infof("Using legacy routing setup") - return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) + return r.setupRefCounter(initAddresses) } if err = addRoutingTableName(); err != nil { log.Errorf("Error adding routing table name: %v", err) } - originalValues, err := setupSysctl(wgIface) + originalValues, err := sysctl.Setup(r.wgInterface) if err != nil { log.Errorf("Error setting up sysctl: %v", err) sysctlFailed = true @@ -111,7 +104,7 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before defer func() { if err != nil { - if cleanErr := cleanupRouting(); cleanErr != nil { + if cleanErr := r.CleanupRouting(); cleanErr != nil { log.Errorf("Error cleaning up routing: %v", cleanErr) } } @@ -123,7 +116,7 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before if errors.Is(err, syscall.EOPNOTSUPP) { log.Warnf("Rule operations are not supported, falling back to the legacy routing setup") setIsLegacy(true) - return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) + return r.setupRefCounter(initAddresses) } return nil, nil, fmt.Errorf("%s: %w", rule.description, err) } @@ -132,12 +125,12 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before return nil, nil, nil } -// cleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. +// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. // It systematically removes the three rules and any associated routing table entries to ensure a clean state. // The function uses error aggregation to report any errors encountered during the cleanup process. -func cleanupRouting() error { +func (r *SysOps) CleanupRouting() error { if isLegacy() { - return cleanupRoutingWithRouteManager(routeManager) + return r.cleanupRefCounter() } var result *multierror.Error @@ -156,58 +149,58 @@ func cleanupRouting() error { } } - if err := cleanupSysctl(originalSysctl); err != nil { + if err := sysctl.Cleanup(originalSysctl); err != nil { result = multierror.Append(result, fmt.Errorf("cleanup sysctl: %w", err)) } originalSysctl = nil sysctlFailed = false - return result.ErrorOrNil() + return nberrors.FormatErrorOrNil(result) } -func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error { - return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN) +func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { + return addRoute(prefix, nexthop, syscall.RT_TABLE_MAIN) } -func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error { - return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN) +func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error { + return removeRoute(prefix, nexthop, syscall.RT_TABLE_MAIN) } -func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error { +func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { if isLegacy() { - return genericAddVPNRoute(prefix, intf) + return r.genericAddVPNRoute(prefix, intf) } - if sysctlFailed && (prefix == defaultv4 || prefix == defaultv6) { + if sysctlFailed && (prefix == vars.Defaultv4 || prefix == vars.Defaultv6) { log.Warnf("Default route is configured but sysctl operations failed, VPN traffic may not be routed correctly, consider using NB_USE_LEGACY_ROUTING=true or setting net.ipv4.conf.*.rp_filter to 2 (loose) or 0 (off)") } // No need to check if routes exist as main table takes precedence over the VPN table via Rule 1 // TODO remove this once we have ipv6 support - if prefix == defaultv4 { - if err := addUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil { + if prefix == vars.Defaultv4 { + if err := addUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil { return fmt.Errorf("add blackhole: %w", err) } } - if err := addRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil { + if err := addRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil { return fmt.Errorf("add route: %w", err) } return nil } -func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error { +func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error { if isLegacy() { - return genericRemoveVPNRoute(prefix, intf) + return r.genericRemoveVPNRoute(prefix, intf) } // TODO remove this once we have ipv6 support - if prefix == defaultv4 { - if err := removeUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil { + if prefix == vars.Defaultv4 { + if err := removeUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil { return fmt.Errorf("remove unreachable route: %w", err) } } - if err := removeRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil { + if err := removeRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil { return fmt.Errorf("remove route: %w", err) } return nil @@ -255,7 +248,7 @@ func getRoutes(tableID, family int) ([]netip.Prefix, error) { } // addRoute adds a route to a specific routing table identified by tableID. -func addRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error { +func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error { route := &netlink.Route{ Scope: netlink.SCOPE_UNIVERSE, Table: tableID, @@ -268,7 +261,7 @@ func addRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID } route.Dst = ipNet - if err := addNextHop(addr, intf, route); err != nil { + if err := addNextHop(nexthop, route); err != nil { return fmt.Errorf("add gateway and device: %w", err) } @@ -327,7 +320,7 @@ func removeUnreachableRoute(prefix netip.Prefix, tableID int) error { } // removeRoute removes a route from a specific routing table identified by tableID. -func removeRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error { +func removeRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error { _, ipNet, err := net.ParseCIDR(prefix.String()) if err != nil { return fmt.Errorf("parse prefix %s: %w", prefix, err) @@ -340,7 +333,7 @@ func removeRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tabl Dst: ipNet, } - if err := addNextHop(addr, intf, route); err != nil { + if err := addNextHop(nexthop, route); err != nil { return fmt.Errorf("add gateway and device: %w", err) } @@ -373,11 +366,11 @@ func flushRoutes(tableID, family int) error { } } - return result.ErrorOrNil() + return nberrors.FormatErrorOrNil(result) } -func enableIPForwarding() error { - _, err := setSysctl(ipv4ForwardingPath, 1, false) +func EnableIPForwarding() error { + _, err := sysctl.Set(ipv4ForwardingPath, 1, false) return err } @@ -481,19 +474,19 @@ func removeRule(params ruleParams) error { } // addNextHop adds the gateway and device to the route. -func addNextHop(addr netip.Addr, intf *net.Interface, route *netlink.Route) error { - if intf != nil { - route.LinkIndex = intf.Index +func addNextHop(nexthop Nexthop, route *netlink.Route) error { + if nexthop.Intf != nil { + route.LinkIndex = nexthop.Intf.Index } - if addr.IsValid() { - route.Gw = addr.AsSlice() + if nexthop.IP.IsValid() { + route.Gw = nexthop.IP.AsSlice() // if zone is set, it means the gateway is a link-local address, so we set the link index - if addr.Zone() != "" && intf == nil { - link, err := netlink.LinkByName(addr.Zone()) + if nexthop.IP.Zone() != "" && nexthop.Intf == nil { + link, err := netlink.LinkByName(nexthop.IP.Zone()) if err != nil { - return fmt.Errorf("get link by name for zone %s: %w", addr.Zone(), err) + return fmt.Errorf("get link by name for zone %s: %w", nexthop.IP.Zone(), err) } route.LinkIndex = link.Attrs().Index } @@ -509,82 +502,9 @@ func getAddressFamily(prefix netip.Prefix) int { return netlink.FAMILY_V6 } -// setupSysctl configures sysctl settings for RP filtering and source validation. -func setupSysctl(wgIface *iface.WGIface) (map[string]int, error) { - keys := map[string]int{} - var result *multierror.Error - - oldVal, err := setSysctl(srcValidMarkPath, 1, false) - if err != nil { - result = multierror.Append(result, err) - } else { - keys[srcValidMarkPath] = oldVal +func hasSeparateRouting() ([]netip.Prefix, error) { + if isLegacy() { + return getRoutesFromTable() } - - oldVal, err = setSysctl(rpFilterPath, 2, true) - if err != nil { - result = multierror.Append(result, err) - } else { - keys[rpFilterPath] = oldVal - } - - interfaces, err := net.Interfaces() - if err != nil { - result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err)) - } - - for _, intf := range interfaces { - if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() { - continue - } - - i := fmt.Sprintf(rpFilterInterfacePath, intf.Name) - oldVal, err := setSysctl(i, 2, true) - if err != nil { - result = multierror.Append(result, err) - } else { - keys[i] = oldVal - } - } - - return keys, result.ErrorOrNil() -} - -// setSysctl sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1 -func setSysctl(key string, desiredValue int, onlyIfOne bool) (int, error) { - path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/")) - currentValue, err := os.ReadFile(path) - if err != nil { - return -1, fmt.Errorf("read sysctl %s: %w", key, err) - } - - currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue))) - if err != nil && len(currentValue) > 0 { - return -1, fmt.Errorf("convert current desiredValue to int: %w", err) - } - - if currentV == desiredValue || onlyIfOne && currentV != 1 { - return currentV, nil - } - - //nolint:gosec - if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil { - return currentV, fmt.Errorf("write sysctl %s: %w", key, err) - } - log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue) - - return currentV, nil -} - -func cleanupSysctl(originalSettings map[string]int) error { - var result *multierror.Error - - for key, value := range originalSettings { - _, err := setSysctl(key, value, false) - if err != nil { - result = multierror.Append(result, err) - } - } - - return result.ErrorOrNil() + return nil, ErrRoutingIsSeparate } diff --git a/client/internal/routemanager/systemops_linux_test.go b/client/internal/routemanager/systemops/systemops_linux_test.go similarity index 96% rename from client/internal/routemanager/systemops_linux_test.go rename to client/internal/routemanager/systemops/systemops_linux_test.go index 0043c3f4e..8f12740d0 100644 --- a/client/internal/routemanager/systemops_linux_test.go +++ b/client/internal/routemanager/systemops/systemops_linux_test.go @@ -1,6 +1,6 @@ //go:build !android -package routemanager +package systemops import ( "errors" @@ -14,6 +14,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vishvananda/netlink" + + "github.com/netbirdio/netbird/client/internal/routemanager/vars" ) var expectedVPNint = "wgtest0" @@ -138,7 +140,7 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) { if dstIPNet.String() == "0.0.0.0/0" { var err error originalNexthop, originalLinkIndex, err = fetchOriginalGateway(netlink.FAMILY_V4) - if err != nil && !errors.Is(err, ErrRouteNotFound) { + if err != nil && !errors.Is(err, vars.ErrRouteNotFound) { t.Logf("Failed to fetch original gateway: %v", err) } @@ -193,7 +195,7 @@ func fetchOriginalGateway(family int) (net.IP, int, error) { } } - return nil, 0, ErrRouteNotFound + return nil, 0, vars.ErrRouteNotFound } func setupDummyInterfacesAndRoutes(t *testing.T) { diff --git a/client/internal/routemanager/systemops/systemops_mobile.go b/client/internal/routemanager/systemops/systemops_mobile.go new file mode 100644 index 000000000..43815c657 --- /dev/null +++ b/client/internal/routemanager/systemops/systemops_mobile.go @@ -0,0 +1,38 @@ +//go:build ios || android + +package systemops + +import ( + "net" + "net/netip" + "runtime" + + log "github.com/sirupsen/logrus" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +func (r *SysOps) SetupRouting([]net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { + return nil, nil, nil +} + +func (r *SysOps) CleanupRouting() error { + return nil +} + +func (r *SysOps) AddVPNRoute(netip.Prefix, *net.Interface) error { + return nil +} + +func (r *SysOps) RemoveVPNRoute(netip.Prefix, *net.Interface) error { + return nil +} + +func EnableIPForwarding() error { + log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) + return nil +} + +func IsAddrRouted(netip.Addr, []netip.Prefix) (bool, netip.Prefix) { + return false, netip.Prefix{} +} diff --git a/client/internal/routemanager/systemops/systemops_nonlinux.go b/client/internal/routemanager/systemops/systemops_nonlinux.go new file mode 100644 index 000000000..0adeb0992 --- /dev/null +++ b/client/internal/routemanager/systemops/systemops_nonlinux.go @@ -0,0 +1,28 @@ +//go:build !linux && !ios + +package systemops + +import ( + "net" + "net/netip" + "runtime" + + log "github.com/sirupsen/logrus" +) + +func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { + return r.genericAddVPNRoute(prefix, intf) +} + +func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error { + return r.genericRemoveVPNRoute(prefix, intf) +} + +func EnableIPForwarding() error { + log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) + return nil +} + +func hasSeparateRouting() ([]netip.Prefix, error) { + return getRoutesFromTable() +} diff --git a/client/internal/routemanager/systemops_darwin.go b/client/internal/routemanager/systemops/systemops_unix.go similarity index 55% rename from client/internal/routemanager/systemops_darwin.go rename to client/internal/routemanager/systemops/systemops_unix.go index ee4196a0c..a2bbf35cf 100644 --- a/client/internal/routemanager/systemops_darwin.go +++ b/client/internal/routemanager/systemops/systemops_unix.go @@ -1,6 +1,6 @@ -//go:build darwin && !ios +//go:build (darwin && !ios) || dragonfly || freebsd || netbsd || openbsd -package routemanager +package systemops import ( "fmt" @@ -13,43 +13,41 @@ import ( "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" + nbnet "github.com/netbirdio/netbird/util/net" ) -var routeManager *RouteManager - -func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) +func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { + return r.setupRefCounter(initAddresses) } -func cleanupRouting() error { - return cleanupRoutingWithRouteManager(routeManager) +func (r *SysOps) CleanupRouting() error { + return r.cleanupRefCounter() } -func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error { - return routeCmd("add", prefix, nexthop, intf) +func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { + return r.routeCmd("add", prefix, nexthop) } -func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error { - return routeCmd("delete", prefix, nexthop, intf) +func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error { + return r.routeCmd("delete", prefix, nexthop) } -func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error { +func (r *SysOps) routeCmd(action string, prefix netip.Prefix, nexthop Nexthop) error { inet := "-inet" - network := prefix.String() - if prefix.IsSingleIP() { - network = prefix.Addr().String() - } if prefix.Addr().Is6() { inet = "-inet6" } + network := prefix.String() + if prefix.IsSingleIP() { + network = prefix.Addr().String() + } + args := []string{"-n", action, inet, network} - if nexthop.IsValid() { - args = append(args, nexthop.Unmap().String()) - } else if intf != nil { - args = append(args, "-interface", intf.Name) + if nexthop.IP.IsValid() { + args = append(args, nexthop.IP.Unmap().String()) + } else if nexthop.Intf != nil { + args = append(args, "-interface", nexthop.Intf.Name) } if err := retryRouteCmd(args); err != nil { diff --git a/client/internal/routemanager/systemops_unix_test.go b/client/internal/routemanager/systemops/systemops_unix_test.go similarity index 97% rename from client/internal/routemanager/systemops_unix_test.go rename to client/internal/routemanager/systemops/systemops_unix_test.go index 561eaeea4..a6000d963 100644 --- a/client/internal/routemanager/systemops_unix_test.go +++ b/client/internal/routemanager/systemops/systemops_unix_test.go @@ -1,10 +1,11 @@ //go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly -package routemanager +package systemops import ( "fmt" "net" + "runtime" "strings" "testing" "time" @@ -85,6 +86,10 @@ var testCases = []testCase{ func TestRouting(t *testing.T) { for _, tc := range testCases { + // todo resolve test execution on freebsd + if runtime.GOOS == "freebsd" { + t.Skip("skipping ", tc.name, " on freebsd") + } t.Run(tc.name, func(t *testing.T) { setupTestEnv(t) diff --git a/client/internal/routemanager/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go similarity index 77% rename from client/internal/routemanager/systemops_windows.go rename to client/internal/routemanager/systemops/systemops_windows.go index 32e94d8da..88bdce7c9 100644 --- a/client/internal/routemanager/systemops_windows.go +++ b/client/internal/routemanager/systemops/systemops_windows.go @@ -1,6 +1,6 @@ //go:build windows -package routemanager +package systemops import ( "fmt" @@ -17,8 +17,7 @@ import ( "github.com/yusufpapurcu/wmi" "github.com/netbirdio/netbird/client/firewall/uspfilter" - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" + nbnet "github.com/netbirdio/netbird/util/net" ) type MSFT_NetRoute struct { @@ -57,14 +56,42 @@ var prefixList []netip.Prefix var lastUpdate time.Time var mux = sync.Mutex{} -var routeManager *RouteManager - -func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) +func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { + return r.setupRefCounter(initAddresses) } -func cleanupRouting() error { - return cleanupRoutingWithRouteManager(routeManager) +func (r *SysOps) CleanupRouting() error { + return r.cleanupRefCounter() +} + +func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { + if nexthop.IP.Zone() != "" && nexthop.Intf == nil { + zone, err := strconv.Atoi(nexthop.IP.Zone()) + if err != nil { + return fmt.Errorf("invalid zone: %w", err) + } + nexthop.Intf = &net.Interface{Index: zone} + } + + return addRouteCmd(prefix, nexthop) +} + +func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error { + args := []string{"delete", prefix.String()} + if nexthop.IP.IsValid() { + ip := nexthop.IP.WithZone("") + args = append(args, ip.Unmap().String()) + } + + routeCmd := uspfilter.GetSystem32Command("route") + + out, err := exec.Command(routeCmd, args...).CombinedOutput() + log.Tracef("route %s: %s", strings.Join(args, " "), out) + + if err != nil { + return fmt.Errorf("remove route: %w", err) + } + return nil } func getRoutesFromTable() ([]netip.Prefix, error) { @@ -93,7 +120,7 @@ func getRoutesFromTable() ([]netip.Prefix, error) { func GetRoutes() ([]Route, error) { var entries []MSFT_NetRoute - query := `SELECT DestinationPrefix, NextHop, InterfaceIndex, InterfaceAlias, AddressFamily FROM MSFT_NetRoute` + query := `SELECT DestinationPrefix, Nexthop, InterfaceIndex, InterfaceAlias, AddressFamily FROM MSFT_NetRoute` if err := wmi.QueryNamespace(query, &entries, `ROOT\StandardCimv2`); err != nil { return nil, fmt.Errorf("get routes: %w", err) } @@ -118,6 +145,10 @@ func GetRoutes() ([]Route, error) { Index: int(entry.InterfaceIndex), Name: entry.InterfaceAlias, } + + if nexthop.Is6() && (nexthop.IsLinkLocalUnicast() || nexthop.IsLinkLocalMulticast()) { + nexthop = nexthop.WithZone(strconv.Itoa(int(entry.InterfaceIndex))) + } } routes = append(routes, Route{ @@ -157,11 +188,12 @@ func GetNeighbors() ([]Neighbor, error) { return neighbors, nil } -func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error { +func addRouteCmd(prefix netip.Prefix, nexthop Nexthop) error { args := []string{"add", prefix.String()} - if nexthop.IsValid() { - args = append(args, nexthop.Unmap().String()) + if nexthop.IP.IsValid() { + ip := nexthop.IP.WithZone("") + args = append(args, ip.Unmap().String()) } else { addr := "0.0.0.0" if prefix.Addr().Is6() { @@ -170,8 +202,8 @@ func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) e args = append(args, addr) } - if intf != nil { - args = append(args, "if", strconv.Itoa(intf.Index)) + if nexthop.Intf != nil { + args = append(args, "if", strconv.Itoa(nexthop.Intf.Index)) } routeCmd := uspfilter.GetSystem32Command("route") @@ -185,37 +217,6 @@ func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) e return nil } -func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error { - if nexthop.Zone() != "" && intf == nil { - zone, err := strconv.Atoi(nexthop.Zone()) - if err != nil { - return fmt.Errorf("invalid zone: %w", err) - } - intf = &net.Interface{Index: zone} - nexthop.WithZone("") - } - - return addRouteCmd(prefix, nexthop, intf) -} - -func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ *net.Interface) error { - args := []string{"delete", prefix.String()} - if nexthop.IsValid() { - nexthop.WithZone("") - args = append(args, nexthop.Unmap().String()) - } - - routeCmd := uspfilter.GetSystem32Command("route") - - out, err := exec.Command(routeCmd, args...).CombinedOutput() - log.Tracef("route %s: %s", strings.Join(args, " "), out) - - if err != nil { - return fmt.Errorf("remove route: %w", err) - } - return nil -} - func isCacheDisabled() bool { return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true" } diff --git a/client/internal/routemanager/systemops_windows_test.go b/client/internal/routemanager/systemops/systemops_windows_test.go similarity index 97% rename from client/internal/routemanager/systemops_windows_test.go rename to client/internal/routemanager/systemops/systemops_windows_test.go index a5e03b8d2..9180ed58c 100644 --- a/client/internal/routemanager/systemops_windows_test.go +++ b/client/internal/routemanager/systemops/systemops_windows_test.go @@ -1,4 +1,4 @@ -package routemanager +package systemops import ( "context" @@ -29,7 +29,7 @@ type FindNetRouteOutput struct { InterfaceIndex int `json:"InterfaceIndex"` InterfaceAlias string `json:"InterfaceAlias"` AddressFamily int `json:"AddressFamily"` - NextHop string `json:"NextHop"` + NextHop string `json:"Nexthop"` DestinationPrefix string `json:"DestinationPrefix"` } @@ -166,7 +166,7 @@ func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOut host, _, err := net.SplitHostPort(destination) require.NoError(t, err) - script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, NextHop, DestinationPrefix | ConvertTo-Json`, host) + script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, Nexthop, DestinationPrefix | ConvertTo-Json`, host) out, err := exec.Command("powershell", "-Command", script).Output() require.NoError(t, err, "Failed to execute Find-NetRoute") @@ -207,7 +207,7 @@ func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR str } func fetchOriginalGateway() (*RouteInfo, error) { - cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object NextHop, RouteMetric, InterfaceAlias | ConvertTo-Json") + cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object Nexthop, RouteMetric, InterfaceAlias | ConvertTo-Json") output, err := cmd.CombinedOutput() if err != nil { return nil, fmt.Errorf("failed to execute Get-NetRoute: %w", err) diff --git a/client/internal/routemanager/systemops_android.go b/client/internal/routemanager/systemops_android.go deleted file mode 100644 index 4d23d3910..000000000 --- a/client/internal/routemanager/systemops_android.go +++ /dev/null @@ -1,33 +0,0 @@ -package routemanager - -import ( - "net" - "net/netip" - "runtime" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" -) - -func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - return nil, nil, nil -} - -func cleanupRouting() error { - return nil -} - -func enableIPForwarding() error { - log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) - return nil -} - -func addVPNRoute(netip.Prefix, *net.Interface) error { - return nil -} - -func removeVPNRoute(netip.Prefix, *net.Interface) error { - return nil -} diff --git a/client/internal/routemanager/systemops_bsd_test.go b/client/internal/routemanager/systemops_bsd_test.go deleted file mode 100644 index 81bca504c..000000000 --- a/client/internal/routemanager/systemops_bsd_test.go +++ /dev/null @@ -1,57 +0,0 @@ -//go:build darwin || dragonfly || freebsd || netbsd || openbsd - -package routemanager - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "golang.org/x/net/route" -) - -func TestBits(t *testing.T) { - tests := []struct { - name string - addr route.Addr - want int - wantErr bool - }{ - { - name: "IPv4 all ones", - addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 255}}, - want: 32, - }, - { - name: "IPv4 normal mask", - addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 0}}, - want: 24, - }, - { - name: "IPv6 all ones", - addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}}, - want: 128, - }, - { - name: "IPv6 normal mask", - addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0}}, - want: 64, - }, - { - name: "Unsupported type", - addr: &route.LinkAddr{}, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := ones(tt.addr) - if tt.wantErr { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tt.want, got) - } - }) - } -} diff --git a/client/internal/routemanager/systemops_ios.go b/client/internal/routemanager/systemops_ios.go deleted file mode 100644 index 4d23d3910..000000000 --- a/client/internal/routemanager/systemops_ios.go +++ /dev/null @@ -1,33 +0,0 @@ -package routemanager - -import ( - "net" - "net/netip" - "runtime" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" -) - -func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - return nil, nil, nil -} - -func cleanupRouting() error { - return nil -} - -func enableIPForwarding() error { - log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) - return nil -} - -func addVPNRoute(netip.Prefix, *net.Interface) error { - return nil -} - -func removeVPNRoute(netip.Prefix, *net.Interface) error { - return nil -} diff --git a/client/internal/routemanager/systemops_nonlinux.go b/client/internal/routemanager/systemops_nonlinux.go deleted file mode 100644 index 91879790a..000000000 --- a/client/internal/routemanager/systemops_nonlinux.go +++ /dev/null @@ -1,24 +0,0 @@ -//go:build !linux && !ios - -package routemanager - -import ( - "net" - "net/netip" - "runtime" - - log "github.com/sirupsen/logrus" -) - -func enableIPForwarding() error { - log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) - return nil -} - -func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error { - return genericAddVPNRoute(prefix, intf) -} - -func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error { - return genericRemoveVPNRoute(prefix, intf) -} diff --git a/client/internal/routemanager/util/ip.go b/client/internal/routemanager/util/ip.go new file mode 100644 index 000000000..ac5a48e37 --- /dev/null +++ b/client/internal/routemanager/util/ip.go @@ -0,0 +1,29 @@ +package util + +import ( + "fmt" + "net" + "net/netip" +) + +// GetPrefixFromIP returns a netip.Prefix from a net.IP address. +func GetPrefixFromIP(ip net.IP) (netip.Prefix, error) { + addr, ok := netip.AddrFromSlice(ip) + if !ok { + return netip.Prefix{}, fmt.Errorf("parse IP address: %s", ip) + } + addr = addr.Unmap() + + var prefixLength int + switch { + case addr.Is4(): + prefixLength = 32 + case addr.Is6(): + prefixLength = 128 + default: + return netip.Prefix{}, fmt.Errorf("invalid IP address: %s", addr) + } + + prefix := netip.PrefixFrom(addr, prefixLength) + return prefix, nil +} diff --git a/client/internal/routemanager/vars/vars.go b/client/internal/routemanager/vars/vars.go new file mode 100644 index 000000000..4aa986d2f --- /dev/null +++ b/client/internal/routemanager/vars/vars.go @@ -0,0 +1,16 @@ +package vars + +import ( + "errors" + "net/netip" +) + +const MinRangeBits = 7 + +var ( + ErrRouteNotFound = errors.New("route not found") + ErrRouteNotAllowed = errors.New("route not allowed") + + Defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0) + Defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0) +) diff --git a/client/internal/routeselector/routeselector.go b/client/internal/routeselector/routeselector.go index 1c17e8803..00128a27b 100644 --- a/client/internal/routeselector/routeselector.go +++ b/client/internal/routeselector/routeselector.go @@ -3,11 +3,11 @@ package routeselector import ( "fmt" "slices" - "strings" "github.com/hashicorp/go-multierror" "golang.org/x/exp/maps" + "github.com/netbirdio/netbird/client/errors" route "github.com/netbirdio/netbird/route" ) @@ -30,10 +30,10 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al rs.selectedRoutes = map[route.NetID]struct{}{} } - var multiErr *multierror.Error + var err *multierror.Error for _, route := range routes { if !slices.Contains(allRoutes, route) { - multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route)) + err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route)) continue } @@ -41,11 +41,7 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al } rs.selectAll = false - if multiErr != nil { - multiErr.ErrorFormat = formatError - } - - return multiErr.ErrorOrNil() + return errors.FormatErrorOrNil(err) } // SelectAllRoutes sets the selector to select all routes. @@ -65,21 +61,17 @@ func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route. } } - var multiErr *multierror.Error + var err *multierror.Error for _, route := range routes { if !slices.Contains(allRoutes, route) { - multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route)) + err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route)) continue } delete(rs.selectedRoutes, route) } - if multiErr != nil { - multiErr.ErrorFormat = formatError - } - - return multiErr.ErrorOrNil() + return errors.FormatErrorOrNil(err) } // DeselectAllRoutes deselects all routes, effectively disabling route selection. @@ -111,18 +103,3 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap { } return filtered } - -func formatError(es []error) string { - if len(es) == 1 { - return fmt.Sprintf("1 error occurred:\n\t* %s", es[0]) - } - - points := make([]string, len(es)) - for i, err := range es { - points[i] = fmt.Sprintf("* %s", err) - } - - return fmt.Sprintf( - "%d errors occurred:\n\t%s", - len(es), strings.Join(points, "\n\t")) -} diff --git a/client/internal/routeselector/routeselector_test.go b/client/internal/routeselector/routeselector_test.go index fb1e456cd..7df433f92 100644 --- a/client/internal/routeselector/routeselector_test.go +++ b/client/internal/routeselector/routeselector_test.go @@ -261,15 +261,15 @@ func TestRouteSelector_FilterSelected(t *testing.T) { require.NoError(t, err) routes := route.HAMap{ - "route1-10.0.0.0/8": {}, - "route2-192.168.0.0/16": {}, - "route3-172.16.0.0/12": {}, + "route1|10.0.0.0/8": {}, + "route2|192.168.0.0/16": {}, + "route3|172.16.0.0/12": {}, } filtered := rs.FilterSelected(routes) assert.Equal(t, route.HAMap{ - "route1-10.0.0.0/8": {}, - "route2-192.168.0.0/16": {}, + "route1|10.0.0.0/8": {}, + "route2|192.168.0.0/16": {}, }, filtered) } diff --git a/client/internal/wgproxy/factory_linux.go b/client/internal/wgproxy/factory_linux.go index ba1ef8c45..d01ae7e74 100644 --- a/client/internal/wgproxy/factory_linux.go +++ b/client/internal/wgproxy/factory_linux.go @@ -4,21 +4,24 @@ package wgproxy import ( "context" + + log "github.com/sirupsen/logrus" ) -func NewFactory(ctx context.Context, wgPort int) *Factory { +func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory { f := &Factory{wgPort: wgPort} - // todo: put it back - /* - ebpfProxy := NewWGEBPFProxy(ctx, wgPort) - err := ebpfProxy.listen() - if err != nil { - log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err) - return f - } - f.ebpfProxy = ebpfProxy + if userspace { + return f + } - */ + ebpfProxy := NewWGEBPFProxy(ctx, wgPort) + err := ebpfProxy.listen() + if err != nil { + log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err) + return f + } + + f.ebpfProxy = ebpfProxy return f } diff --git a/client/internal/wgproxy/factory_nonlinux.go b/client/internal/wgproxy/factory_nonlinux.go index 33a235c4a..d1640c97d 100644 --- a/client/internal/wgproxy/factory_nonlinux.go +++ b/client/internal/wgproxy/factory_nonlinux.go @@ -4,6 +4,6 @@ package wgproxy import "context" -func NewFactory(ctx context.Context, wgPort int) *Factory { +func NewFactory(ctx context.Context, _ bool, wgPort int) *Factory { return &Factory{wgPort: wgPort} } diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 83c8278d5..813540246 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,17 +1,17 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v3.12.4 +// protoc v4.23.4 // source: daemon.proto package proto import ( - _ "github.com/golang/protobuf/protoc-gen-go/descriptor" - duration "github.com/golang/protobuf/ptypes/duration" - timestamp "github.com/golang/protobuf/ptypes/timestamp" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" + _ "google.golang.org/protobuf/types/descriptorpb" + durationpb "google.golang.org/protobuf/types/known/durationpb" + timestamppb "google.golang.org/protobuf/types/known/timestamppb" reflect "reflect" sync "sync" ) @@ -108,19 +108,20 @@ type LoginRequest struct { // cleanNATExternalIPs clean map list of external IPs. // This is needed because the generated code // omits initialized empty slices due to omitempty tags - CleanNATExternalIPs bool `protobuf:"varint,6,opt,name=cleanNATExternalIPs,proto3" json:"cleanNATExternalIPs,omitempty"` - CustomDNSAddress []byte `protobuf:"bytes,7,opt,name=customDNSAddress,proto3" json:"customDNSAddress,omitempty"` - IsLinuxDesktopClient bool `protobuf:"varint,8,opt,name=isLinuxDesktopClient,proto3" json:"isLinuxDesktopClient,omitempty"` - Hostname string `protobuf:"bytes,9,opt,name=hostname,proto3" json:"hostname,omitempty"` - RosenpassEnabled *bool `protobuf:"varint,10,opt,name=rosenpassEnabled,proto3,oneof" json:"rosenpassEnabled,omitempty"` - InterfaceName *string `protobuf:"bytes,11,opt,name=interfaceName,proto3,oneof" json:"interfaceName,omitempty"` - WireguardPort *int64 `protobuf:"varint,12,opt,name=wireguardPort,proto3,oneof" json:"wireguardPort,omitempty"` - OptionalPreSharedKey *string `protobuf:"bytes,13,opt,name=optionalPreSharedKey,proto3,oneof" json:"optionalPreSharedKey,omitempty"` - DisableAutoConnect *bool `protobuf:"varint,14,opt,name=disableAutoConnect,proto3,oneof" json:"disableAutoConnect,omitempty"` - ServerSSHAllowed *bool `protobuf:"varint,15,opt,name=serverSSHAllowed,proto3,oneof" json:"serverSSHAllowed,omitempty"` - RosenpassPermissive *bool `protobuf:"varint,16,opt,name=rosenpassPermissive,proto3,oneof" json:"rosenpassPermissive,omitempty"` - ExtraIFaceBlacklist []string `protobuf:"bytes,17,rep,name=extraIFaceBlacklist,proto3" json:"extraIFaceBlacklist,omitempty"` - NetworkMonitor *bool `protobuf:"varint,18,opt,name=networkMonitor,proto3,oneof" json:"networkMonitor,omitempty"` + CleanNATExternalIPs bool `protobuf:"varint,6,opt,name=cleanNATExternalIPs,proto3" json:"cleanNATExternalIPs,omitempty"` + CustomDNSAddress []byte `protobuf:"bytes,7,opt,name=customDNSAddress,proto3" json:"customDNSAddress,omitempty"` + IsLinuxDesktopClient bool `protobuf:"varint,8,opt,name=isLinuxDesktopClient,proto3" json:"isLinuxDesktopClient,omitempty"` + Hostname string `protobuf:"bytes,9,opt,name=hostname,proto3" json:"hostname,omitempty"` + RosenpassEnabled *bool `protobuf:"varint,10,opt,name=rosenpassEnabled,proto3,oneof" json:"rosenpassEnabled,omitempty"` + InterfaceName *string `protobuf:"bytes,11,opt,name=interfaceName,proto3,oneof" json:"interfaceName,omitempty"` + WireguardPort *int64 `protobuf:"varint,12,opt,name=wireguardPort,proto3,oneof" json:"wireguardPort,omitempty"` + OptionalPreSharedKey *string `protobuf:"bytes,13,opt,name=optionalPreSharedKey,proto3,oneof" json:"optionalPreSharedKey,omitempty"` + DisableAutoConnect *bool `protobuf:"varint,14,opt,name=disableAutoConnect,proto3,oneof" json:"disableAutoConnect,omitempty"` + ServerSSHAllowed *bool `protobuf:"varint,15,opt,name=serverSSHAllowed,proto3,oneof" json:"serverSSHAllowed,omitempty"` + RosenpassPermissive *bool `protobuf:"varint,16,opt,name=rosenpassPermissive,proto3,oneof" json:"rosenpassPermissive,omitempty"` + ExtraIFaceBlacklist []string `protobuf:"bytes,17,rep,name=extraIFaceBlacklist,proto3" json:"extraIFaceBlacklist,omitempty"` + NetworkMonitor *bool `protobuf:"varint,18,opt,name=networkMonitor,proto3,oneof" json:"networkMonitor,omitempty"` + DnsRouteInterval *durationpb.Duration `protobuf:"bytes,19,opt,name=dnsRouteInterval,proto3,oneof" json:"dnsRouteInterval,omitempty"` } func (x *LoginRequest) Reset() { @@ -282,6 +283,13 @@ func (x *LoginRequest) GetNetworkMonitor() bool { return false } +func (x *LoginRequest) GetDnsRouteInterval() *durationpb.Duration { + if x != nil { + return x.DnsRouteInterval + } + return nil +} + type LoginResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -762,7 +770,13 @@ type GetConfigResponse struct { // preSharedKey settings value. PreSharedKey string `protobuf:"bytes,4,opt,name=preSharedKey,proto3" json:"preSharedKey,omitempty"` // adminURL settings value. - AdminURL string `protobuf:"bytes,5,opt,name=adminURL,proto3" json:"adminURL,omitempty"` + AdminURL string `protobuf:"bytes,5,opt,name=adminURL,proto3" json:"adminURL,omitempty"` + InterfaceName string `protobuf:"bytes,6,opt,name=interfaceName,proto3" json:"interfaceName,omitempty"` + WireguardPort int64 `protobuf:"varint,7,opt,name=wireguardPort,proto3" json:"wireguardPort,omitempty"` + DisableAutoConnect bool `protobuf:"varint,9,opt,name=disableAutoConnect,proto3" json:"disableAutoConnect,omitempty"` + ServerSSHAllowed bool `protobuf:"varint,10,opt,name=serverSSHAllowed,proto3" json:"serverSSHAllowed,omitempty"` + RosenpassEnabled bool `protobuf:"varint,11,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` + RosenpassPermissive bool `protobuf:"varint,12,opt,name=rosenpassPermissive,proto3" json:"rosenpassPermissive,omitempty"` } func (x *GetConfigResponse) Reset() { @@ -832,29 +846,71 @@ func (x *GetConfigResponse) GetAdminURL() string { return "" } +func (x *GetConfigResponse) GetInterfaceName() string { + if x != nil { + return x.InterfaceName + } + return "" +} + +func (x *GetConfigResponse) GetWireguardPort() int64 { + if x != nil { + return x.WireguardPort + } + return 0 +} + +func (x *GetConfigResponse) GetDisableAutoConnect() bool { + if x != nil { + return x.DisableAutoConnect + } + return false +} + +func (x *GetConfigResponse) GetServerSSHAllowed() bool { + if x != nil { + return x.ServerSSHAllowed + } + return false +} + +func (x *GetConfigResponse) GetRosenpassEnabled() bool { + if x != nil { + return x.RosenpassEnabled + } + return false +} + +func (x *GetConfigResponse) GetRosenpassPermissive() bool { + if x != nil { + return x.RosenpassPermissive + } + return false +} + // PeerState contains the latest state of a peer type PeerState struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - IP string `protobuf:"bytes,1,opt,name=IP,proto3" json:"IP,omitempty"` - PubKey string `protobuf:"bytes,2,opt,name=pubKey,proto3" json:"pubKey,omitempty"` - ConnStatus string `protobuf:"bytes,3,opt,name=connStatus,proto3" json:"connStatus,omitempty"` - ConnStatusUpdate *timestamp.Timestamp `protobuf:"bytes,4,opt,name=connStatusUpdate,proto3" json:"connStatusUpdate,omitempty"` - Relayed bool `protobuf:"varint,5,opt,name=relayed,proto3" json:"relayed,omitempty"` - Direct bool `protobuf:"varint,6,opt,name=direct,proto3" json:"direct,omitempty"` - LocalIceCandidateType string `protobuf:"bytes,7,opt,name=localIceCandidateType,proto3" json:"localIceCandidateType,omitempty"` - RemoteIceCandidateType string `protobuf:"bytes,8,opt,name=remoteIceCandidateType,proto3" json:"remoteIceCandidateType,omitempty"` - Fqdn string `protobuf:"bytes,9,opt,name=fqdn,proto3" json:"fqdn,omitempty"` - LocalIceCandidateEndpoint string `protobuf:"bytes,10,opt,name=localIceCandidateEndpoint,proto3" json:"localIceCandidateEndpoint,omitempty"` - RemoteIceCandidateEndpoint string `protobuf:"bytes,11,opt,name=remoteIceCandidateEndpoint,proto3" json:"remoteIceCandidateEndpoint,omitempty"` - LastWireguardHandshake *timestamp.Timestamp `protobuf:"bytes,12,opt,name=lastWireguardHandshake,proto3" json:"lastWireguardHandshake,omitempty"` - BytesRx int64 `protobuf:"varint,13,opt,name=bytesRx,proto3" json:"bytesRx,omitempty"` - BytesTx int64 `protobuf:"varint,14,opt,name=bytesTx,proto3" json:"bytesTx,omitempty"` - RosenpassEnabled bool `protobuf:"varint,15,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` - Routes []string `protobuf:"bytes,16,rep,name=routes,proto3" json:"routes,omitempty"` - Latency *duration.Duration `protobuf:"bytes,17,opt,name=latency,proto3" json:"latency,omitempty"` + IP string `protobuf:"bytes,1,opt,name=IP,proto3" json:"IP,omitempty"` + PubKey string `protobuf:"bytes,2,opt,name=pubKey,proto3" json:"pubKey,omitempty"` + ConnStatus string `protobuf:"bytes,3,opt,name=connStatus,proto3" json:"connStatus,omitempty"` + ConnStatusUpdate *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=connStatusUpdate,proto3" json:"connStatusUpdate,omitempty"` + Relayed bool `protobuf:"varint,5,opt,name=relayed,proto3" json:"relayed,omitempty"` + Direct bool `protobuf:"varint,6,opt,name=direct,proto3" json:"direct,omitempty"` + LocalIceCandidateType string `protobuf:"bytes,7,opt,name=localIceCandidateType,proto3" json:"localIceCandidateType,omitempty"` + RemoteIceCandidateType string `protobuf:"bytes,8,opt,name=remoteIceCandidateType,proto3" json:"remoteIceCandidateType,omitempty"` + Fqdn string `protobuf:"bytes,9,opt,name=fqdn,proto3" json:"fqdn,omitempty"` + LocalIceCandidateEndpoint string `protobuf:"bytes,10,opt,name=localIceCandidateEndpoint,proto3" json:"localIceCandidateEndpoint,omitempty"` + RemoteIceCandidateEndpoint string `protobuf:"bytes,11,opt,name=remoteIceCandidateEndpoint,proto3" json:"remoteIceCandidateEndpoint,omitempty"` + LastWireguardHandshake *timestamppb.Timestamp `protobuf:"bytes,12,opt,name=lastWireguardHandshake,proto3" json:"lastWireguardHandshake,omitempty"` + BytesRx int64 `protobuf:"varint,13,opt,name=bytesRx,proto3" json:"bytesRx,omitempty"` + BytesTx int64 `protobuf:"varint,14,opt,name=bytesTx,proto3" json:"bytesTx,omitempty"` + RosenpassEnabled bool `protobuf:"varint,15,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` + Routes []string `protobuf:"bytes,16,rep,name=routes,proto3" json:"routes,omitempty"` + Latency *durationpb.Duration `protobuf:"bytes,17,opt,name=latency,proto3" json:"latency,omitempty"` } func (x *PeerState) Reset() { @@ -910,7 +966,7 @@ func (x *PeerState) GetConnStatus() string { return "" } -func (x *PeerState) GetConnStatusUpdate() *timestamp.Timestamp { +func (x *PeerState) GetConnStatusUpdate() *timestamppb.Timestamp { if x != nil { return x.ConnStatusUpdate } @@ -966,7 +1022,7 @@ func (x *PeerState) GetRemoteIceCandidateEndpoint() string { return "" } -func (x *PeerState) GetLastWireguardHandshake() *timestamp.Timestamp { +func (x *PeerState) GetLastWireguardHandshake() *timestamppb.Timestamp { if x != nil { return x.LastWireguardHandshake } @@ -1001,7 +1057,7 @@ func (x *PeerState) GetRoutes() []string { return nil } -func (x *PeerState) GetLatency() *duration.Duration { +func (x *PeerState) GetLatency() *durationpb.Duration { if x != nil { return x.Latency } @@ -1641,20 +1697,69 @@ func (*SelectRoutesResponse) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{22} } +type IPList struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Ips []string `protobuf:"bytes,1,rep,name=ips,proto3" json:"ips,omitempty"` +} + +func (x *IPList) Reset() { + *x = IPList{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[23] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *IPList) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*IPList) ProtoMessage() {} + +func (x *IPList) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[23] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use IPList.ProtoReflect.Descriptor instead. +func (*IPList) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{23} +} + +func (x *IPList) GetIps() []string { + if x != nil { + return x.Ips + } + return nil +} + type Route struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ID string `protobuf:"bytes,1,opt,name=ID,proto3" json:"ID,omitempty"` - Network string `protobuf:"bytes,2,opt,name=network,proto3" json:"network,omitempty"` - Selected bool `protobuf:"varint,3,opt,name=selected,proto3" json:"selected,omitempty"` + ID string `protobuf:"bytes,1,opt,name=ID,proto3" json:"ID,omitempty"` + Network string `protobuf:"bytes,2,opt,name=network,proto3" json:"network,omitempty"` + Selected bool `protobuf:"varint,3,opt,name=selected,proto3" json:"selected,omitempty"` + Domains []string `protobuf:"bytes,4,rep,name=domains,proto3" json:"domains,omitempty"` + ResolvedIPs map[string]*IPList `protobuf:"bytes,5,rep,name=resolvedIPs,proto3" json:"resolvedIPs,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` } func (x *Route) Reset() { *x = Route{} if protoimpl.UnsafeEnabled { - mi := &file_daemon_proto_msgTypes[23] + mi := &file_daemon_proto_msgTypes[24] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1667,7 +1772,7 @@ func (x *Route) String() string { func (*Route) ProtoMessage() {} func (x *Route) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[23] + mi := &file_daemon_proto_msgTypes[24] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1680,7 +1785,7 @@ func (x *Route) ProtoReflect() protoreflect.Message { // Deprecated: Use Route.ProtoReflect.Descriptor instead. func (*Route) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{23} + return file_daemon_proto_rawDescGZIP(), []int{24} } func (x *Route) GetID() string { @@ -1704,6 +1809,20 @@ func (x *Route) GetSelected() bool { return false } +func (x *Route) GetDomains() []string { + if x != nil { + return x.Domains + } + return nil +} + +func (x *Route) GetResolvedIPs() map[string]*IPList { + if x != nil { + return x.ResolvedIPs + } + return nil +} + type DebugBundleRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -1716,7 +1835,7 @@ type DebugBundleRequest struct { func (x *DebugBundleRequest) Reset() { *x = DebugBundleRequest{} if protoimpl.UnsafeEnabled { - mi := &file_daemon_proto_msgTypes[24] + mi := &file_daemon_proto_msgTypes[25] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1729,7 +1848,7 @@ func (x *DebugBundleRequest) String() string { func (*DebugBundleRequest) ProtoMessage() {} func (x *DebugBundleRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[24] + mi := &file_daemon_proto_msgTypes[25] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1742,7 +1861,7 @@ func (x *DebugBundleRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DebugBundleRequest.ProtoReflect.Descriptor instead. func (*DebugBundleRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{24} + return file_daemon_proto_rawDescGZIP(), []int{25} } func (x *DebugBundleRequest) GetAnonymize() bool { @@ -1770,7 +1889,7 @@ type DebugBundleResponse struct { func (x *DebugBundleResponse) Reset() { *x = DebugBundleResponse{} if protoimpl.UnsafeEnabled { - mi := &file_daemon_proto_msgTypes[25] + mi := &file_daemon_proto_msgTypes[26] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1783,7 +1902,7 @@ func (x *DebugBundleResponse) String() string { func (*DebugBundleResponse) ProtoMessage() {} func (x *DebugBundleResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[25] + mi := &file_daemon_proto_msgTypes[26] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1796,7 +1915,7 @@ func (x *DebugBundleResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use DebugBundleResponse.ProtoReflect.Descriptor instead. func (*DebugBundleResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{25} + return file_daemon_proto_rawDescGZIP(), []int{26} } func (x *DebugBundleResponse) GetPath() string { @@ -1815,7 +1934,7 @@ type GetLogLevelRequest struct { func (x *GetLogLevelRequest) Reset() { *x = GetLogLevelRequest{} if protoimpl.UnsafeEnabled { - mi := &file_daemon_proto_msgTypes[26] + mi := &file_daemon_proto_msgTypes[27] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1828,7 +1947,7 @@ func (x *GetLogLevelRequest) String() string { func (*GetLogLevelRequest) ProtoMessage() {} func (x *GetLogLevelRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[26] + mi := &file_daemon_proto_msgTypes[27] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1841,7 +1960,7 @@ func (x *GetLogLevelRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetLogLevelRequest.ProtoReflect.Descriptor instead. func (*GetLogLevelRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{26} + return file_daemon_proto_rawDescGZIP(), []int{27} } type GetLogLevelResponse struct { @@ -1855,7 +1974,7 @@ type GetLogLevelResponse struct { func (x *GetLogLevelResponse) Reset() { *x = GetLogLevelResponse{} if protoimpl.UnsafeEnabled { - mi := &file_daemon_proto_msgTypes[27] + mi := &file_daemon_proto_msgTypes[28] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1868,7 +1987,7 @@ func (x *GetLogLevelResponse) String() string { func (*GetLogLevelResponse) ProtoMessage() {} func (x *GetLogLevelResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[27] + mi := &file_daemon_proto_msgTypes[28] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1881,7 +2000,7 @@ func (x *GetLogLevelResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetLogLevelResponse.ProtoReflect.Descriptor instead. func (*GetLogLevelResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{27} + return file_daemon_proto_rawDescGZIP(), []int{28} } func (x *GetLogLevelResponse) GetLevel() LogLevel { @@ -1902,7 +2021,7 @@ type SetLogLevelRequest struct { func (x *SetLogLevelRequest) Reset() { *x = SetLogLevelRequest{} if protoimpl.UnsafeEnabled { - mi := &file_daemon_proto_msgTypes[28] + mi := &file_daemon_proto_msgTypes[29] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1915,7 +2034,7 @@ func (x *SetLogLevelRequest) String() string { func (*SetLogLevelRequest) ProtoMessage() {} func (x *SetLogLevelRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[28] + mi := &file_daemon_proto_msgTypes[29] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1928,7 +2047,7 @@ func (x *SetLogLevelRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SetLogLevelRequest.ProtoReflect.Descriptor instead. func (*SetLogLevelRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{28} + return file_daemon_proto_rawDescGZIP(), []int{29} } func (x *SetLogLevelRequest) GetLevel() LogLevel { @@ -1947,7 +2066,7 @@ type SetLogLevelResponse struct { func (x *SetLogLevelResponse) Reset() { *x = SetLogLevelResponse{} if protoimpl.UnsafeEnabled { - mi := &file_daemon_proto_msgTypes[29] + mi := &file_daemon_proto_msgTypes[30] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1960,7 +2079,7 @@ func (x *SetLogLevelResponse) String() string { func (*SetLogLevelResponse) ProtoMessage() {} func (x *SetLogLevelResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[29] + mi := &file_daemon_proto_msgTypes[30] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1973,7 +2092,7 @@ func (x *SetLogLevelResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SetLogLevelResponse.ProtoReflect.Descriptor instead. func (*SetLogLevelResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{29} + return file_daemon_proto_rawDescGZIP(), []int{30} } var File_daemon_proto protoreflect.FileDescriptor @@ -1986,7 +2105,7 @@ var file_daemon_proto_rawDesc = []byte{ 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x75, 0x72, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xcf, 0x07, 0x0a, 0x0c, 0x4c, 0x6f, + 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xb0, 0x08, 0x0a, 0x0c, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65, 0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x65, 0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x12, 0x26, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, @@ -2037,186 +2156,221 @@ var file_daemon_proto_rawDesc = []byte{ 0x6b, 0x6c, 0x69, 0x73, 0x74, 0x12, 0x2b, 0x0a, 0x0e, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x6f, 0x6e, 0x69, 0x74, 0x6f, 0x72, 0x18, 0x12, 0x20, 0x01, 0x28, 0x08, 0x48, 0x07, 0x52, 0x0e, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x6f, 0x6e, 0x69, 0x74, 0x6f, 0x72, 0x88, - 0x01, 0x01, 0x42, 0x13, 0x0a, 0x11, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, - 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x69, 0x6e, 0x74, 0x65, - 0x72, 0x66, 0x61, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x77, 0x69, - 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x42, 0x17, 0x0a, 0x15, 0x5f, - 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x50, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, 0x65, - 0x64, 0x4b, 0x65, 0x79, 0x42, 0x15, 0x0a, 0x13, 0x5f, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, - 0x41, 0x75, 0x74, 0x6f, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x42, 0x13, 0x0a, 0x11, 0x5f, - 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x53, 0x53, 0x48, 0x41, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, - 0x42, 0x16, 0x0a, 0x14, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, - 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x42, 0x11, 0x0a, 0x0f, 0x5f, 0x6e, 0x65, 0x74, - 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x6f, 0x6e, 0x69, 0x74, 0x6f, 0x72, 0x22, 0xb5, 0x01, 0x0a, 0x0d, - 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x24, 0x0a, - 0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f, - 0x67, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x12, - 0x28, 0x0a, 0x0f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, - 0x52, 0x49, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, - 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x12, 0x38, 0x0a, 0x17, 0x76, 0x65, 0x72, + 0x01, 0x01, 0x12, 0x4a, 0x0a, 0x10, 0x64, 0x6e, 0x73, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x6e, + 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x18, 0x13, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, + 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, + 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x48, 0x08, 0x52, 0x10, 0x64, 0x6e, 0x73, 0x52, 0x6f, + 0x75, 0x74, 0x65, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x88, 0x01, 0x01, 0x42, 0x13, + 0x0a, 0x11, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, + 0x6c, 0x65, 0x64, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, + 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, + 0x61, 0x72, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x42, 0x17, 0x0a, 0x15, 0x5f, 0x6f, 0x70, 0x74, 0x69, + 0x6f, 0x6e, 0x61, 0x6c, 0x50, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, + 0x42, 0x15, 0x0a, 0x13, 0x5f, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x41, 0x75, 0x74, 0x6f, + 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x42, 0x13, 0x0a, 0x11, 0x5f, 0x73, 0x65, 0x72, 0x76, + 0x65, 0x72, 0x53, 0x53, 0x48, 0x41, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x42, 0x16, 0x0a, 0x14, + 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, + 0x73, 0x69, 0x76, 0x65, 0x42, 0x11, 0x0a, 0x0f, 0x5f, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, + 0x4d, 0x6f, 0x6e, 0x69, 0x74, 0x6f, 0x72, 0x42, 0x13, 0x0a, 0x11, 0x5f, 0x64, 0x6e, 0x73, 0x52, + 0x6f, 0x75, 0x74, 0x65, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x22, 0xb5, 0x01, 0x0a, + 0x0d, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x24, + 0x0a, 0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, + 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, + 0x12, 0x28, 0x0a, 0x0f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x55, 0x52, 0x49, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x76, 0x65, 0x72, 0x69, 0x66, + 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x12, 0x38, 0x0a, 0x17, 0x76, 0x65, + 0x72, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x43, 0x6f, 0x6d, + 0x70, 0x6c, 0x65, 0x74, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x17, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x43, 0x6f, 0x6d, 0x70, - 0x6c, 0x65, 0x74, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x17, 0x76, 0x65, 0x72, 0x69, - 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x43, 0x6f, 0x6d, 0x70, 0x6c, - 0x65, 0x74, 0x65, 0x22, 0x4d, 0x0a, 0x13, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, - 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, - 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, - 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, - 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, - 0x6d, 0x65, 0x22, 0x16, 0x0a, 0x14, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, - 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x0b, 0x0a, 0x09, 0x55, 0x70, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x0c, 0x0a, 0x0a, 0x55, 0x70, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x3d, 0x0a, 0x0d, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x67, 0x65, 0x74, 0x46, 0x75, 0x6c, - 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x11, 0x67, 0x65, 0x74, 0x46, 0x75, 0x6c, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, - 0x61, 0x74, 0x75, 0x73, 0x22, 0x82, 0x01, 0x0a, 0x0e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, - 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, - 0x32, 0x0a, 0x0a, 0x66, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x46, 0x75, 0x6c, - 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x0a, 0x66, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, - 0x74, 0x75, 0x73, 0x12, 0x24, 0x0a, 0x0d, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x56, 0x65, 0x72, - 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x0d, 0x0a, 0x0b, 0x44, 0x6f, 0x77, - 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x0e, 0x0a, 0x0c, 0x44, 0x6f, 0x77, 0x6e, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x12, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xb3, 0x01, 0x0a, - 0x11, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x55, 0x72, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x55, 0x72, 0x6c, 0x12, 0x1e, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x46, 0x69, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x63, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x46, 0x69, 0x6c, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6c, 0x6f, 0x67, 0x46, - 0x69, 0x6c, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6c, 0x6f, 0x67, 0x46, 0x69, - 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, - 0x65, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, - 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x55, - 0x52, 0x4c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x55, - 0x52, 0x4c, 0x22, 0xce, 0x05, 0x0a, 0x09, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, - 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, - 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, 0x6e, - 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x63, 0x6f, - 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x46, 0x0a, 0x10, 0x63, 0x6f, 0x6e, 0x6e, - 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x18, 0x04, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x10, - 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, - 0x12, 0x18, 0x0a, 0x07, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x07, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x69, - 0x72, 0x65, 0x63, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, - 0x63, 0x74, 0x12, 0x34, 0x0a, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, - 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, - 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x36, 0x0a, 0x16, 0x72, 0x65, 0x6d, 0x6f, - 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, - 0x70, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, - 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, - 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, - 0x66, 0x71, 0x64, 0x6e, 0x12, 0x3c, 0x0a, 0x19, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, - 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, - 0x74, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x19, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, - 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, - 0x6e, 0x74, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, - 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, - 0x18, 0x0b, 0x20, 0x01, 0x28, 0x09, 0x52, 0x1a, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, - 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, - 0x6e, 0x74, 0x12, 0x52, 0x0a, 0x16, 0x6c, 0x61, 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, - 0x61, 0x72, 0x64, 0x48, 0x61, 0x6e, 0x64, 0x73, 0x68, 0x61, 0x6b, 0x65, 0x18, 0x0c, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x16, - 0x6c, 0x61, 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x48, 0x61, 0x6e, - 0x64, 0x73, 0x68, 0x61, 0x6b, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, - 0x78, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x03, 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x78, - 0x12, 0x18, 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x18, 0x0e, 0x20, 0x01, 0x28, - 0x03, 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, - 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0f, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, - 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, - 0x18, 0x10, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, - 0x0a, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x18, 0x11, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, - 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x07, 0x6c, 0x61, 0x74, 0x65, - 0x6e, 0x63, 0x79, 0x22, 0xec, 0x01, 0x0a, 0x0e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, - 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x28, - 0x0a, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, - 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, - 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, - 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x2a, 0x0a, 0x10, - 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, - 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, - 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, - 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, - 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, - 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, - 0x75, 0x74, 0x65, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, - 0x65, 0x73, 0x22, 0x53, 0x0a, 0x0b, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, + 0x6c, 0x65, 0x74, 0x65, 0x22, 0x4d, 0x0a, 0x13, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, + 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x75, + 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, + 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, + 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, + 0x61, 0x6d, 0x65, 0x22, 0x16, 0x0a, 0x14, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, + 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x0b, 0x0a, 0x09, 0x55, + 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x0c, 0x0a, 0x0a, 0x55, 0x70, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x3d, 0x0a, 0x0d, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x67, 0x65, 0x74, 0x46, 0x75, + 0x6c, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x11, 0x67, 0x65, 0x74, 0x46, 0x75, 0x6c, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, + 0x74, 0x61, 0x74, 0x75, 0x73, 0x22, 0x82, 0x01, 0x0a, 0x0e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, + 0x75, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, + 0x12, 0x32, 0x0a, 0x0a, 0x66, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x46, 0x75, + 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x0a, 0x66, 0x75, 0x6c, 0x6c, 0x53, 0x74, + 0x61, 0x74, 0x75, 0x73, 0x12, 0x24, 0x0a, 0x0d, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x56, 0x65, + 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x0d, 0x0a, 0x0b, 0x44, 0x6f, + 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x0e, 0x0a, 0x0c, 0x44, 0x6f, 0x77, + 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x12, 0x0a, 0x10, 0x47, 0x65, 0x74, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xb9, 0x03, + 0x0a, 0x11, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x55, 0x72, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x55, 0x72, 0x6c, 0x12, 0x1e, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x46, 0x69, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x63, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x46, 0x69, 0x6c, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6c, 0x6f, 0x67, + 0x46, 0x69, 0x6c, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6c, 0x6f, 0x67, 0x46, + 0x69, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, 0x65, 0x64, + 0x4b, 0x65, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, + 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x61, 0x64, 0x6d, 0x69, 0x6e, + 0x55, 0x52, 0x4c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x61, 0x64, 0x6d, 0x69, 0x6e, + 0x55, 0x52, 0x4c, 0x12, 0x24, 0x0a, 0x0d, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, + 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x69, 0x6e, 0x74, 0x65, + 0x72, 0x66, 0x61, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x77, 0x69, 0x72, + 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x07, 0x20, 0x01, 0x28, 0x03, + 0x52, 0x0d, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x12, + 0x2e, 0x0a, 0x12, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x41, 0x75, 0x74, 0x6f, 0x43, 0x6f, + 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x64, 0x69, 0x73, + 0x61, 0x62, 0x6c, 0x65, 0x41, 0x75, 0x74, 0x6f, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x12, + 0x2a, 0x0a, 0x10, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x53, 0x53, 0x48, 0x41, 0x6c, 0x6c, 0x6f, + 0x77, 0x65, 0x64, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x73, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x53, 0x53, 0x48, 0x41, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x12, 0x2a, 0x0a, 0x10, 0x72, + 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, + 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, + 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, + 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, 0x0c, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, + 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x22, 0xce, 0x05, 0x0a, 0x09, 0x50, 0x65, + 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, + 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, + 0x1e, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x0a, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, + 0x46, 0x0a, 0x10, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, + 0x61, 0x74, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, + 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, + 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x10, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, + 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x72, 0x65, 0x6c, 0x61, 0x79, + 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x65, + 0x64, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x12, 0x34, 0x0a, 0x15, 0x6c, 0x6f, 0x63, + 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, + 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, + 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, + 0x36, 0x0a, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, + 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, + 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, + 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x3c, 0x0a, 0x19, 0x6c, + 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, + 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x19, + 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, + 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x65, 0x6d, + 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, + 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x09, 0x52, 0x1a, 0x72, + 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, + 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x52, 0x0a, 0x16, 0x6c, 0x61, 0x73, + 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x48, 0x61, 0x6e, 0x64, 0x73, 0x68, + 0x61, 0x6b, 0x65, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, + 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, + 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x16, 0x6c, 0x61, 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, + 0x75, 0x61, 0x72, 0x64, 0x48, 0x61, 0x6e, 0x64, 0x73, 0x68, 0x61, 0x6b, 0x65, 0x12, 0x18, 0x0a, + 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x78, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x03, 0x52, 0x07, + 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x78, 0x12, 0x18, 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, + 0x54, 0x78, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x03, 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, + 0x78, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, + 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, + 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x16, 0x0a, + 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x10, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, + 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, + 0x18, 0x11, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x52, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x22, 0xec, 0x01, 0x0a, 0x0e, 0x4c, + 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, + 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, + 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, + 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, + 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f, + 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x12, + 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, + 0x71, 0x64, 0x6e, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, + 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, + 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, + 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, + 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, + 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, + 0x65, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, + 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22, 0x53, 0x0a, 0x0b, 0x53, 0x69, 0x67, + 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, + 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, + 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, + 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x57, + 0x0a, 0x0f, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x57, 0x0a, 0x0f, 0x4d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, - 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, - 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, - 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, - 0x22, 0x52, 0x0a, 0x0a, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, - 0x0a, 0x03, 0x55, 0x52, 0x49, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x49, - 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x14, - 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, - 0x72, 0x72, 0x6f, 0x72, 0x22, 0x72, 0x0a, 0x0c, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, - 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, - 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, - 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, - 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, - 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0xd2, 0x02, 0x0a, 0x0a, 0x46, 0x75, 0x6c, - 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x41, 0x0a, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, - 0x32, 0x17, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x35, 0x0a, 0x0b, 0x73, 0x69, - 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x52, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x12, 0x3e, 0x0a, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, - 0x61, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x52, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x12, 0x27, 0x0a, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, - 0x32, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, - 0x61, 0x74, 0x65, 0x52, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2a, 0x0a, 0x06, 0x72, 0x65, - 0x6c, 0x61, 0x79, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, - 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x12, 0x35, 0x0a, 0x0b, 0x64, 0x6e, 0x73, 0x5f, 0x73, 0x65, - 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x52, 0x0a, 0x64, 0x6e, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x22, 0x13, 0x0a, - 0x11, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, - 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22, - 0x5b, 0x0a, 0x13, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x49, - 0x44, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x49, - 0x44, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, - 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x16, 0x0a, 0x14, - 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x4d, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, - 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, - 0x07, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, - 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65, 0x6c, 0x65, 0x63, - 0x74, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x73, 0x65, 0x6c, 0x65, 0x63, - 0x74, 0x65, 0x64, 0x22, 0x4a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, + 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x52, 0x0a, 0x0a, 0x52, 0x65, 0x6c, 0x61, 0x79, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x49, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x49, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, + 0x61, 0x62, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x76, 0x61, 0x69, + 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x72, 0x0a, 0x0c, 0x4e, + 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, + 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x73, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, + 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, + 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, + 0x6f, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, + 0xd2, 0x02, 0x0a, 0x0a, 0x46, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x41, + 0x0a, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, + 0x52, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x12, 0x35, 0x0a, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0b, 0x73, 0x69, 0x67, + 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x3e, 0x0a, 0x0e, 0x6c, 0x6f, 0x63, 0x61, + 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, + 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, + 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x27, 0x0a, 0x05, 0x70, 0x65, 0x65, 0x72, + 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x05, 0x70, 0x65, 0x65, 0x72, + 0x73, 0x12, 0x2a, 0x0a, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, + 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x12, 0x35, 0x0a, + 0x0b, 0x64, 0x6e, 0x73, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4e, 0x53, 0x47, 0x72, + 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0a, 0x64, 0x6e, 0x73, 0x53, 0x65, 0x72, + 0x76, 0x65, 0x72, 0x73, 0x22, 0x13, 0x0a, 0x11, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, + 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, + 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, + 0x25, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x0d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, + 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22, 0x5b, 0x0a, 0x13, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, + 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, + 0x08, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, + 0x08, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x70, 0x70, + 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, + 0x64, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, + 0x61, 0x6c, 0x6c, 0x22, 0x16, 0x0a, 0x14, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, + 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1a, 0x0a, 0x06, 0x49, + 0x50, 0x4c, 0x69, 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x69, 0x70, 0x73, 0x18, 0x01, 0x20, 0x03, + 0x28, 0x09, 0x52, 0x03, 0x69, 0x70, 0x73, 0x22, 0xf9, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, + 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, + 0x44, 0x12, 0x18, 0x0a, 0x07, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x07, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x1a, 0x0a, 0x08, 0x73, + 0x65, 0x6c, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x73, + 0x65, 0x6c, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x73, 0x12, 0x40, 0x0a, 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, + 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x52, 0x6f, 0x75, 0x74, 0x65, 0x2e, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, + 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, + 0x49, 0x50, 0x73, 0x1a, 0x4e, 0x0a, 0x10, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, + 0x50, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x24, 0x0a, 0x05, 0x76, 0x61, 0x6c, + 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x49, 0x50, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, + 0x02, 0x38, 0x01, 0x22, 0x4a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x6e, 0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x6e, 0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, @@ -2309,85 +2463,90 @@ func file_daemon_proto_rawDescGZIP() []byte { } var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 30) +var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 32) var file_daemon_proto_goTypes = []interface{}{ - (LogLevel)(0), // 0: daemon.LogLevel - (*LoginRequest)(nil), // 1: daemon.LoginRequest - (*LoginResponse)(nil), // 2: daemon.LoginResponse - (*WaitSSOLoginRequest)(nil), // 3: daemon.WaitSSOLoginRequest - (*WaitSSOLoginResponse)(nil), // 4: daemon.WaitSSOLoginResponse - (*UpRequest)(nil), // 5: daemon.UpRequest - (*UpResponse)(nil), // 6: daemon.UpResponse - (*StatusRequest)(nil), // 7: daemon.StatusRequest - (*StatusResponse)(nil), // 8: daemon.StatusResponse - (*DownRequest)(nil), // 9: daemon.DownRequest - (*DownResponse)(nil), // 10: daemon.DownResponse - (*GetConfigRequest)(nil), // 11: daemon.GetConfigRequest - (*GetConfigResponse)(nil), // 12: daemon.GetConfigResponse - (*PeerState)(nil), // 13: daemon.PeerState - (*LocalPeerState)(nil), // 14: daemon.LocalPeerState - (*SignalState)(nil), // 15: daemon.SignalState - (*ManagementState)(nil), // 16: daemon.ManagementState - (*RelayState)(nil), // 17: daemon.RelayState - (*NSGroupState)(nil), // 18: daemon.NSGroupState - (*FullStatus)(nil), // 19: daemon.FullStatus - (*ListRoutesRequest)(nil), // 20: daemon.ListRoutesRequest - (*ListRoutesResponse)(nil), // 21: daemon.ListRoutesResponse - (*SelectRoutesRequest)(nil), // 22: daemon.SelectRoutesRequest - (*SelectRoutesResponse)(nil), // 23: daemon.SelectRoutesResponse - (*Route)(nil), // 24: daemon.Route - (*DebugBundleRequest)(nil), // 25: daemon.DebugBundleRequest - (*DebugBundleResponse)(nil), // 26: daemon.DebugBundleResponse - (*GetLogLevelRequest)(nil), // 27: daemon.GetLogLevelRequest - (*GetLogLevelResponse)(nil), // 28: daemon.GetLogLevelResponse - (*SetLogLevelRequest)(nil), // 29: daemon.SetLogLevelRequest - (*SetLogLevelResponse)(nil), // 30: daemon.SetLogLevelResponse - (*timestamp.Timestamp)(nil), // 31: google.protobuf.Timestamp - (*duration.Duration)(nil), // 32: google.protobuf.Duration + (LogLevel)(0), // 0: daemon.LogLevel + (*LoginRequest)(nil), // 1: daemon.LoginRequest + (*LoginResponse)(nil), // 2: daemon.LoginResponse + (*WaitSSOLoginRequest)(nil), // 3: daemon.WaitSSOLoginRequest + (*WaitSSOLoginResponse)(nil), // 4: daemon.WaitSSOLoginResponse + (*UpRequest)(nil), // 5: daemon.UpRequest + (*UpResponse)(nil), // 6: daemon.UpResponse + (*StatusRequest)(nil), // 7: daemon.StatusRequest + (*StatusResponse)(nil), // 8: daemon.StatusResponse + (*DownRequest)(nil), // 9: daemon.DownRequest + (*DownResponse)(nil), // 10: daemon.DownResponse + (*GetConfigRequest)(nil), // 11: daemon.GetConfigRequest + (*GetConfigResponse)(nil), // 12: daemon.GetConfigResponse + (*PeerState)(nil), // 13: daemon.PeerState + (*LocalPeerState)(nil), // 14: daemon.LocalPeerState + (*SignalState)(nil), // 15: daemon.SignalState + (*ManagementState)(nil), // 16: daemon.ManagementState + (*RelayState)(nil), // 17: daemon.RelayState + (*NSGroupState)(nil), // 18: daemon.NSGroupState + (*FullStatus)(nil), // 19: daemon.FullStatus + (*ListRoutesRequest)(nil), // 20: daemon.ListRoutesRequest + (*ListRoutesResponse)(nil), // 21: daemon.ListRoutesResponse + (*SelectRoutesRequest)(nil), // 22: daemon.SelectRoutesRequest + (*SelectRoutesResponse)(nil), // 23: daemon.SelectRoutesResponse + (*IPList)(nil), // 24: daemon.IPList + (*Route)(nil), // 25: daemon.Route + (*DebugBundleRequest)(nil), // 26: daemon.DebugBundleRequest + (*DebugBundleResponse)(nil), // 27: daemon.DebugBundleResponse + (*GetLogLevelRequest)(nil), // 28: daemon.GetLogLevelRequest + (*GetLogLevelResponse)(nil), // 29: daemon.GetLogLevelResponse + (*SetLogLevelRequest)(nil), // 30: daemon.SetLogLevelRequest + (*SetLogLevelResponse)(nil), // 31: daemon.SetLogLevelResponse + nil, // 32: daemon.Route.ResolvedIPsEntry + (*durationpb.Duration)(nil), // 33: google.protobuf.Duration + (*timestamppb.Timestamp)(nil), // 34: google.protobuf.Timestamp } var file_daemon_proto_depIdxs = []int32{ - 19, // 0: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus - 31, // 1: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp - 31, // 2: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp - 32, // 3: daemon.PeerState.latency:type_name -> google.protobuf.Duration - 16, // 4: daemon.FullStatus.managementState:type_name -> daemon.ManagementState - 15, // 5: daemon.FullStatus.signalState:type_name -> daemon.SignalState - 14, // 6: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState - 13, // 7: daemon.FullStatus.peers:type_name -> daemon.PeerState - 17, // 8: daemon.FullStatus.relays:type_name -> daemon.RelayState - 18, // 9: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState - 24, // 10: daemon.ListRoutesResponse.routes:type_name -> daemon.Route - 0, // 11: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel - 0, // 12: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel - 1, // 13: daemon.DaemonService.Login:input_type -> daemon.LoginRequest - 3, // 14: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest - 5, // 15: daemon.DaemonService.Up:input_type -> daemon.UpRequest - 7, // 16: daemon.DaemonService.Status:input_type -> daemon.StatusRequest - 9, // 17: daemon.DaemonService.Down:input_type -> daemon.DownRequest - 11, // 18: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest - 20, // 19: daemon.DaemonService.ListRoutes:input_type -> daemon.ListRoutesRequest - 22, // 20: daemon.DaemonService.SelectRoutes:input_type -> daemon.SelectRoutesRequest - 22, // 21: daemon.DaemonService.DeselectRoutes:input_type -> daemon.SelectRoutesRequest - 25, // 22: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest - 27, // 23: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest - 29, // 24: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest - 2, // 25: daemon.DaemonService.Login:output_type -> daemon.LoginResponse - 4, // 26: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse - 6, // 27: daemon.DaemonService.Up:output_type -> daemon.UpResponse - 8, // 28: daemon.DaemonService.Status:output_type -> daemon.StatusResponse - 10, // 29: daemon.DaemonService.Down:output_type -> daemon.DownResponse - 12, // 30: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse - 21, // 31: daemon.DaemonService.ListRoutes:output_type -> daemon.ListRoutesResponse - 23, // 32: daemon.DaemonService.SelectRoutes:output_type -> daemon.SelectRoutesResponse - 23, // 33: daemon.DaemonService.DeselectRoutes:output_type -> daemon.SelectRoutesResponse - 26, // 34: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse - 28, // 35: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse - 30, // 36: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse - 25, // [25:37] is the sub-list for method output_type - 13, // [13:25] is the sub-list for method input_type - 13, // [13:13] is the sub-list for extension type_name - 13, // [13:13] is the sub-list for extension extendee - 0, // [0:13] is the sub-list for field type_name + 33, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 19, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus + 34, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp + 34, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp + 33, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration + 16, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState + 15, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState + 14, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState + 13, // 8: daemon.FullStatus.peers:type_name -> daemon.PeerState + 17, // 9: daemon.FullStatus.relays:type_name -> daemon.RelayState + 18, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState + 25, // 11: daemon.ListRoutesResponse.routes:type_name -> daemon.Route + 32, // 12: daemon.Route.resolvedIPs:type_name -> daemon.Route.ResolvedIPsEntry + 0, // 13: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel + 0, // 14: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel + 24, // 15: daemon.Route.ResolvedIPsEntry.value:type_name -> daemon.IPList + 1, // 16: daemon.DaemonService.Login:input_type -> daemon.LoginRequest + 3, // 17: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest + 5, // 18: daemon.DaemonService.Up:input_type -> daemon.UpRequest + 7, // 19: daemon.DaemonService.Status:input_type -> daemon.StatusRequest + 9, // 20: daemon.DaemonService.Down:input_type -> daemon.DownRequest + 11, // 21: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest + 20, // 22: daemon.DaemonService.ListRoutes:input_type -> daemon.ListRoutesRequest + 22, // 23: daemon.DaemonService.SelectRoutes:input_type -> daemon.SelectRoutesRequest + 22, // 24: daemon.DaemonService.DeselectRoutes:input_type -> daemon.SelectRoutesRequest + 26, // 25: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest + 28, // 26: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest + 30, // 27: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest + 2, // 28: daemon.DaemonService.Login:output_type -> daemon.LoginResponse + 4, // 29: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse + 6, // 30: daemon.DaemonService.Up:output_type -> daemon.UpResponse + 8, // 31: daemon.DaemonService.Status:output_type -> daemon.StatusResponse + 10, // 32: daemon.DaemonService.Down:output_type -> daemon.DownResponse + 12, // 33: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse + 21, // 34: daemon.DaemonService.ListRoutes:output_type -> daemon.ListRoutesResponse + 23, // 35: daemon.DaemonService.SelectRoutes:output_type -> daemon.SelectRoutesResponse + 23, // 36: daemon.DaemonService.DeselectRoutes:output_type -> daemon.SelectRoutesResponse + 27, // 37: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse + 29, // 38: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse + 31, // 39: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse + 28, // [28:40] is the sub-list for method output_type + 16, // [16:28] is the sub-list for method input_type + 16, // [16:16] is the sub-list for extension type_name + 16, // [16:16] is the sub-list for extension extendee + 0, // [0:16] is the sub-list for field type_name } func init() { file_daemon_proto_init() } @@ -2673,7 +2832,7 @@ func file_daemon_proto_init() { } } file_daemon_proto_msgTypes[23].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Route); i { + switch v := v.(*IPList); i { case 0: return &v.state case 1: @@ -2685,7 +2844,7 @@ func file_daemon_proto_init() { } } file_daemon_proto_msgTypes[24].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DebugBundleRequest); i { + switch v := v.(*Route); i { case 0: return &v.state case 1: @@ -2697,7 +2856,7 @@ func file_daemon_proto_init() { } } file_daemon_proto_msgTypes[25].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DebugBundleResponse); i { + switch v := v.(*DebugBundleRequest); i { case 0: return &v.state case 1: @@ -2709,7 +2868,7 @@ func file_daemon_proto_init() { } } file_daemon_proto_msgTypes[26].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetLogLevelRequest); i { + switch v := v.(*DebugBundleResponse); i { case 0: return &v.state case 1: @@ -2721,7 +2880,7 @@ func file_daemon_proto_init() { } } file_daemon_proto_msgTypes[27].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetLogLevelResponse); i { + switch v := v.(*GetLogLevelRequest); i { case 0: return &v.state case 1: @@ -2733,7 +2892,7 @@ func file_daemon_proto_init() { } } file_daemon_proto_msgTypes[28].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SetLogLevelRequest); i { + switch v := v.(*GetLogLevelResponse); i { case 0: return &v.state case 1: @@ -2745,6 +2904,18 @@ func file_daemon_proto_init() { } } file_daemon_proto_msgTypes[29].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SetLogLevelRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_daemon_proto_msgTypes[30].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*SetLogLevelResponse); i { case 0: return &v.state @@ -2764,7 +2935,7 @@ func file_daemon_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_daemon_proto_rawDesc, NumEnums: 1, - NumMessages: 30, + NumMessages: 32, NumExtensions: 0, NumServices: 1, }, diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index e39b08bc3..267eec279 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -92,6 +92,8 @@ message LoginRequest { repeated string extraIFaceBlacklist = 17; optional bool networkMonitor = 18; + + optional google.protobuf.Duration dnsRouteInterval = 19; } message LoginResponse { @@ -145,6 +147,18 @@ message GetConfigResponse { // adminURL settings value. string adminURL = 5; + + string interfaceName = 6; + + int64 wireguardPort = 7; + + bool disableAutoConnect = 9; + + bool serverSSHAllowed = 10; + + bool rosenpassEnabled = 11; + + bool rosenpassPermissive = 12; } // PeerState contains the latest state of a peer @@ -233,10 +247,17 @@ message SelectRoutesRequest { message SelectRoutesResponse { } +message IPList { + repeated string ips = 1; +} + + message Route { string ID = 1; string network = 2; bool selected = 3; + repeated string domains = 4; + map resolvedIPs = 5; } message DebugBundleRequest { diff --git a/client/server/route.go b/client/server/route.go index 4c63cea93..d70e0dca3 100644 --- a/client/server/route.go +++ b/client/server/route.go @@ -9,17 +9,19 @@ import ( "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/route" ) type selectRoute struct { NetID route.NetID Network netip.Prefix + Domains domain.List Selected bool } // ListRoutes returns a list of all available routes. -func (s *Server) ListRoutes(ctx context.Context, req *proto.ListRoutesRequest) (*proto.ListRoutesResponse, error) { +func (s *Server) ListRoutes(context.Context, *proto.ListRoutesRequest) (*proto.ListRoutesResponse, error) { s.mutex.Lock() defer s.mutex.Unlock() @@ -43,6 +45,7 @@ func (s *Server) ListRoutes(ctx context.Context, req *proto.ListRoutesRequest) ( route := &selectRoute{ NetID: id, Network: rt[0].Network, + Domains: rt[0].Domains, Selected: routeSelector.IsSelected(id), } routes = append(routes, route) @@ -63,13 +66,29 @@ func (s *Server) ListRoutes(ctx context.Context, req *proto.ListRoutesRequest) ( return iPrefix < jPrefix }) + resolvedDomains := s.statusRecorder.GetResolvedDomainsStates() var pbRoutes []*proto.Route for _, route := range routes { - pbRoutes = append(pbRoutes, &proto.Route{ - ID: string(route.NetID), - Network: route.Network.String(), - Selected: route.Selected, - }) + pbRoute := &proto.Route{ + ID: string(route.NetID), + Network: route.Network.String(), + Domains: route.Domains.ToSafeStringList(), + ResolvedIPs: map[string]*proto.IPList{}, + Selected: route.Selected, + } + + for _, domain := range route.Domains { + if prefixes, exists := resolvedDomains[domain]; exists { + var ipStrings []string + for _, prefix := range prefixes { + ipStrings = append(ipStrings, prefix.Addr().String()) + } + pbRoute.ResolvedIPs[string(domain)] = &proto.IPList{ + Ips: ipStrings, + } + } + } + pbRoutes = append(pbRoutes, pbRoute) } return &proto.ListRoutesResponse{ diff --git a/client/server/server.go b/client/server/server.go index a59cffd14..2805c10f4 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -365,6 +365,12 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro s.latestConfigInput.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist } + if msg.DnsRouteInterval != nil { + duration := msg.DnsRouteInterval.AsDuration() + inputConfig.DNSRouteInterval = &duration + s.latestConfigInput.DNSRouteInterval = &duration + } + s.mutex.Unlock() if msg.OptionalPreSharedKey != nil { @@ -662,11 +668,17 @@ func (s *Server) GetConfig(_ context.Context, _ *proto.GetConfigRequest) (*proto } return &proto.GetConfigResponse{ - ManagementUrl: managementURL, - AdminURL: adminURL, - ConfigFile: s.latestConfigInput.ConfigPath, - LogFile: s.logFile, - PreSharedKey: preSharedKey, + ManagementUrl: managementURL, + ConfigFile: s.latestConfigInput.ConfigPath, + LogFile: s.logFile, + PreSharedKey: preSharedKey, + AdminURL: adminURL, + InterfaceName: s.config.WgIface, + WireguardPort: int64(s.config.WgPort), + DisableAutoConnect: s.config.DisableAutoConnect, + ServerSSHAllowed: *s.config.ServerSSHAllowed, + RosenpassEnabled: s.config.RosenpassEnabled, + RosenpassPermissive: s.config.RosenpassPermissive, }, nil } func (s *Server) onSessionExpire() { diff --git a/client/server/server_test.go b/client/server/server_test.go index 2337e972d..47738cb84 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -7,6 +7,8 @@ import ( "time" "github.com/netbirdio/management-integrations/integrations" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" log "github.com/sirupsen/logrus" "google.golang.org/grpc" @@ -39,7 +41,7 @@ var ( // we will use a management server started via to simulate the server and capture the number of retries func TestConnectWithRetryRuns(t *testing.T) { // start the signal server - _, signalAddr, err := startSignal() + _, signalAddr, err := startSignal(t) if err != nil { t.Fatalf("failed to start signal server: %v", err) } @@ -106,7 +108,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve return nil, "", err } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := server.NewTestStoreFromJson(config.Datadir) + store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir) if err != nil { return nil, "", err } @@ -117,13 +119,13 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve if err != nil { return nil, "", err } - ia, _ := integrations.NewIntegratedValidator(eventStore) - accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia) + ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) + accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia) if err != nil { return nil, "", err } turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "") - mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil) if err != nil { return nil, "", err } @@ -141,7 +143,9 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve return s, lis.Addr().String(), nil } -func startSignal() (*grpc.Server, string, error) { +func startSignal(t *testing.T) (*grpc.Server, string, error) { + t.Helper() + s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) lis, err := net.Listen("tcp", "localhost:0") @@ -149,7 +153,9 @@ func startSignal() (*grpc.Server, string, error) { log.Fatalf("failed to listen: %v", err) } - proto.RegisterSignalExchangeServer(s, signalServer.NewServer()) + srv, err := signalServer.NewServer(otel.Meter("")) + require.NoError(t, err) + proto.RegisterSignalExchangeServer(s, srv) go func() { if err = s.Serve(lis); err != nil { diff --git a/client/ssh/window_freebsd.go b/client/ssh/window_freebsd.go new file mode 100644 index 000000000..ef4848341 --- /dev/null +++ b/client/ssh/window_freebsd.go @@ -0,0 +1,10 @@ +//go:build freebsd + +package ssh + +import ( + "os" +) + +func setWinSize(file *os.File, width, height int) { +} diff --git a/client/system/info.go b/client/system/info.go index e2e057206..2af2e637b 100644 --- a/client/system/info.go +++ b/client/system/info.go @@ -8,6 +8,7 @@ import ( "google.golang.org/grpc/metadata" + "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/version" ) @@ -33,6 +34,12 @@ type Environment struct { Platform string } +type File struct { + Path string + Exist bool + ProcessIsRunning bool +} + // Info is an object that contains machine information // Most of the code is taken from https://github.com/matishsiao/goInfo type Info struct { @@ -51,6 +58,7 @@ type Info struct { SystemProductName string SystemManufacturer string Environment Environment + Files []File // for posture checks } // extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context @@ -132,3 +140,21 @@ func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool { } return false } + +// GetInfoWithChecks retrieves and parses the system information with applied checks. +func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) { + processCheckPaths := make([]string, 0) + for _, check := range checks { + processCheckPaths = append(processCheckPaths, check.GetFiles()...) + } + + files, err := checkFileAndProcess(processCheckPaths) + if err != nil { + return nil, err + } + + info := GetInfo(ctx) + info.Files = files + + return info, nil +} diff --git a/client/system/info_android.go b/client/system/info_android.go index 7f5dd371b..7718da913 100644 --- a/client/system/info_android.go +++ b/client/system/info_android.go @@ -32,7 +32,7 @@ func GetInfo(ctx context.Context) *Info { GoOS: runtime.GOOS, Kernel: kernel, Platform: "unknown", - OS: "android", + OS: "Android", OSVersion: osVersion(), Hostname: extractDeviceName(ctx, "android"), CPUs: runtime.NumCPU(), @@ -44,6 +44,11 @@ func GetInfo(ctx context.Context) *Info { return gio } +// checkFileAndProcess checks if the file path exists and if a process is running at that path. +func checkFileAndProcess(paths []string) ([]File, error) { + return []File{}, nil +} + func uname() []string { res := run("/system/bin/uname", "-a") return strings.Split(res, " ") @@ -72,5 +77,6 @@ func run(name string, arg ...string) string { if err != nil { log.Errorf("getInfo: %s", err) } - return out.String() + + return strings.TrimSpace(out.String()) } diff --git a/client/system/info_freebsd.go b/client/system/info_freebsd.go index b44fdee7c..454e58a0b 100644 --- a/client/system/info_freebsd.go +++ b/client/system/info_freebsd.go @@ -1,15 +1,18 @@ +//go:build freebsd + package system import ( "bytes" "context" - "fmt" "os" "os/exec" "runtime" "strings" "time" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/system/detect_cloud" "github.com/netbirdio/netbird/client/system/detect_platform" "github.com/netbirdio/netbird/version" @@ -22,8 +25,8 @@ func GetInfo(ctx context.Context) *Info { out = _getInfo() time.Sleep(500 * time.Millisecond) } - osStr := strings.Replace(out, "\n", "", -1) - osStr = strings.Replace(osStr, "\r\n", "", -1) + osStr := strings.ReplaceAll(out, "\n", "") + osStr = strings.ReplaceAll(osStr, "\r\n", "") osInfo := strings.Split(osStr, " ") env := Environment{ @@ -31,14 +34,23 @@ func GetInfo(ctx context.Context) *Info { Platform: detect_platform.Detect(ctx), } - gio := &Info{Kernel: osInfo[0], Platform: runtime.GOARCH, OS: osInfo[2], GoOS: runtime.GOOS, CPUs: runtime.NumCPU(), KernelVersion: osInfo[1], Environment: env} + osName, osVersion := readOsReleaseFile() systemHostname, _ := os.Hostname() - gio.Hostname = extractDeviceName(ctx, systemHostname) - gio.WiretrusteeVersion = version.NetbirdVersion() - gio.UIVersion = extractUserAgent(ctx) - return gio + return &Info{ + GoOS: runtime.GOOS, + Kernel: osInfo[0], + Platform: runtime.GOARCH, + OS: osName, + OSVersion: osVersion, + Hostname: extractDeviceName(ctx, systemHostname), + CPUs: runtime.NumCPU(), + WiretrusteeVersion: version.NetbirdVersion(), + UIVersion: extractUserAgent(ctx), + KernelVersion: osInfo[1], + Environment: env, + } } func _getInfo() string { @@ -50,7 +62,8 @@ func _getInfo() string { cmd.Stderr = &stderr err := cmd.Run() if err != nil { - fmt.Println("getInfo:", err) + log.Warnf("getInfo: %s", err) } + return out.String() } diff --git a/client/system/info_ios.go b/client/system/info_ios.go index e1c291ef5..3dbf50e1e 100644 --- a/client/system/info_ios.go +++ b/client/system/info_ios.go @@ -25,6 +25,11 @@ func GetInfo(ctx context.Context) *Info { return gio } +// checkFileAndProcess checks if the file path exists and if a process is running at that path. +func checkFileAndProcess(paths []string) ([]File, error) { + return []File{}, nil +} + // extractOsVersion extracts operating system version from context or returns the default func extractOsVersion(ctx context.Context, defaultName string) string { v, ok := ctx.Value(OsVersionCtxKey).(string) diff --git a/client/system/info_linux.go b/client/system/info_linux.go index 652bc1115..d85a6faec 100644 --- a/client/system/info_linux.go +++ b/client/system/info_linux.go @@ -28,28 +28,11 @@ func GetInfo(ctx context.Context) *Info { time.Sleep(500 * time.Millisecond) } - releaseInfo := _getReleaseInfo() - for strings.Contains(info, "broken pipe") { - releaseInfo = _getReleaseInfo() - time.Sleep(500 * time.Millisecond) - } - - osRelease := strings.Split(releaseInfo, "\n") - var osName string - var osVer string - for _, s := range osRelease { - if strings.HasPrefix(s, "NAME=") { - osName = strings.Split(s, "=")[1] - osName = strings.ReplaceAll(osName, "\"", "") - } else if strings.HasPrefix(s, "VERSION_ID=") { - osVer = strings.Split(s, "=")[1] - osVer = strings.ReplaceAll(osVer, "\"", "") - } - } - osStr := strings.ReplaceAll(info, "\n", "") osStr = strings.ReplaceAll(osStr, "\r\n", "") osInfo := strings.Split(osStr, " ") + + osName, osVersion := readOsReleaseFile() if osName == "" { osName = osInfo[3] } @@ -72,7 +55,7 @@ func GetInfo(ctx context.Context) *Info { Kernel: osInfo[0], Platform: osInfo[2], OS: osName, - OSVersion: osVer, + OSVersion: osVersion, Hostname: extractDeviceName(ctx, systemHostname), GoOS: runtime.GOOS, CPUs: runtime.NumCPU(), @@ -103,22 +86,12 @@ func _getInfo() string { return out.String() } -func _getReleaseInfo() string { - cmd := exec.Command("cat", "/etc/os-release") - cmd.Stdin = strings.NewReader("some") - var out bytes.Buffer - var stderr bytes.Buffer - cmd.Stdout = &out - cmd.Stderr = &stderr - err := cmd.Run() - if err != nil { - log.Warnf("geucwReleaseInfo: %s", err) - } - return out.String() -} - func sysInfo() (serialNumber string, productName string, manufacturer string) { var si sysinfo.SysInfo si.GetSysInfo() - return si.Chassis.Serial, si.Product.Name, si.Product.Vendor + serial := si.Chassis.Serial + if (serial == "Default string" || serial == "") && si.Product.Serial != "" { + serial = si.Product.Serial + } + return serial, si.Product.Name, si.Product.Vendor } diff --git a/client/system/osrelease_unix.go b/client/system/osrelease_unix.go new file mode 100644 index 000000000..851633248 --- /dev/null +++ b/client/system/osrelease_unix.go @@ -0,0 +1,38 @@ +//go:build (linux && !android) || freebsd + +package system + +import ( + "bufio" + "os" + "strings" + + log "github.com/sirupsen/logrus" +) + +func readOsReleaseFile() (osName string, osVer string) { + file, err := os.Open("/etc/os-release") + if err != nil { + log.Warnf("failed to open file /etc/os-release: %s", err) + return "", "" + } + defer file.Close() + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "NAME=") { + osName = strings.ReplaceAll(strings.Split(line, "=")[1], "\"", "") + continue + } + if strings.HasPrefix(line, "VERSION_ID=") { + osVer = strings.ReplaceAll(strings.Split(line, "=")[1], "\"", "") + continue + } + + if osName != "" && osVer != "" { + break + } + } + return +} diff --git a/client/system/process.go b/client/system/process.go new file mode 100644 index 000000000..2e43fcfe0 --- /dev/null +++ b/client/system/process.go @@ -0,0 +1,58 @@ +//go:build windows || (linux && !android) || (darwin && !ios) || freebsd + +package system + +import ( + "os" + "slices" + + "github.com/shirou/gopsutil/v3/process" +) + +// getRunningProcesses returns a list of running process paths. +func getRunningProcesses() ([]string, error) { + processes, err := process.Processes() + if err != nil { + return nil, err + } + + processMap := make(map[string]bool) + for _, p := range processes { + path, _ := p.Exe() + if path != "" { + processMap[path] = true + } + } + + uniqueProcesses := make([]string, 0, len(processMap)) + for p := range processMap { + uniqueProcesses = append(uniqueProcesses, p) + } + + return uniqueProcesses, nil +} + +// checkFileAndProcess checks if the file path exists and if a process is running at that path. +func checkFileAndProcess(paths []string) ([]File, error) { + files := make([]File, len(paths)) + if len(paths) == 0 { + return files, nil + } + + runningProcesses, err := getRunningProcesses() + if err != nil { + return nil, err + } + + for i, path := range paths { + file := File{Path: path} + + _, err := os.Stat(path) + file.Exist = !os.IsNotExist(err) + + file.ProcessIsRunning = slices.Contains(runningProcesses, path) + files[i] = file + } + + return files, nil +} diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 7b1e0320a..cadd14f18 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -1,10 +1,11 @@ -//go:build !(linux && 386) +//go:build !(linux && 386) && !freebsd package main import ( "context" _ "embed" + "errors" "flag" "fmt" "os" @@ -79,6 +80,7 @@ func main() { log.Errorf("check PID file: %v", err) return } + client.setDefaultFonts() systray.Run(client.onTrayReady, client.onTrayExit) } } @@ -125,44 +127,55 @@ type serviceClient struct { icUpdateCloud []byte // systray menu items - mStatus *systray.MenuItem - mUp *systray.MenuItem - mDown *systray.MenuItem - mAdminPanel *systray.MenuItem - mSettings *systray.MenuItem - mAbout *systray.MenuItem - mVersionUI *systray.MenuItem - mVersionDaemon *systray.MenuItem - mUpdate *systray.MenuItem - mQuit *systray.MenuItem - mRoutes *systray.MenuItem + mStatus *systray.MenuItem + mUp *systray.MenuItem + mDown *systray.MenuItem + mAdminPanel *systray.MenuItem + mSettings *systray.MenuItem + mAbout *systray.MenuItem + mVersionUI *systray.MenuItem + mVersionDaemon *systray.MenuItem + mUpdate *systray.MenuItem + mQuit *systray.MenuItem + mRoutes *systray.MenuItem + mAllowSSH *systray.MenuItem + mAutoConnect *systray.MenuItem + mEnableRosenpass *systray.MenuItem + mAdvancedSettings *systray.MenuItem // application with main windows. - app fyne.App - wSettings fyne.Window - showSettings bool - sendNotification bool + app fyne.App + wSettings fyne.Window + showAdvancedSettings bool + sendNotification bool // input elements for settings form - iMngURL *widget.Entry - iAdminURL *widget.Entry - iConfigFile *widget.Entry - iLogFile *widget.Entry - iPreSharedKey *widget.Entry + iMngURL *widget.Entry + iAdminURL *widget.Entry + iConfigFile *widget.Entry + iLogFile *widget.Entry + iPreSharedKey *widget.Entry + iInterfaceName *widget.Entry + iInterfacePort *widget.Entry + + // switch elements for settings form + sRosenpassPermissive *widget.Check // observable settings over corresponding iMngURL and iPreSharedKey values. - managementURL string - preSharedKey string - adminURL string + managementURL string + preSharedKey string + adminURL string + RosenpassPermissive bool + interfaceName string + interfacePort int connected bool update *version.Update daemonVersion string updateIndicationLock sync.Mutex isUpdateIconActive bool - - showRoutes bool - wRoutes fyne.Window + showRoutes bool + wRoutes fyne.Window } // newServiceClient instance constructor @@ -175,9 +188,9 @@ func newServiceClient(addr string, a fyne.App, showSettings bool, showRoutes boo app: a, sendNotification: false, - showSettings: showSettings, - showRoutes: showRoutes, - update: version.NewUpdate(), + showAdvancedSettings: showSettings, + showRoutes: showRoutes, + update: version.NewUpdate(), } if runtime.GOOS == "windows" { @@ -215,8 +228,13 @@ func (s *serviceClient) showSettingsUI() { s.iLogFile = widget.NewEntry() s.iLogFile.Disable() s.iPreSharedKey = widget.NewPasswordEntry() + s.iInterfaceName = widget.NewEntry() + s.iInterfacePort = widget.NewEntry() + s.sRosenpassPermissive = widget.NewCheck("Enable Rosenpass permissive mode", nil) + s.wSettings.SetContent(s.getSettingsForm()) - s.wSettings.Resize(fyne.NewSize(600, 100)) + s.wSettings.Resize(fyne.NewSize(600, 400)) + s.wSettings.SetFixedSize(true) s.getSrvConfig() @@ -239,6 +257,9 @@ func showErrorMSG(msg string) { func (s *serviceClient) getSettingsForm() *widget.Form { return &widget.Form{ Items: []*widget.FormItem{ + {Text: "Quantum-Resistance", Widget: s.sRosenpassPermissive}, + {Text: "Interface Name", Widget: s.iInterfaceName}, + {Text: "Interface Port", Widget: s.iInterfacePort}, {Text: "Management URL", Widget: s.iMngURL}, {Text: "Admin URL", Widget: s.iAdminURL}, {Text: "Pre-shared Key", Widget: s.iPreSharedKey}, @@ -255,45 +276,45 @@ func (s *serviceClient) getSettingsForm() *widget.Form { } } + port, err := strconv.ParseInt(s.iInterfacePort.Text, 10, 64) + if err != nil { + dialog.ShowError(errors.New("Invalid interface port"), s.wSettings) + return + } + + iAdminURL := strings.TrimSpace(s.iAdminURL.Text) + iMngURL := strings.TrimSpace(s.iMngURL.Text) + defer s.wSettings.Close() - // if management URL or Pre-shared key changed, we try to re-login with new settings. - if s.managementURL != s.iMngURL.Text || s.preSharedKey != s.iPreSharedKey.Text || - s.adminURL != s.iAdminURL.Text { - s.managementURL = s.iMngURL.Text + // If the management URL, pre-shared key, admin URL, Rosenpass permissive mode, + // interface name, or interface port have changed, we attempt to re-login with the new settings. + if s.managementURL != iMngURL || s.preSharedKey != s.iPreSharedKey.Text || + s.adminURL != iAdminURL || s.RosenpassPermissive != s.sRosenpassPermissive.Checked || + s.interfaceName != s.iInterfaceName.Text || s.interfacePort != int(port) { + + s.managementURL = iMngURL s.preSharedKey = s.iPreSharedKey.Text - s.adminURL = s.iAdminURL.Text - - client, err := s.getSrvClient(failFastTimeout) - if err != nil { - log.Errorf("get daemon client: %v", err) - return - } + s.adminURL = iAdminURL loginRequest := proto.LoginRequest{ - ManagementUrl: s.iMngURL.Text, - AdminURL: s.iAdminURL.Text, + ManagementUrl: iMngURL, + AdminURL: iAdminURL, IsLinuxDesktopClient: runtime.GOOS == "linux", + RosenpassPermissive: &s.sRosenpassPermissive.Checked, + InterfaceName: &s.iInterfaceName.Text, + WireguardPort: &port, } if s.iPreSharedKey.Text != "**********" { loginRequest.OptionalPreSharedKey = &s.iPreSharedKey.Text } - _, err = client.Login(s.ctx, &loginRequest) - if err != nil { - log.Errorf("login to management URL: %v", err) + if err := s.restartClient(&loginRequest); err != nil { + log.Errorf("restarting client connection: %v", err) return } - - _, err = client.Up(s.ctx, &proto.UpRequest{}) - if err != nil { - log.Errorf("login to management URL: %v", err) - return - } - } - s.wSettings.Close() }, OnCancel: func() { s.wSettings.Close() @@ -499,7 +520,14 @@ func (s *serviceClient) onTrayReady() { s.mDown.Disable() s.mAdminPanel = systray.AddMenuItem("Admin Panel", "Netbird Admin Panel") systray.AddSeparator() + s.mSettings = systray.AddMenuItem("Settings", "Settings of the application") + s.mAllowSSH = s.mSettings.AddSubMenuItemCheckbox("Allow SSH", "Allow SSH connections", false) + s.mAutoConnect = s.mSettings.AddSubMenuItemCheckbox("Connect on Startup", "Connect automatically when the service starts", false) + s.mEnableRosenpass = s.mSettings.AddSubMenuItemCheckbox("Enable Quantum-Resistance", "Enable post-quantum security via Rosenpass", false) + s.mAdvancedSettings = s.mSettings.AddSubMenuItem("Advanced Settings", "Advanced settings of the application") + s.loadSettings() + s.mRoutes = systray.AddMenuItem("Network Routes", "Open the routes management window") s.mRoutes.Disable() systray.AddSeparator() @@ -539,7 +567,7 @@ func (s *serviceClient) onTrayReady() { case <-s.mAdminPanel.ClickedCh: err = open.Run(s.adminURL) case <-s.mUp.ClickedCh: - s.mUp.Disabled() + s.mUp.Disable() go func() { defer s.mUp.Enable() err := s.menuUpClick() @@ -558,10 +586,40 @@ func (s *serviceClient) onTrayReady() { return } }() - case <-s.mSettings.ClickedCh: - s.mSettings.Disable() + case <-s.mAllowSSH.ClickedCh: + if s.mAllowSSH.Checked() { + s.mAllowSSH.Uncheck() + } else { + s.mAllowSSH.Check() + } + if err := s.updateConfig(); err != nil { + log.Errorf("failed to update config: %v", err) + return + } + case <-s.mAutoConnect.ClickedCh: + if s.mAutoConnect.Checked() { + s.mAutoConnect.Uncheck() + } else { + s.mAutoConnect.Check() + } + if err := s.updateConfig(); err != nil { + log.Errorf("failed to update config: %v", err) + return + } + case <-s.mEnableRosenpass.ClickedCh: + if s.mEnableRosenpass.Checked() { + s.mEnableRosenpass.Uncheck() + } else { + s.mEnableRosenpass.Check() + } + if err := s.updateConfig(); err != nil { + log.Errorf("failed to update config: %v", err) + return + } + case <-s.mAdvancedSettings.ClickedCh: + s.mAdvancedSettings.Disable() go func() { - defer s.mSettings.Enable() + defer s.mAdvancedSettings.Enable() defer s.getSrvConfig() s.runSelfCommand("settings", "true") }() @@ -663,13 +721,23 @@ func (s *serviceClient) getSrvConfig() { s.adminURL = cfg.AdminURL } s.preSharedKey = cfg.PreSharedKey + s.RosenpassPermissive = cfg.RosenpassPermissive + s.interfaceName = cfg.InterfaceName + s.interfacePort = int(cfg.WireguardPort) - if s.showSettings { + if s.showAdvancedSettings { s.iMngURL.SetText(s.managementURL) s.iAdminURL.SetText(s.adminURL) s.iConfigFile.SetText(cfg.ConfigFile) s.iLogFile.SetText(cfg.LogFile) s.iPreSharedKey.SetText(cfg.PreSharedKey) + s.iInterfaceName.SetText(cfg.InterfaceName) + s.iInterfacePort.SetText(strconv.Itoa(int(cfg.WireguardPort))) + s.sRosenpassPermissive.SetChecked(cfg.RosenpassPermissive) + if !cfg.RosenpassEnabled { + s.sRosenpassPermissive.Disable() + } + } } @@ -704,6 +772,81 @@ func (s *serviceClient) onSessionExpire() { } } +// loadSettings loads the settings from the config file and updates the UI elements accordingly. +func (s *serviceClient) loadSettings() { + conn, err := s.getSrvClient(failFastTimeout) + if err != nil { + log.Errorf("get client: %v", err) + return + } + + cfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{}) + if err != nil { + log.Errorf("get config settings from server: %v", err) + return + } + + if cfg.ServerSSHAllowed { + s.mAllowSSH.Check() + } else { + s.mAllowSSH.Uncheck() + } + + if cfg.DisableAutoConnect { + s.mAutoConnect.Uncheck() + } else { + s.mAutoConnect.Check() + } + + if cfg.RosenpassEnabled { + s.mEnableRosenpass.Check() + } else { + s.mEnableRosenpass.Uncheck() + } +} + +// updateConfig updates the configuration parameters +// based on the values selected in the settings window. +func (s *serviceClient) updateConfig() error { + disableAutoStart := !s.mAutoConnect.Checked() + sshAllowed := s.mAllowSSH.Checked() + rosenpassEnabled := s.mEnableRosenpass.Checked() + + loginRequest := proto.LoginRequest{ + IsLinuxDesktopClient: runtime.GOOS == "linux", + ServerSSHAllowed: &sshAllowed, + RosenpassEnabled: &rosenpassEnabled, + DisableAutoConnect: &disableAutoStart, + } + + if err := s.restartClient(&loginRequest); err != nil { + log.Errorf("restarting client connection: %v", err) + return err + } + + return nil +} + +// restartClient restarts the client connection. +func (s *serviceClient) restartClient(loginRequest *proto.LoginRequest) error { + client, err := s.getSrvClient(failFastTimeout) + if err != nil { + return err + } + + _, err = client.Login(s.ctx, loginRequest) + if err != nil { + return err + } + + _, err = client.Up(s.ctx, &proto.UpRequest{}) + if err != nil { + return err + } + + return nil +} + func openURL(url string) error { var err error switch runtime.GOOS { @@ -734,3 +877,88 @@ func checkPIDFile() error { return os.WriteFile(pidFile, []byte(fmt.Sprintf("%d", os.Getpid())), 0o664) //nolint:gosec } + +func (s *serviceClient) setDefaultFonts() { + var ( + defaultFontPath string + ) + + //TODO: Linux Multiple Language Support + switch runtime.GOOS { + case "darwin": + defaultFontPath = "/Library/Fonts/Arial Unicode.ttf" + case "windows": + fontPath := s.getWindowsFontFilePath() + defaultFontPath = fontPath + } + + _, err := os.Stat(defaultFontPath) + + if err == nil { + os.Setenv("FYNE_FONT", defaultFontPath) + } +} + +func (s *serviceClient) getWindowsFontFilePath() (fontPath string) { + /* + https://learn.microsoft.com/en-us/windows/apps/design/globalizing/loc-international-fonts + https://learn.microsoft.com/en-us/typography/fonts/windows_11_font_list + */ + + var ( + fontFolder string = "C:/Windows/Fonts" + fontMapping = map[string]string{ + "default": "Segoeui.ttf", + "zh-CN": "Msyh.ttc", + "am-ET": "Ebrima.ttf", + "nirmala": "Nirmala.ttf", + "chr-CHER-US": "Gadugi.ttf", + "zh-HK": "Msjh.ttc", + "zh-TW": "Msjh.ttc", + "ja-JP": "Yugothm.ttc", + "km-KH": "Leelawui.ttf", + "ko-KR": "Malgun.ttf", + "th-TH": "Leelawui.ttf", + "ti-ET": "Ebrima.ttf", + } + nirMalaLang = []string{ + "as-IN", + "bn-BD", + "bn-IN", + "gu-IN", + "hi-IN", + "kn-IN", + "kok-IN", + "ml-IN", + "mr-IN", + "ne-NP", + "or-IN", + "pa-IN", + "si-LK", + "ta-IN", + "te-IN", + } + ) + cmd := exec.Command("powershell", "-Command", "(Get-Culture).Name") + output, err := cmd.Output() + if err != nil { + log.Errorf("Failed to get Windows default language setting: %v", err) + fontPath = path.Join(fontFolder, fontMapping["default"]) + return + } + defaultLanguage := strings.TrimSpace(string(output)) + + for _, lang := range nirMalaLang { + if defaultLanguage == lang { + fontPath = path.Join(fontFolder, fontMapping["nirmala"]) + return + } + } + + if font, ok := fontMapping[defaultLanguage]; ok { + fontPath = path.Join(fontFolder, font) + } else { + fontPath = path.Join(fontFolder, fontMapping["default"]) + } + return +} diff --git a/client/ui/route.go b/client/ui/route.go index 0ac58e5d5..5b6b8fee0 100644 --- a/client/ui/route.go +++ b/client/ui/route.go @@ -1,9 +1,10 @@ -//go:build !(linux && 386) +//go:build !(linux && 386) && !freebsd package main import ( "fmt" + "sort" "strings" "time" @@ -17,28 +18,57 @@ import ( "github.com/netbirdio/netbird/client/proto" ) +const ( + allRoutesText = "All routes" + overlappingRoutesText = "Overlapping routes" + exitNodeRoutesText = "Exit-node routes" + allRoutes filter = "all" + overlappingRoutes filter = "overlapping" + exitNodeRoutes filter = "exit-node" + getClientFMT = "get client: %v" +) + +type filter string + func (s *serviceClient) showRoutesUI() { s.wRoutes = s.app.NewWindow("NetBird Routes") - grid := container.New(layout.NewGridLayout(2)) - go s.updateRoutes(grid) + allGrid := container.New(layout.NewGridLayout(3)) + go s.updateRoutes(allGrid, allRoutes) + overlappingGrid := container.New(layout.NewGridLayout(3)) + exitNodeGrid := container.New(layout.NewGridLayout(3)) routeCheckContainer := container.NewVBox() - routeCheckContainer.Add(grid) + tabs := container.NewAppTabs( + container.NewTabItem(allRoutesText, allGrid), + container.NewTabItem(overlappingRoutesText, overlappingGrid), + container.NewTabItem(exitNodeRoutesText, exitNodeGrid), + ) + tabs.OnSelected = func(item *container.TabItem) { + s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) + } + tabs.OnUnselected = func(item *container.TabItem) { + grid, _ := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid) + grid.Objects = nil + } + + routeCheckContainer.Add(tabs) scrollContainer := container.NewVScroll(routeCheckContainer) scrollContainer.SetMinSize(fyne.NewSize(200, 300)) buttonBox := container.NewHBox( layout.NewSpacer(), widget.NewButton("Refresh", func() { - s.updateRoutes(grid) + s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) }), widget.NewButton("Select all", func() { - s.selectAllRoutes() - s.updateRoutes(grid) + _, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid) + s.selectAllFilteredRoutes(f) + s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) }), widget.NewButton("Deselect All", func() { - s.deselectAllRoutes() - s.updateRoutes(grid) + _, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid) + s.deselectAllFilteredRoutes(f) + s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) }), layout.NewSpacer(), ) @@ -48,27 +78,31 @@ func (s *serviceClient) showRoutesUI() { s.wRoutes.SetContent(content) s.wRoutes.Show() - s.startAutoRefresh(5*time.Second, grid) + s.startAutoRefresh(10*time.Second, tabs, allGrid, overlappingGrid, exitNodeGrid) } -func (s *serviceClient) updateRoutes(grid *fyne.Container) { - routes, err := s.fetchRoutes() - if err != nil { - log.Errorf("get client: %v", err) - s.showError(fmt.Errorf("get client: %v", err)) - return - } - +func (s *serviceClient) updateRoutes(grid *fyne.Container, f filter) { grid.Objects = nil + grid.Refresh() idHeader := widget.NewLabelWithStyle(" ID", fyne.TextAlignLeading, fyne.TextStyle{Bold: true}) - networkHeader := widget.NewLabelWithStyle("Network", fyne.TextAlignLeading, fyne.TextStyle{Bold: true}) + networkHeader := widget.NewLabelWithStyle("Network/Domains", fyne.TextAlignLeading, fyne.TextStyle{Bold: true}) + resolvedIPsHeader := widget.NewLabelWithStyle("Resolved IPs", fyne.TextAlignLeading, fyne.TextStyle{Bold: true}) grid.Add(idHeader) grid.Add(networkHeader) - for _, route := range routes { + grid.Add(resolvedIPsHeader) + + filteredRoutes, err := s.getFilteredRoutes(f) + if err != nil { + return + } + + sortRoutesByIDs(filteredRoutes) + + for _, route := range filteredRoutes { r := route - checkBox := widget.NewCheck(r.ID, func(checked bool) { + checkBox := widget.NewCheck(r.GetID(), func(checked bool) { s.selectRoute(r.ID, checked) }) checkBox.Checked = route.Selected @@ -76,16 +110,106 @@ func (s *serviceClient) updateRoutes(grid *fyne.Container) { checkBox.Refresh() grid.Add(checkBox) - grid.Add(widget.NewLabel(r.Network)) + network := r.GetNetwork() + domains := r.GetDomains() + + if len(domains) == 0 { + grid.Add(widget.NewLabel(network)) + grid.Add(widget.NewLabel("")) + continue + } + + // our selectors are only for display + noopFunc := func(_ string) { + // do nothing + } + + domainsSelector := widget.NewSelect(domains, noopFunc) + domainsSelector.Selected = domains[0] + grid.Add(domainsSelector) + + var resolvedIPsList []string + for _, domain := range domains { + if ipList, exists := r.GetResolvedIPs()[domain]; exists { + resolvedIPsList = append(resolvedIPsList, fmt.Sprintf("%s: %s", domain, strings.Join(ipList.GetIps(), ", "))) + } + } + + if len(resolvedIPsList) == 0 { + grid.Add(widget.NewLabel("")) + continue + } + + // TODO: limit width within the selector display + resolvedIPsSelector := widget.NewSelect(resolvedIPsList, noopFunc) + resolvedIPsSelector.Selected = resolvedIPsList[0] + resolvedIPsSelector.Resize(fyne.NewSize(100, 100)) + grid.Add(resolvedIPsSelector) } s.wRoutes.Content().Refresh() + grid.Refresh() +} + +func (s *serviceClient) getFilteredRoutes(f filter) ([]*proto.Route, error) { + routes, err := s.fetchRoutes() + if err != nil { + log.Errorf(getClientFMT, err) + s.showError(fmt.Errorf(getClientFMT, err)) + return nil, err + } + switch f { + case overlappingRoutes: + return getOverlappingRoutes(routes), nil + case exitNodeRoutes: + return getExitNodeRoutes(routes), nil + default: + } + return routes, nil +} + +func getOverlappingRoutes(routes []*proto.Route) []*proto.Route { + var filteredRoutes []*proto.Route + existingRange := make(map[string][]*proto.Route) + for _, route := range routes { + if len(route.Domains) > 0 { + continue + } + if r, exists := existingRange[route.GetNetwork()]; exists { + r = append(r, route) + existingRange[route.GetNetwork()] = r + } else { + existingRange[route.GetNetwork()] = []*proto.Route{route} + } + } + for _, r := range existingRange { + if len(r) > 1 { + filteredRoutes = append(filteredRoutes, r...) + } + } + return filteredRoutes +} + +func getExitNodeRoutes(routes []*proto.Route) []*proto.Route { + var filteredRoutes []*proto.Route + for _, route := range routes { + if route.Network == "0.0.0.0/0" { + filteredRoutes = append(filteredRoutes, route) + } + } + return filteredRoutes +} + +func sortRoutesByIDs(routes []*proto.Route) { + sort.Slice(routes, func(i, j int) bool { + return strings.ToLower(routes[i].GetID()) < strings.ToLower(routes[j].GetID()) + }) } func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { - return nil, fmt.Errorf("get client: %v", err) + return nil, fmt.Errorf(getClientFMT, err) } resp, err := conn.ListRoutes(s.ctx, &proto.ListRoutesRequest{}) @@ -99,8 +223,8 @@ func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) { func (s *serviceClient) selectRoute(id string, checked bool) { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { - log.Errorf("get client: %v", err) - s.showError(fmt.Errorf("get client: %v", err)) + log.Errorf(getClientFMT, err) + s.showError(fmt.Errorf(getClientFMT, err)) return } @@ -126,16 +250,14 @@ func (s *serviceClient) selectRoute(id string, checked bool) { } } -func (s *serviceClient) selectAllRoutes() { +func (s *serviceClient) selectAllFilteredRoutes(f filter) { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { - log.Errorf("get client: %v", err) + log.Errorf(getClientFMT, err) return } - req := &proto.SelectRoutesRequest{ - All: true, - } + req := s.getRoutesRequest(f, true) if _, err := conn.SelectRoutes(s.ctx, req); err != nil { log.Errorf("failed to select all routes: %v", err) s.showError(fmt.Errorf("failed to select all routes: %v", err)) @@ -145,16 +267,14 @@ func (s *serviceClient) selectAllRoutes() { log.Debug("All routes selected") } -func (s *serviceClient) deselectAllRoutes() { +func (s *serviceClient) deselectAllFilteredRoutes(f filter) { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { - log.Errorf("get client: %v", err) + log.Errorf(getClientFMT, err) return } - req := &proto.SelectRoutesRequest{ - All: true, - } + req := s.getRoutesRequest(f, false) if _, err := conn.DeselectRoutes(s.ctx, req); err != nil { log.Errorf("failed to deselect all routes: %v", err) s.showError(fmt.Errorf("failed to deselect all routes: %v", err)) @@ -164,17 +284,34 @@ func (s *serviceClient) deselectAllRoutes() { log.Debug("All routes deselected") } +func (s *serviceClient) getRoutesRequest(f filter, appendRoute bool) *proto.SelectRoutesRequest { + req := &proto.SelectRoutesRequest{} + if f == allRoutes { + req.All = true + } else { + routes, err := s.getFilteredRoutes(f) + if err != nil { + return nil + } + for _, route := range routes { + req.RouteIDs = append(req.RouteIDs, route.GetID()) + } + req.Append = appendRoute + } + return req +} + func (s *serviceClient) showError(err error) { wrappedMessage := wrapText(err.Error(), 50) dialog.ShowError(fmt.Errorf("%s", wrappedMessage), s.wRoutes) } -func (s *serviceClient) startAutoRefresh(interval time.Duration, grid *fyne.Container) { +func (s *serviceClient) startAutoRefresh(interval time.Duration, tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) { ticker := time.NewTicker(interval) go func() { for range ticker.C { - s.updateRoutes(grid) + s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodesGrid) } }() @@ -183,6 +320,23 @@ func (s *serviceClient) startAutoRefresh(interval time.Duration, grid *fyne.Cont }) } +func (s *serviceClient) updateRoutesBasedOnDisplayTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) { + grid, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodesGrid) + s.wRoutes.Content().Refresh() + s.updateRoutes(grid, f) +} + +func getGridAndFilterFromTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) (*fyne.Container, filter) { + switch tabs.Selected().Text { + case overlappingRoutesText: + return overlappingGrid, overlappingRoutes + case exitNodeRoutesText: + return exitNodesGrid, exitNodeRoutes + default: + return allGrid, allRoutes + } +} + // wrapText inserts newlines into the text to ensure that each line is // no longer than 'lineLength' runes. func wrapText(text string, lineLength int) string { diff --git a/formatter/hook.go b/formatter/hook.go index c3aa77fb3..12f27e67d 100644 --- a/formatter/hook.go +++ b/formatter/hook.go @@ -7,6 +7,18 @@ import ( "strings" "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server/context" +) + +type ExecutionContext string + +const ( + ExecutionContextKey = "executionContext" + + HTTPSource ExecutionContext = "HTTP" + GRPCSource ExecutionContext = "GRPC" + SystemSource ExecutionContext = "SYSTEM" ) // ContextHook is a custom hook for add the source information for the entry @@ -30,6 +42,27 @@ func (hook ContextHook) Levels() []logrus.Level { func (hook ContextHook) Fire(entry *logrus.Entry) error { src := hook.parseSrc(entry.Caller.File) entry.Data["source"] = fmt.Sprintf("%s:%v", src, entry.Caller.Line) + + if entry.Context == nil { + return nil + } + + source, ok := entry.Context.Value(ExecutionContextKey).(ExecutionContext) + if !ok { + return nil + } + + entry.Data["context"] = source + + switch source { + case HTTPSource: + addHTTPFields(entry) + case GRPCSource: + addGRPCFields(entry) + case SystemSource: + addSystemFields(entry) + } + return nil } @@ -59,3 +92,42 @@ func (hook ContextHook) parseSrc(filePath string) string { file := path.Base(filePath) return fmt.Sprintf("%s/%s", pkg, file) } + +func addHTTPFields(entry *logrus.Entry) { + if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok { + entry.Data[context.RequestIDKey] = ctxReqID + } + if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok { + entry.Data[context.AccountIDKey] = ctxAccountID + } + if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok { + entry.Data[context.UserIDKey] = ctxInitiatorID + } +} + +func addGRPCFields(entry *logrus.Entry) { + if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok { + entry.Data[context.RequestIDKey] = ctxReqID + } + if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok { + entry.Data[context.AccountIDKey] = ctxAccountID + } + if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok { + entry.Data[context.PeerIDKey] = ctxDeviceID + } +} + +func addSystemFields(entry *logrus.Entry) { + if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok { + entry.Data[context.RequestIDKey] = ctxReqID + } + if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok { + entry.Data[context.UserIDKey] = ctxInitiatorID + } + if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok { + entry.Data[context.AccountIDKey] = ctxAccountID + } + if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok { + entry.Data[context.PeerIDKey] = ctxDeviceID + } +} diff --git a/formatter/set.go b/formatter/set.go index cceeef860..f9ccef601 100644 --- a/formatter/set.go +++ b/formatter/set.go @@ -1,6 +1,8 @@ package formatter -import "github.com/sirupsen/logrus" +import ( + "github.com/sirupsen/logrus" +) // SetTextFormatter set the text formatter for given logger. func SetTextFormatter(logger *logrus.Logger) { @@ -9,6 +11,13 @@ func SetTextFormatter(logger *logrus.Logger) { logger.AddHook(NewContextHook()) } +// SetJSONFormatter set the JSON formatter for given logger. +func SetJSONFormatter(logger *logrus.Logger) { + logger.Formatter = &logrus.JSONFormatter{} + logger.ReportCaller = true + logger.AddHook(NewContextHook()) +} + // SetLogcatFormatter set the logcat formatter for given logger. func SetLogcatFormatter(logger *logrus.Logger) { logger.Formatter = NewLogcatFormatter() diff --git a/go.mod b/go.mod index 327476c4f..94cf5000d 100644 --- a/go.mod +++ b/go.mod @@ -44,7 +44,6 @@ require ( github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.6.0 github.com/google/gopacket v1.1.19 - github.com/google/martian/v3 v3.0.0 github.com/google/nftables v0.0.0-20220808154552-2eca00135732 github.com/gopacket/gopacket v1.1.1 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 @@ -58,7 +57,7 @@ require ( github.com/miekg/dns v1.1.43 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20240524104853-69c6d89826cd + github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 github.com/patrickmn/go-cache v2.1.0+incompatible @@ -70,6 +69,7 @@ require ( github.com/prometheus/client_golang v1.19.1 github.com/quic-go/quic-go v0.45.0 github.com/rs/xid v1.3.0 + github.com/shirou/gopsutil/v3 v3.24.4 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 github.com/stretchr/testify v1.9.0 github.com/testcontainers/testcontainers-go v0.31.0 @@ -77,6 +77,7 @@ require ( github.com/things-go/go-socks5 v0.0.4 github.com/yusufpapurcu/wmi v1.2.4 github.com/zcalusic/sysinfo v1.0.2 + go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 go.opentelemetry.io/otel v1.26.0 go.opentelemetry.io/otel/exporters/prometheus v0.48.0 go.opentelemetry.io/otel/metric v1.26.0 @@ -180,7 +181,6 @@ require ( github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.53.0 // indirect github.com/prometheus/procfs v0.15.0 // indirect - github.com/shirou/gopsutil/v3 v3.24.4 // indirect github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/spf13/cast v1.5.0 // indirect github.com/srwiley/oksvg v0.0.0-20200311192757-870daf9aa564 // indirect @@ -194,11 +194,11 @@ require ( go.opentelemetry.io/otel/sdk v1.26.0 // indirect go.opentelemetry.io/otel/trace v1.26.0 // indirect go.uber.org/mock v0.4.0 // indirect - golang.org/x/image v0.10.0 // indirect + golang.org/x/image v0.18.0 // indirect golang.org/x/mod v0.17.0 // indirect - golang.org/x/text v0.15.0 // indirect + golang.org/x/text v0.16.0 // indirect golang.org/x/time v0.5.0 // indirect - golang.org/x/tools v0.21.0 // indirect + golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291 // indirect @@ -218,3 +218,5 @@ replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-2 replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 replace github.com/pion/ice/v3 => github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e + +replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 diff --git a/go.sum b/go.sum index 84eec1545..b0d239035 100644 --- a/go.sum +++ b/go.sum @@ -214,8 +214,6 @@ github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSN github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= -github.com/google/martian/v3 v3.0.0 h1:pMen7vLs8nvgEYhywH3KDWJIJTeEr2ULsVWHWYHQyBs= -github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A= github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= @@ -296,8 +294,6 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/libp2p/go-netroute v0.2.1 h1:V8kVrpD8GK0Riv15/7VN6RbUQ3URNZVosw7H2v9tksU= -github.com/libp2p/go-netroute v0.2.1/go.mod h1:hraioZr0fhBjG0ZRXJJ6Zj2IVEVNx6tDTFQfSmcq7mQ= github.com/lucor/goinfo v0.0.0-20210802170112-c078a2b0f08b/go.mod h1:PRq09yoB+Q2OJReAmwzKivcYyremnibWGbK7WfftHzc= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae h1:dIZY4ULFcto4tAFlj1FYZl8ztUZ13bdq+PLY+NOfbyI= @@ -341,10 +337,12 @@ github.com/munnerz/goautoneg v0.0.0-20120707110453-a547fc61f48d/go.mod h1:+n7T8m github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc= github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ= +github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk= +github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= -github.com/netbirdio/management-integrations/integrations v0.0.0-20240524104853-69c6d89826cd h1:IzGGIJMpz07aPs3R6/4sxZv63JoCMddftLpVodUK+Ec= -github.com/netbirdio/management-integrations/integrations v0.0.0-20240524104853-69c6d89826cd/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc= +github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e h1:LYxhAmiEzSldLELHSMVoUnRPq3ztTNQImrD27frrGsI= +github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM= @@ -513,6 +511,8 @@ github.com/zcalusic/sysinfo v1.0.2 h1:nwTTo2a+WQ0NXwo0BGRojOJvJ/5XKvQih+2RrtWqfx github.com/zcalusic/sysinfo v1.0.2/go.mod h1:kluzTYflRWo6/tXVMJPdEjShsbPpsFRyy+p1mBQPC30= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 h1:4Pp6oUg3+e/6M4C0A/3kJ2VYa++dsWVTtGgLVj5xtHg= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0/go.mod h1:Mjt1i1INqiaoZOMGR1RIUJN+i3ChKoFRqzrRQhlkbs0= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 h1:Xs2Ncz0gNihqu9iosIZ5SkBbWo5T8JhhLJFMQL1qmLI= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0/go.mod h1:vy+2G/6NvVMpwGX/NyLqcC41fxepnuKHk16E6IZUcJc= go.opentelemetry.io/otel v1.26.0 h1:LQwgL5s/1W7YiiRwxf03QGnWLb2HW4pLiAhaA5cZXBs= @@ -553,8 +553,8 @@ golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJ golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20200430140353-33d19683fad8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/image v0.10.0 h1:gXjUUtwtx5yOE0VKWq1CH4IJAClq4UGgUA3i+rpON9M= -golang.org/x/image v0.10.0/go.mod h1:jtrku+n79PfroUbvDdeUWMAI+heR786BofxrbiSF+J0= +golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ= +golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= @@ -578,7 +578,6 @@ golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20191004110552-13f9640d40b9/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20191105084925-a882066a44e0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -669,11 +668,10 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= -golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -692,8 +690,8 @@ golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw= -golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/iface/address.go b/iface/address.go index 2920d009f..5ff4fbc06 100644 --- a/iface/address.go +++ b/iface/address.go @@ -23,24 +23,6 @@ func parseWGAddress(address string) (WGAddress, error) { }, nil } -// Masked returns the WGAddress with the IP address part masked according to its network mask. -func (addr WGAddress) Masked() WGAddress { - ip := addr.IP.To4() - if ip == nil { - ip = addr.IP.To16() - } - - maskedIP := make(net.IP, len(ip)) - for i := range ip { - maskedIP[i] = ip[i] & addr.Network.Mask[i] - } - - return WGAddress{ - IP: maskedIP, - Network: addr.Network, - } -} - func (addr WGAddress) String() string { maskSize, _ := addr.Network.Mask.Size() return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize) diff --git a/iface/bind/bind.go b/iface/bind/bind.go index 00af25f67..ba6153cb7 100644 --- a/iface/bind/bind.go +++ b/iface/bind/bind.go @@ -28,11 +28,14 @@ type ICEBind struct { transportNet transport.Net udpMux *UniversalUDPMuxDefault + + filterFn FilterFn } -func NewICEBind(transportNet transport.Net) *ICEBind { +func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind { ib := &ICEBind{ transportNet: transportNet, + filterFn: filterFn, } rc := receiverCreator{ @@ -59,8 +62,9 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC s.udpMux = NewUniversalUDPMuxDefault( UniversalUDPMuxParams{ - UDPConn: conn, - Net: s.transportNet, + UDPConn: conn, + Net: s.transportNet, + FilterFn: s.filterFn, }, ) return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { diff --git a/iface/bind/udp_mux_universal.go b/iface/bind/udp_mux_universal.go index 7121f1ff4..ebbefe035 100644 --- a/iface/bind/udp_mux_universal.go +++ b/iface/bind/udp_mux_universal.go @@ -8,6 +8,8 @@ import ( "context" "fmt" "net" + "net/netip" + "sync" "time" log "github.com/sirupsen/logrus" @@ -17,6 +19,10 @@ import ( "github.com/pion/transport/v3" ) +// FilterFn is a function that filters out candidates based on the address. +// If it returns true, the address is to be filtered. It also returns the prefix of matching route. +type FilterFn func(address netip.Addr) (bool, netip.Prefix, error) + // UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn // It then passes packets to the UDPMux that does the actual connection muxing. type UniversalUDPMuxDefault struct { @@ -34,6 +40,7 @@ type UniversalUDPMuxParams struct { UDPConn net.PacketConn XORMappedAddrCacheTTL time.Duration Net transport.Net + FilterFn FilterFn } // NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux @@ -56,6 +63,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef PacketConn: params.UDPConn, mux: m, logger: params.Logger, + filterFn: params.FilterFn, } // embed UDPMux @@ -105,8 +113,68 @@ func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) { // udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets type udpConn struct { net.PacketConn - mux *UniversalUDPMuxDefault - logger logging.LeveledLogger + mux *UniversalUDPMuxDefault + logger logging.LeveledLogger + filterFn FilterFn + // TODO: reset cache on route changes + addrCache sync.Map +} + +func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) { + if u.filterFn == nil { + return u.PacketConn.WriteTo(b, addr) + } + + if isRouted, found := u.addrCache.Load(addr.String()); found { + return u.handleCachedAddress(isRouted.(bool), b, addr) + } + + return u.handleUncachedAddress(b, addr) +} + +func (u *udpConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) { + if isRouted { + return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr) + } + return u.PacketConn.WriteTo(b, addr) +} + +func (u *udpConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) { + if err := u.performFilterCheck(addr); err != nil { + return 0, err + } + return u.PacketConn.WriteTo(b, addr) +} + +func (u *udpConn) performFilterCheck(addr net.Addr) error { + host, err := getHostFromAddr(addr) + if err != nil { + log.Errorf("Failed to get host from address %s: %v", addr, err) + return nil + } + + a, err := netip.ParseAddr(host) + if err != nil { + log.Errorf("Failed to parse address %s: %v", addr, err) + return nil + } + + if isRouted, prefix, err := u.filterFn(a); err != nil { + log.Errorf("Failed to check if address %s is routed: %v", addr, err) + } else { + u.addrCache.Store(addr.String(), isRouted) + if isRouted { + // Extra log, as the error only shows up with ICE logging enabled + log.Infof("Address %s is part of routed network %s, refusing to write", addr, prefix) + return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix) + } + } + return nil +} + +func getHostFromAddr(addr net.Addr) (string, error) { + host, _, err := net.SplitHostPort(addr.String()) + return host, err } // GetSharedConn returns the shared udp conn diff --git a/iface/freebsd/errors.go b/iface/freebsd/errors.go new file mode 100644 index 000000000..e2c6a2aa9 --- /dev/null +++ b/iface/freebsd/errors.go @@ -0,0 +1,8 @@ +package freebsd + +import "errors" + +var ( + ErrDoesNotExist = errors.New("does not exist") + ErrNameDoesNotMatch = errors.New("name does not match") +) diff --git a/iface/freebsd/iface.go b/iface/freebsd/iface.go new file mode 100644 index 000000000..d32fa6436 --- /dev/null +++ b/iface/freebsd/iface.go @@ -0,0 +1,108 @@ +package freebsd + +import ( + "bufio" + "fmt" + "strconv" + "strings" +) + +type iface struct { + Name string + MTU int + Group string + IPAddrs []string +} + +func parseError(output []byte) error { + // TODO: implement without allocations + lines := string(output) + + if strings.Contains(lines, "does not exist") { + return ErrDoesNotExist + } + + return nil +} + +func parseIfconfigOutput(output []byte) (*iface, error) { + // TODO: implement without allocations + lines := string(output) + + scanner := bufio.NewScanner(strings.NewReader(lines)) + + var name, mtu, group string + var ips []string + + for scanner.Scan() { + line := scanner.Text() + + // If line contains ": flags", it's a line with interface information + if strings.Contains(line, ": flags") { + parts := strings.Fields(line) + if len(parts) < 4 { + return nil, fmt.Errorf("failed to parse line: %s", line) + } + name = strings.TrimSuffix(parts[0], ":") + if strings.Contains(line, "mtu") { + mtuIndex := 0 + for i, part := range parts { + if part == "mtu" { + mtuIndex = i + break + } + } + mtu = parts[mtuIndex+1] + } + } + + // If line contains "groups:", it's a line with interface group + if strings.Contains(line, "groups:") { + parts := strings.Fields(line) + if len(parts) < 2 { + return nil, fmt.Errorf("failed to parse line: %s", line) + } + group = parts[1] + } + + // If line contains "inet ", it's a line with IP address + if strings.Contains(line, "inet ") { + parts := strings.Fields(line) + if len(parts) < 2 { + return nil, fmt.Errorf("failed to parse line: %s", line) + } + ips = append(ips, parts[1]) + } + } + + if name == "" { + return nil, fmt.Errorf("interface name not found in ifconfig output") + } + + mtuInt, err := strconv.Atoi(mtu) + if err != nil { + return nil, fmt.Errorf("failed to parse MTU: %w", err) + } + + return &iface{ + Name: name, + MTU: mtuInt, + Group: group, + IPAddrs: ips, + }, nil +} + +func parseIFName(output []byte) (string, error) { + // TODO: implement without allocations + lines := strings.Split(string(output), "\n") + if len(lines) == 0 || lines[0] == "" { + return "", fmt.Errorf("no output returned") + } + + fields := strings.Fields(lines[0]) + if len(fields) > 1 { + return "", fmt.Errorf("invalid output") + } + + return fields[0], nil +} diff --git a/iface/freebsd/iface_internal_test.go b/iface/freebsd/iface_internal_test.go new file mode 100644 index 000000000..f933ae634 --- /dev/null +++ b/iface/freebsd/iface_internal_test.go @@ -0,0 +1,76 @@ +package freebsd + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseIfconfigOutput(t *testing.T) { + testOutput := `wg1: flags=8080 metric 0 mtu 1420 + options=80000 + groups: wg + nd6 options=109` + + expected := &iface{ + Name: "wg1", + MTU: 1420, + Group: "wg", + } + + result, err := parseIfconfigOutput(([]byte)(testOutput)) + if err != nil { + t.Errorf("Error parsing ifconfig output: %v", err) + return + } + + assert.Equal(t, expected.Name, result.Name, "Name should match") + assert.Equal(t, expected.MTU, result.MTU, "MTU should match") + assert.Equal(t, expected.Group, result.Group, "Group should match") +} + +func TestParseIFName(t *testing.T) { + tests := []struct { + name string + output string + expected string + expectedErr error + }{ + { + name: "ValidOutput", + output: "eth0\n", + expected: "eth0", + }, + { + name: "ValidOutputOneLine", + output: "eth0", + expected: "eth0", + }, + { + name: "EmptyOutput", + output: "", + expectedErr: fmt.Errorf("no output returned"), + }, + { + name: "InvalidOutput", + output: "This is an invalid output\n", + expectedErr: fmt.Errorf("invalid output"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result, err := parseIFName(([]byte)(test.output)) + + assert.Equal(t, test.expected, result, "Interface names should match") + + if test.expectedErr != nil { + assert.NotNil(t, err, "Error should not be nil") + assert.EqualError(t, err, test.expectedErr.Error(), "Error messages should match") + } else { + assert.Nil(t, err, "Error should be nil") + } + }) + } +} diff --git a/iface/freebsd/link.go b/iface/freebsd/link.go new file mode 100644 index 000000000..b7924f04b --- /dev/null +++ b/iface/freebsd/link.go @@ -0,0 +1,239 @@ +package freebsd + +import ( + "bytes" + "errors" + "fmt" + "os/exec" + "strconv" + + log "github.com/sirupsen/logrus" +) + +const wgIFGroup = "wg" + +// Link represents a network interface. +type Link struct { + name string +} + +func NewLink(name string) *Link { + return &Link{ + name: name, + } +} + +// LinkByName retrieves a network interface by its name. +func LinkByName(name string) (*Link, error) { + out, err := exec.Command("ifconfig", name).CombinedOutput() + if err != nil { + if pErr := parseError(out); pErr != nil { + return nil, pErr + } + + log.Debugf("ifconfig out: %s", out) + + return nil, fmt.Errorf("command run: %w", err) + } + + i, err := parseIfconfigOutput(out) + if err != nil { + return nil, fmt.Errorf("parse ifconfig output: %w", err) + } + + if i.Name != name { + return nil, ErrNameDoesNotMatch + } + + return &Link{name: i.Name}, nil +} + +// Recreate - create new interface, remove current before create if it exists +func (l *Link) Recreate() error { + ok, err := l.isExist() + if err != nil { + return fmt.Errorf("is exist: %w", err) + } + + if ok { + if err := l.del(l.name); err != nil { + return fmt.Errorf("del: %w", err) + } + } + + return l.Add() +} + +// Add creates a new network interface. +func (l *Link) Add() error { + parsedName, err := l.create(wgIFGroup) + if err != nil { + return fmt.Errorf("create link: %w", err) + } + + if parsedName == l.name { + return nil + } + + parsedName, err = l.rename(parsedName, l.name) + if err != nil { + errDel := l.del(parsedName) + if errDel != nil { + return fmt.Errorf("del on rename link: %w: %w", err, errDel) + } + + return fmt.Errorf("rename link: %w", err) + } + + return nil +} + +// Del removes an existing network interface. +func (l *Link) Del() error { + return l.del(l.name) +} + +// SetMTU sets the MTU of the network interface. +func (l *Link) SetMTU(mtu int) error { + return l.setMTU(mtu) +} + +// AssignAddr assigns an IP address and netmask to the network interface. +func (l *Link) AssignAddr(ip, netmask string) error { + return l.setAddr(ip, netmask) +} + +func (l *Link) Up() error { + return l.up(l.name) +} + +func (l *Link) Down() error { + return l.down(l.name) +} + +func (l *Link) isExist() (bool, error) { + _, err := LinkByName(l.name) + if errors.Is(err, ErrDoesNotExist) { + return false, nil + } + + if err != nil { + return false, fmt.Errorf("link by name: %w", err) + } + + return true, nil +} + +func (l *Link) create(groupName string) (string, error) { + cmd := exec.Command("ifconfig", groupName, "create") + + output, err := cmd.CombinedOutput() + if err != nil { + log.Debugf("ifconfig out: %s", output) + + return "", fmt.Errorf("create %s interface: %w", groupName, err) + } + + interfaceName, err := parseIFName(output) + if err != nil { + return "", fmt.Errorf("parse interface name: %w", err) + } + + return interfaceName, nil +} + +func (l *Link) rename(oldName, newName string) (string, error) { + cmd := exec.Command("ifconfig", oldName, "name", newName) + + output, err := cmd.CombinedOutput() + if err != nil { + log.Debugf("ifconfig out: %s", output) + + return "", fmt.Errorf("change name %q -> %q: %w", oldName, newName, err) + } + + interfaceName, err := parseIFName(output) + if err != nil { + return "", fmt.Errorf("parse new name: %w", err) + } + + return interfaceName, nil +} + +func (l *Link) del(name string) error { + var stderr bytes.Buffer + + cmd := exec.Command("ifconfig", name, "destroy") + cmd.Stderr = &stderr + + err := cmd.Run() + if err != nil { + log.Debugf("ifconfig out: %s", stderr.String()) + + return fmt.Errorf("destroy %s interface: %w", name, err) + } + + return nil +} + +func (l *Link) setMTU(mtu int) error { + var stderr bytes.Buffer + + cmd := exec.Command("ifconfig", l.name, "mtu", strconv.Itoa(mtu)) + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + log.Debugf("ifconfig out: %s", stderr.String()) + + return fmt.Errorf("set interface mtu: %w", err) + } + + return nil +} + +func (l *Link) setAddr(ip, netmask string) error { + var stderr bytes.Buffer + + cmd := exec.Command("ifconfig", l.name, "inet", ip, "netmask", netmask) + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + log.Debugf("ifconfig out: %s", stderr.String()) + + return fmt.Errorf("set interface addr: %w", err) + } + + return nil +} + +func (l *Link) up(name string) error { + var stderr bytes.Buffer + + cmd := exec.Command("ifconfig", name, "up") + cmd.Stderr = &stderr + + err := cmd.Run() + if err != nil { + log.Debugf("ifconfig out: %s", stderr.String()) + + return fmt.Errorf("up %s interface: %w", name, err) + } + + return nil +} + +func (l *Link) down(name string) error { + var stderr bytes.Buffer + + cmd := exec.Command("ifconfig", name, "down") + cmd.Stderr = &stderr + + err := cmd.Run() + if err != nil { + log.Debugf("ifconfig out: %s", stderr.String()) + + return fmt.Errorf("down %s interface: %w", name, err) + } + + return nil +} diff --git a/iface/iface.go b/iface/iface.go index 3ae40ad4c..928077a3d 100644 --- a/iface/iface.go +++ b/iface/iface.go @@ -48,6 +48,19 @@ func (w *WGIface) Address() WGAddress { return w.tun.WgAddress() } +// ToInterface returns the net.Interface for the Wireguard interface +func (r *WGIface) ToInterface() *net.Interface { + name := r.tun.DeviceName() + intf, err := net.InterfaceByName(name) + if err != nil { + log.Warnf("Failed to get interface by name %s: %v", name, err) + intf = &net.Interface{ + Name: name, + } + } + return intf +} + // Up configures a Wireguard interface // The interface must exist before calling this method (e.g. call interface.Create() before) func (w *WGIface) Up() (*bind.UniversalUDPMuxDefault, error) { @@ -94,7 +107,7 @@ func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error { w.mu.Lock() defer w.mu.Unlock() - log.Debugf("adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP) + log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP) return w.configurer.addAllowedIP(peerKey, allowedIP) } @@ -103,7 +116,7 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error { w.mu.Lock() defer w.mu.Unlock() - log.Debugf("removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP) + log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP) return w.configurer.removeAllowedIP(peerKey, allowedIP) } diff --git a/iface/iface_android.go b/iface/iface_android.go index d1876e495..99f6885a5 100644 --- a/iface/iface_android.go +++ b/iface/iface_android.go @@ -4,17 +4,19 @@ import ( "fmt" "github.com/pion/transport/v3" + + "github.com/netbirdio/netbird/iface/bind" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments) (*WGIface, error) { +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { wgAddress, err := parseWGAddress(address) if err != nil { return nil, err } wgIFace := &WGIface{ - tun: newTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter), + tun: newTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter, filterFn), userspaceBind: true, } return wgIFace, nil diff --git a/iface/iface_create.go b/iface/iface_create.go index 86c3f320f..cfc555f2e 100644 --- a/iface/iface_create.go +++ b/iface/iface_create.go @@ -1,5 +1,4 @@ //go:build !android -// +build !android package iface diff --git a/iface/iface_darwin.go b/iface/iface_darwin.go index 4d62c6af6..15e4a7817 100644 --- a/iface/iface_darwin.go +++ b/iface/iface_darwin.go @@ -1,5 +1,4 @@ //go:build !ios -// +build !ios package iface @@ -8,11 +7,12 @@ import ( "github.com/pion/transport/v3" + "github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/iface/netstack" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments) (*WGIface, error) { +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, _ *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { wgAddress, err := parseWGAddress(address) if err != nil { return nil, err @@ -23,11 +23,11 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, } if netstack.IsEnabled() { - wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr()) + wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) return wgIFace, nil } - wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet) + wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn) return wgIFace, nil } diff --git a/iface/iface_ios.go b/iface/iface_ios.go index b22e1a6a4..6babe5964 100644 --- a/iface/iface_ios.go +++ b/iface/iface_ios.go @@ -1,5 +1,4 @@ //go:build ios -// +build ios package iface @@ -7,16 +6,18 @@ import ( "fmt" "github.com/pion/transport/v3" + + "github.com/netbirdio/netbird/iface/bind" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments) (*WGIface, error) { +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { wgAddress, err := parseWGAddress(address) if err != nil { return nil, err } wgIFace := &WGIface{ - tun: newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd), + tun: newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd, filterFn), userspaceBind: true, } return wgIFace, nil diff --git a/iface/iface_test.go b/iface/iface_test.go index f227eaf83..43c44b770 100644 --- a/iface/iface_test.go +++ b/iface/iface_test.go @@ -41,7 +41,7 @@ func TestWGIface_UpdateAddr(t *testing.T) { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, addr, wgPort, key, DefaultMTU, newNet, nil) + iface, err := NewWGIFace(ifaceName, addr, wgPort, key, DefaultMTU, newNet, nil, nil) if err != nil { t.Fatal(err) } @@ -114,7 +114,7 @@ func Test_CreateInterface(t *testing.T) { if err != nil { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil) + iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil) if err != nil { t.Fatal(err) } @@ -149,7 +149,7 @@ func Test_Close(t *testing.T) { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil) + iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil) if err != nil { t.Fatal(err) } @@ -182,7 +182,7 @@ func Test_ConfigureInterface(t *testing.T) { if err != nil { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil) + iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil) if err != nil { t.Fatal(err) } @@ -230,7 +230,7 @@ func Test_UpdatePeer(t *testing.T) { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil) + iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil) if err != nil { t.Fatal(err) } @@ -291,7 +291,7 @@ func Test_RemovePeer(t *testing.T) { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil) + iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil) if err != nil { t.Fatal(err) } @@ -345,7 +345,7 @@ func Test_ConnectPeers(t *testing.T) { t.Fatal(err) } - iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, peer1wgPort, peer1Key.String(), DefaultMTU, newNet, nil) + iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, peer1wgPort, peer1Key.String(), DefaultMTU, newNet, nil, nil) if err != nil { t.Fatal(err) } @@ -368,7 +368,7 @@ func Test_ConnectPeers(t *testing.T) { if err != nil { t.Fatal(err) } - iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, peer2wgPort, peer2Key.String(), DefaultMTU, newNet, nil) + iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, peer2wgPort, peer2Key.String(), DefaultMTU, newNet, nil, nil) if err != nil { t.Fatal(err) } diff --git a/iface/iface_linux.go b/iface/iface_unix.go similarity index 76% rename from iface/iface_linux.go rename to iface/iface_unix.go index 62ae0f0de..9608df1ad 100644 --- a/iface/iface_linux.go +++ b/iface/iface_unix.go @@ -1,18 +1,19 @@ -//go:build !android -// +build !android +//go:build (linux && !android) || freebsd package iface import ( "fmt" + "runtime" "github.com/pion/transport/v3" + "github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/iface/netstack" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments) (*WGIface, error) { +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { wgAddress, err := parseWGAddress(address) if err != nil { return nil, err @@ -22,7 +23,7 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, // move the kernel/usp/netstack preference evaluation to upper layer if netstack.IsEnabled() { - wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr()) + wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) wgIFace.userspaceBind = true return wgIFace, nil } @@ -36,12 +37,12 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, if !tunModuleIsLoaded() { return nil, fmt.Errorf("couldn't check or load tun module") } - wgIFace.tun = newTunUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet) + wgIFace.tun = newTunUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, nil) wgIFace.userspaceBind = true return wgIFace, nil } // CreateOnAndroid this function make sense on mobile only func (w *WGIface) CreateOnAndroid([]string, string, []string) error { - return fmt.Errorf("this function has not implemented on this platform") + return fmt.Errorf("CreateOnAndroid function has not implemented on %s platform", runtime.GOOS) } diff --git a/iface/iface_windows.go b/iface/iface_windows.go index d3a16a52f..c5edd27a9 100644 --- a/iface/iface_windows.go +++ b/iface/iface_windows.go @@ -5,11 +5,12 @@ import ( "github.com/pion/transport/v3" + "github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/iface/netstack" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments) (*WGIface, error) { +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { wgAddress, err := parseWGAddress(address) if err != nil { return nil, err @@ -20,11 +21,11 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, } if netstack.IsEnabled() { - wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr()) + wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) return wgIFace, nil } - wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet) + wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn) return wgIFace, nil } diff --git a/iface/module.go b/iface/module.go index 7f337201d..ca70cf3c7 100644 --- a/iface/module.go +++ b/iface/module.go @@ -1,5 +1,4 @@ -//go:build !linux || android -// +build !linux android +//go:build (!linux && !freebsd) || android package iface diff --git a/iface/module_freebsd.go b/iface/module_freebsd.go new file mode 100644 index 000000000..00ad882c2 --- /dev/null +++ b/iface/module_freebsd.go @@ -0,0 +1,18 @@ +package iface + +// WireGuardModuleIsLoaded check if kernel support wireguard +func WireGuardModuleIsLoaded() bool { + // Despite the fact FreeBSD natively support Wireguard (https://github.com/WireGuard/wireguard-freebsd) + // we are currently do not use it, since it is required to add wireguard kernel support to + // - https://github.com/netbirdio/netbird/tree/main/sharedsock + // - https://github.com/mdlayher/socket + // TODO: implement kernel space + return false +} + +// tunModuleIsLoaded check if tun module exist, if is not attempt to load it +func tunModuleIsLoaded() bool { + // Assume tun supported by freebsd kernel by default + // TODO: implement check for module loaded in kernel or build-it + return true +} diff --git a/iface/name.go b/iface/name.go index 05d0299d3..706cb65ad 100644 --- a/iface/name.go +++ b/iface/name.go @@ -1,5 +1,4 @@ -//go:build linux || windows -// +build linux windows +//go:build linux || windows || freebsd package iface diff --git a/iface/name_darwin.go b/iface/name_darwin.go index c80f790f5..a4016ce15 100644 --- a/iface/name_darwin.go +++ b/iface/name_darwin.go @@ -1,5 +1,4 @@ //go:build darwin -// +build darwin package iface diff --git a/iface/tun_android.go b/iface/tun_android.go index 834b2cb42..dc6abea36 100644 --- a/iface/tun_android.go +++ b/iface/tun_android.go @@ -31,13 +31,13 @@ type wgTunDevice struct { configurer wgConfigurer } -func newTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter) wgTunDevice { +func newTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter, filterFn bind.FilterFn) wgTunDevice { return wgTunDevice{ address: address, port: port, key: key, mtu: mtu, - iceBind: bind.NewICEBind(transportNet), + iceBind: bind.NewICEBind(transportNet, filterFn), tunAdapter: tunAdapter, } } diff --git a/iface/tun_darwin.go b/iface/tun_darwin.go index 8dc10bd0e..7d684f52e 100644 --- a/iface/tun_darwin.go +++ b/iface/tun_darwin.go @@ -27,14 +27,14 @@ type tunDevice struct { configurer wgConfigurer } -func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net) wgTunDevice { +func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice { return &tunDevice{ name: name, address: address, port: port, key: key, mtu: mtu, - iceBind: bind.NewICEBind(transportNet), + iceBind: bind.NewICEBind(transportNet, filterFn), } } diff --git a/iface/tun_ios.go b/iface/tun_ios.go index ea980818d..83e26e08d 100644 --- a/iface/tun_ios.go +++ b/iface/tun_ios.go @@ -29,13 +29,13 @@ type tunDevice struct { configurer wgConfigurer } -func newTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int) *tunDevice { +func newTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int, filterFn bind.FilterFn) *tunDevice { return &tunDevice{ name: name, address: address, port: port, key: key, - iceBind: bind.NewICEBind(transportNet), + iceBind: bind.NewICEBind(transportNet, filterFn), tunFd: tunFd, } } diff --git a/iface/tun_kernel_linux.go b/iface/tun_kernel_unix.go similarity index 63% rename from iface/tun_kernel_linux.go rename to iface/tun_kernel_unix.go index 12adcdf73..019dd786b 100644 --- a/iface/tun_kernel_linux.go +++ b/iface/tun_kernel_unix.go @@ -1,4 +1,4 @@ -//go:build linux && !android +//go:build (linux && !android) || freebsd package iface @@ -6,11 +6,9 @@ import ( "context" "fmt" "net" - "os" "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" - "github.com/vishvananda/netlink" "github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/sharedsock" @@ -29,9 +27,13 @@ type tunKernelDevice struct { link *wgLink udpMuxConn net.PacketConn udpMux *bind.UniversalUDPMuxDefault + + filterFn bind.FilterFn } func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) wgTunDevice { + checkUser() + ctx, cancel := context.WithCancel(context.Background()) return &tunKernelDevice{ ctx: ctx, @@ -48,53 +50,29 @@ func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu in func (t *tunKernelDevice) Create() (wgConfigurer, error) { link := newWGLink(t.name) - // check if interface exists - l, err := netlink.LinkByName(t.name) - if err != nil { - switch err.(type) { - case netlink.LinkNotFoundError: - break - default: - return nil, err - } - } - - // remove if interface exists - if l != nil { - err = netlink.LinkDel(link) - if err != nil { - return nil, err - } - } - - log.Debugf("adding device: %s", t.name) - err = netlink.LinkAdd(link) - if os.IsExist(err) { - log.Infof("interface %s already exists. Will reuse.", t.name) - } else if err != nil { - return nil, err + if err := link.recreate(); err != nil { + return nil, fmt.Errorf("recreate: %w", err) } t.link = link - err = t.assignAddr() - if err != nil { - return nil, err + if err := t.assignAddr(); err != nil { + return nil, fmt.Errorf("assign addr: %w", err) } - // todo do a discovery + // TODO: do a MTU discovery log.Debugf("setting MTU: %d interface: %s", t.mtu, t.name) - err = netlink.LinkSetMTU(link, t.mtu) - if err != nil { - log.Errorf("error setting MTU on interface: %s", t.name) - return nil, err + + if err := link.setMTU(t.mtu); err != nil { + return nil, fmt.Errorf("set mtu: %w", err) } configurer := newWGConfigurer(t.name) - err = configurer.configureInterface(t.key, t.wgPort) - if err != nil { + + if err := configurer.configureInterface(t.key, t.wgPort); err != nil { return nil, err } + return configurer, nil } @@ -108,9 +86,10 @@ func (t *tunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { } log.Debugf("bringing up interface: %s", t.name) - err := netlink.LinkSetUp(t.link) - if err != nil { + + if err := t.link.up(); err != nil { log.Errorf("error bringing up interface: %s", t.name) + return nil, err } @@ -119,8 +98,9 @@ func (t *tunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return nil, err } bindParams := bind.UniversalUDPMuxParams{ - UDPConn: rawSock, - Net: t.transportNet, + UDPConn: rawSock, + Net: t.transportNet, + FilterFn: t.filterFn, } mux := bind.NewUniversalUDPMuxDefault(bindParams) go mux.ReadFromConn(t.ctx) @@ -178,32 +158,5 @@ func (t *tunKernelDevice) Wrapper() *DeviceWrapper { // assignAddr Adds IP address to the tunnel interface func (t *tunKernelDevice) assignAddr() error { - link := newWGLink(t.name) - - //delete existing addresses - list, err := netlink.AddrList(link, 0) - if err != nil { - return err - } - if len(list) > 0 { - for _, a := range list { - addr := a - err = netlink.AddrDel(link, &addr) - if err != nil { - return err - } - } - } - - log.Debugf("adding address %s to interface: %s", t.address.String(), t.name) - addr, _ := netlink.ParseAddr(t.address.String()) - err = netlink.AddrAdd(link, addr) - if os.IsExist(err) { - log.Infof("interface %s already has the address: %s", t.name, t.address.String()) - } else if err != nil { - return err - } - // On linux, the link must be brought up - err = netlink.LinkSetUp(link) - return err + return t.link.assignAddr(t.address) } diff --git a/iface/tun_link_freebsd.go b/iface/tun_link_freebsd.go new file mode 100644 index 000000000..be7921fdb --- /dev/null +++ b/iface/tun_link_freebsd.go @@ -0,0 +1,80 @@ +package iface + +import ( + "fmt" + + "github.com/netbirdio/netbird/iface/freebsd" + log "github.com/sirupsen/logrus" +) + +type wgLink struct { + name string + link *freebsd.Link +} + +func newWGLink(name string) *wgLink { + link := freebsd.NewLink(name) + + return &wgLink{ + name: name, + link: link, + } +} + +// Type returns the interface type +func (l *wgLink) Type() string { + return "wireguard" +} + +// Close deletes the link interface +func (l *wgLink) Close() error { + return l.link.Del() +} + +func (l *wgLink) recreate() error { + if err := l.link.Recreate(); err != nil { + return fmt.Errorf("recreate: %w", err) + } + + return nil +} + +func (l *wgLink) setMTU(mtu int) error { + if err := l.link.SetMTU(mtu); err != nil { + return fmt.Errorf("set mtu: %w", err) + } + + return nil +} + +func (l *wgLink) up() error { + if err := l.link.Up(); err != nil { + return fmt.Errorf("up: %w", err) + } + + return nil +} + +func (l *wgLink) assignAddr(address WGAddress) error { + link, err := freebsd.LinkByName(l.name) + if err != nil { + return fmt.Errorf("link by name: %w", err) + } + + ip := address.IP.String() + mask := "0x" + address.Network.Mask.String() + + log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name) + + err = link.AssignAddr(ip, mask) + if err != nil { + return fmt.Errorf("assign addr: %w", err) + } + + err = link.Up() + if err != nil { + return fmt.Errorf("up: %w", err) + } + + return nil +} diff --git a/iface/tun_link_linux.go b/iface/tun_link_linux.go index ab28b7e38..3ce644e84 100644 --- a/iface/tun_link_linux.go +++ b/iface/tun_link_linux.go @@ -2,7 +2,13 @@ package iface -import "github.com/vishvananda/netlink" +import ( + "fmt" + "os" + + log "github.com/sirupsen/logrus" + "github.com/vishvananda/netlink" +) type wgLink struct { attrs *netlink.LinkAttrs @@ -31,3 +37,97 @@ func (l *wgLink) Type() string { func (l *wgLink) Close() error { return netlink.LinkDel(l) } + +func (l *wgLink) recreate() error { + name := l.attrs.Name + + // check if interface exists + link, err := netlink.LinkByName(name) + if err != nil { + switch err.(type) { + case netlink.LinkNotFoundError: + break + default: + return fmt.Errorf("link by name: %w", err) + } + } + + // remove if interface exists + if link != nil { + err = netlink.LinkDel(l) + if err != nil { + return err + } + } + + log.Debugf("adding device: %s", name) + err = netlink.LinkAdd(l) + if os.IsExist(err) { + log.Infof("interface %s already exists. Will reuse.", name) + } else if err != nil { + return fmt.Errorf("link add: %w", err) + } + + return nil +} + +func (l *wgLink) setMTU(mtu int) error { + if err := netlink.LinkSetMTU(l, mtu); err != nil { + log.Errorf("error setting MTU on interface: %s", l.attrs.Name) + + return fmt.Errorf("link set mtu: %w", err) + } + + return nil +} + +func (l *wgLink) up() error { + if err := netlink.LinkSetUp(l); err != nil { + log.Errorf("error bringing up interface: %s", l.attrs.Name) + return fmt.Errorf("link setup: %w", err) + } + + return nil +} + +func (l *wgLink) assignAddr(address WGAddress) error { + //delete existing addresses + list, err := netlink.AddrList(l, 0) + if err != nil { + return fmt.Errorf("list addr: %w", err) + } + + if len(list) > 0 { + for _, a := range list { + addr := a + err = netlink.AddrDel(l, &addr) + if err != nil { + return fmt.Errorf("del addr: %w", err) + } + } + } + + name := l.attrs.Name + addrStr := address.String() + + log.Debugf("adding address %s to interface: %s", addrStr, name) + + addr, err := netlink.ParseAddr(addrStr) + if err != nil { + return fmt.Errorf("parse addr: %w", err) + } + + err = netlink.AddrAdd(l, addr) + if os.IsExist(err) { + log.Infof("interface %s already has the address: %s", name, addrStr) + } else if err != nil { + return fmt.Errorf("add addr: %w", err) + } + + // On linux, the link must be brought up + if err := netlink.LinkSetUp(l); err != nil { + return fmt.Errorf("link setup: %w", err) + } + + return nil +} diff --git a/iface/tun_netstack.go b/iface/tun_netstack.go index e1d01ecc9..beb3acc3f 100644 --- a/iface/tun_netstack.go +++ b/iface/tun_netstack.go @@ -30,7 +30,7 @@ type tunNetstackDevice struct { configurer wgConfigurer } -func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string) wgTunDevice { +func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string, filterFn bind.FilterFn) wgTunDevice { return &tunNetstackDevice{ name: name, address: address, @@ -38,7 +38,7 @@ func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string key: key, mtu: mtu, listenAddress: listenAddress, - iceBind: bind.NewICEBind(transportNet), + iceBind: bind.NewICEBind(transportNet, filterFn), } } diff --git a/iface/tun_usp_linux.go b/iface/tun_usp_unix.go similarity index 75% rename from iface/tun_usp_linux.go rename to iface/tun_usp_unix.go index 9f0210228..b18794b25 100644 --- a/iface/tun_usp_linux.go +++ b/iface/tun_usp_unix.go @@ -1,14 +1,14 @@ -//go:build linux && !android +//go:build (linux && !android) || freebsd package iface import ( "fmt" "os" + "runtime" "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" - "github.com/vishvananda/netlink" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" @@ -29,15 +29,18 @@ type tunUSPDevice struct { configurer wgConfigurer } -func newTunUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net) wgTunDevice { +func newTunUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice { log.Infof("using userspace bind mode") + + checkUser() + return &tunUSPDevice{ name: name, address: address, port: port, key: key, mtu: mtu, - iceBind: bind.NewICEBind(transportNet), + iceBind: bind.NewICEBind(transportNet, filterFn), } } @@ -129,30 +132,14 @@ func (t *tunUSPDevice) Wrapper() *DeviceWrapper { func (t *tunUSPDevice) assignAddr() error { link := newWGLink(t.name) - //delete existing addresses - list, err := netlink.AddrList(link, 0) - if err != nil { - return err - } - if len(list) > 0 { - for _, a := range list { - addr := a - err = netlink.AddrDel(link, &addr) - if err != nil { - return err - } + return link.assignAddr(t.address) +} + +func checkUser() { + if runtime.GOOS == "freebsd" { + euid := os.Geteuid() + if euid != 0 { + log.Warn("newTunUSPDevice: on netbird must run as root to be able to assign address to the tun interface with ifconfig") } } - - log.Debugf("adding address %s to interface: %s", t.address.String(), t.name) - addr, _ := netlink.ParseAddr(t.address.String()) - err = netlink.AddrAdd(link, addr) - if os.IsExist(err) { - log.Infof("interface %s already has the address: %s", t.name, t.address.String()) - } else if err != nil { - return err - } - // On linux, the link must be brought up - err = netlink.LinkSetUp(link) - return err } diff --git a/iface/tun_windows.go b/iface/tun_windows.go index 900e62fc3..5c77f1d16 100644 --- a/iface/tun_windows.go +++ b/iface/tun_windows.go @@ -29,14 +29,14 @@ type tunDevice struct { configurer wgConfigurer } -func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net) wgTunDevice { +func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice { return &tunDevice{ name: name, address: address, port: port, key: key, mtu: mtu, - iceBind: bind.NewICEBind(transportNet), + iceBind: bind.NewICEBind(transportNet, filterFn), } } diff --git a/iface/wg_configurer.go b/iface/wg_configurer.go index 91c57eb9c..dd38ba075 100644 --- a/iface/wg_configurer.go +++ b/iface/wg_configurer.go @@ -1,12 +1,15 @@ package iface import ( + "errors" "net" "time" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) +var ErrPeerNotFound = errors.New("peer not found") + type wgConfigurer interface { configureInterface(privateKey string, port int) error updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error diff --git a/iface/wg_configurer_kernel.go b/iface/wg_configurer_kernel_unix.go similarity index 90% rename from iface/wg_configurer_kernel.go rename to iface/wg_configurer_kernel_unix.go index 67bfb716d..48ea70b7b 100644 --- a/iface/wg_configurer_kernel.go +++ b/iface/wg_configurer_kernel_unix.go @@ -1,4 +1,4 @@ -//go:build linux && !android +//go:build (linux && !android) || freebsd package iface @@ -125,17 +125,17 @@ func (c *wgKernelConfigurer) addAllowedIP(peerKey string, allowedIP string) erro func (c *wgKernelConfigurer) removeAllowedIP(peerKey string, allowedIP string) error { _, ipNet, err := net.ParseCIDR(allowedIP) if err != nil { - return err + return fmt.Errorf("parse allowed IP: %w", err) } peerKeyParsed, err := wgtypes.ParseKey(peerKey) if err != nil { - return err + return fmt.Errorf("parse peer key: %w", err) } existingPeer, err := c.getPeer(c.deviceName, peerKey) if err != nil { - return err + return fmt.Errorf("get peer: %w", err) } newAllowedIPs := existingPeer.AllowedIPs @@ -159,7 +159,7 @@ func (c *wgKernelConfigurer) removeAllowedIP(peerKey string, allowedIP string) e } err = c.configure(config) if err != nil { - return fmt.Errorf(`received error "%w" while removing allowed IP from peer on interface %s with settings: allowed ips %s`, err, c.deviceName, allowedIP) + return fmt.Errorf("remove allowed IP %s on interface %s: %w", allowedIP, c.deviceName, err) } return nil } @@ -167,25 +167,25 @@ func (c *wgKernelConfigurer) removeAllowedIP(peerKey string, allowedIP string) e func (c *wgKernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) { wg, err := wgctrl.New() if err != nil { - return wgtypes.Peer{}, err + return wgtypes.Peer{}, fmt.Errorf("wgctl: %w", err) } defer func() { err = wg.Close() if err != nil { - log.Errorf("got error while closing wgctl: %v", err) + log.Errorf("Got error while closing wgctl: %v", err) } }() wgDevice, err := wg.Device(ifaceName) if err != nil { - return wgtypes.Peer{}, err + return wgtypes.Peer{}, fmt.Errorf("get device %s: %w", ifaceName, err) } for _, peer := range wgDevice.Peers { if peer.PublicKey.String() == peerPubKey { return peer, nil } } - return wgtypes.Peer{}, fmt.Errorf("peer not found") + return wgtypes.Peer{}, ErrPeerNotFound } func (c *wgKernelConfigurer) configure(config wgtypes.Config) error { @@ -200,7 +200,6 @@ func (c *wgKernelConfigurer) configure(config wgtypes.Config) error { if err != nil { return err } - log.Tracef("got Wireguard device %s", c.deviceName) return wg.ConfigureDevice(c.deviceName, config) } diff --git a/iface/wg_configurer_usp.go b/iface/wg_configurer_usp.go index 0c1b6e85a..04a29a60b 100644 --- a/iface/wg_configurer_usp.go +++ b/iface/wg_configurer_usp.go @@ -17,6 +17,8 @@ import ( nbnet "github.com/netbirdio/netbird/util/net" ) +var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found") + type wgUSPConfigurer struct { device *device.Device deviceName string @@ -173,7 +175,7 @@ func (c *wgUSPConfigurer) removeAllowedIP(peerKey string, ip string) error { } if !removedAllowedIP { - return fmt.Errorf("allowedIP not found") + return ErrAllowedIPNotFound } config := wgtypes.Config{ Peers: []wgtypes.PeerConfig{peer}, @@ -301,7 +303,7 @@ func findPeerInfo(ipcInput string, peerKey string, searchConfigKeys []string) (m } } if !foundPeer { - return nil, fmt.Errorf("peer not found: %s", peerKey) + return nil, fmt.Errorf("%w: %s", ErrPeerNotFound, peerKey) } return configFound, nil diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl index 747eebd53..6b6831493 100644 --- a/infrastructure_files/docker-compose.yml.tmpl +++ b/infrastructure_files/docker-compose.yml.tmpl @@ -28,7 +28,11 @@ services: - LETSENCRYPT_EMAIL=$NETBIRD_LETSENCRYPT_EMAIL volumes: - $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt/ - + logging: + driver: "json-file" + options: + max-size: "500m" + max-file: "2" # Signal signal: image: netbirdio/signal:$NETBIRD_SIGNAL_TAG @@ -40,6 +44,11 @@ services: # # port and command for Let's Encrypt validation # - 443:443 # command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"] + logging: + driver: "json-file" + options: + max-size: "500m" + max-file: "2" # Management management: @@ -63,12 +72,16 @@ services: "--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN", "--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN" ] - + logging: + driver: "json-file" + options: + max-size: "500m" + max-file: "2" # Coturn coturn: image: coturn/coturn:$COTURN_TAG restart: unless-stopped - domainname: $TURN_DOMAIN + #domainname: $TURN_DOMAIN # only needed when TLS is enabled volumes: - ./turnserver.conf:/etc/turnserver.conf:ro # - ./privkey.pem:/etc/coturn/private/privkey.pem:ro @@ -76,7 +89,11 @@ services: network_mode: host command: - -c /etc/turnserver.conf - + logging: + driver: "json-file" + options: + max-size: "500m" + max-file: "2" volumes: $MGMT_VOLUMENAME: $SIGNAL_VOLUMENAME: diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index 29f0e4606..5c33e2db6 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -50,7 +50,7 @@ check_jq() { wait_crdb() { set +e while true; do - if $DOCKER_COMPOSE_COMMAND exec -T crdb curl -sf -o /dev/null 'http://localhost:8080/health?ready=1'; then + if $DOCKER_COMPOSE_COMMAND exec -T zdb curl -sf -o /dev/null 'http://localhost:8080/health?ready=1'; then break fi echo -n " ." @@ -61,14 +61,16 @@ wait_crdb() { } init_crdb() { - echo -e "\nInitializing Zitadel's CockroachDB\n\n" - $DOCKER_COMPOSE_COMMAND up -d crdb - echo "" - # shellcheck disable=SC2028 - echo -n "Waiting cockroachDB to become ready " - wait_crdb - $DOCKER_COMPOSE_COMMAND exec -T crdb /bin/bash -c "cp /cockroach/certs/* /zitadel-certs/ && cockroach cert create-client --overwrite --certs-dir /zitadel-certs/ --ca-key /zitadel-certs/ca.key zitadel_user && chown -R 1000:1000 /zitadel-certs/" - handle_request_command_status $? "init_crdb failed" "" + if [[ $ZITADEL_DATABASE == "cockroach" ]]; then + echo -e "\nInitializing Zitadel's CockroachDB\n\n" + $DOCKER_COMPOSE_COMMAND up -d zdb + echo "" + # shellcheck disable=SC2028 + echo -n "Waiting CockroachDB to become ready" + wait_crdb + $DOCKER_COMPOSE_COMMAND exec -T zdb /bin/bash -c "cp /cockroach/certs/* /zitadel-certs/ && cockroach cert create-client --overwrite --certs-dir /zitadel-certs/ --ca-key /zitadel-certs/ca.key zitadel_user && chown -R 1000:1000 /zitadel-certs/" + handle_request_command_status $? "init_crdb failed" "" + fi } get_main_ip_address() { @@ -156,7 +158,7 @@ create_new_application() { "'"$BASE_REDIRECT_URL2"'" ], "postLogoutRedirectUris": [ - "'"$LOGOUT_URL"'" + "'"$LOGOUT_URL"'" ], "RESPONSETypes": [ "OIDC_RESPONSE_TYPE_CODE" @@ -461,6 +463,20 @@ initEnvironment() { exit 1 fi + if [[ $ZITADEL_DATABASE == "cockroach" ]]; then + echo "Use CockroachDB as Zitadel database." + ZDB=$(renderDockerComposeCockroachDB) + ZITADEL_DB_ENV=$(renderZitadelCockroachDBEnv) + else + echo "Use Postgres as default Zitadel database." + echo "For using CockroachDB please the environment variable 'export ZITADEL_DATABASE=cockroach'." + POSTGRES_ROOT_PASSWORD="$(openssl rand -base64 32 | sed 's/=//g')@" + POSTGRES_ZITADEL_PASSWORD="$(openssl rand -base64 32 | sed 's/=//g')@" + ZDB=$(renderDockerComposePostgres) + ZITADEL_DB_ENV=$(renderZitadelPostgresEnv) + renderPostgresEnv > zdb.env + fi + echo Rendering initial files... renderDockerCompose > docker-compose.yml renderCaddyfile > Caddyfile @@ -474,7 +490,7 @@ initEnvironment() { init_crdb - echo -e "\nStarting Zidatel IDP for user management\n\n" + echo -e "\nStarting Zitadel IDP for user management\n\n" $DOCKER_COMPOSE_COMMAND up -d caddy zitadel init_zitadel @@ -634,15 +650,15 @@ renderManagementJson() { "ExtraConfig": { "ManagementEndpoint": "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/management/v1" } - }, - "DeviceAuthorizationFlow": { - "Provider": "hosted", - "ProviderConfig": { - "Audience": "$NETBIRD_AUTH_CLIENT_ID_CLI", - "ClientID": "$NETBIRD_AUTH_CLIENT_ID_CLI", - "Scope": "openid" - } - }, + }, + "DeviceAuthorizationFlow": { + "Provider": "hosted", + "ProviderConfig": { + "Audience": "$NETBIRD_AUTH_CLIENT_ID_CLI", + "ClientID": "$NETBIRD_AUTH_CLIENT_ID_CLI", + "Scope": "openid" + } + }, "PKCEAuthorizationFlow": { "ProviderConfig": { "Audience": "$NETBIRD_AUTH_CLIENT_ID_CLI", @@ -679,16 +695,6 @@ renderZitadelEnv() { cat < 0 && trustedProxiesCount > 0 { - log.Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " + + log.WithContext(ctx).Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " + "This is not recommended way to extract X-Forwarded-For. Consider using one of these options.") } realipOpts := []realip.Option{ @@ -206,8 +218,8 @@ var ( gRPCOpts := []grpc.ServerOption{ grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp), - grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...)), - grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...)), + grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor), + grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor), } var certManager *autocert.Manager @@ -224,7 +236,7 @@ var ( } else if config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "" { tlsConfig, err = loadTLSConfig(config.HttpConfig.CertFile, config.HttpConfig.CertKey) if err != nil { - log.Errorf("cannot load TLS credentials: %v", err) + log.WithContext(ctx).Errorf("cannot load TLS credentials: %v", err) return err } transportCredentials := credentials.NewTLS(tlsConfig) @@ -233,6 +245,7 @@ var ( } jwtValidator, err := jwtclaims.NewJWTValidator( + ctx, config.HttpConfig.AuthIssuer, config.GetAuthAudiences(), config.HttpConfig.AuthKeysLocation, @@ -249,26 +262,24 @@ var ( KeysLocation: config.HttpConfig.AuthKeysLocation, } - ctx, cancel := context.WithCancel(cmd.Context()) - defer cancel() httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator) if err != nil { return fmt.Errorf("failed creating HTTP API handler: %v", err) } ephemeralManager := server.NewEphemeralManager(store, accountManager) - ephemeralManager.LoadInitialPeers() + ephemeralManager.LoadInitialPeers(ctx) gRPCAPIHandler := grpc.NewServer(gRPCOpts...) - srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnRelayTokenManager, appMetrics, ephemeralManager) + srv, err := server.NewServer(ctx, config, accountManager, peersUpdateManager, turnRelayTokenManager, appMetrics, ephemeralManager) if err != nil { return fmt.Errorf("failed creating gRPC API handler: %v", err) } mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv) - installationID, err := getInstallationID(store) + installationID, err := getInstallationID(ctx, store) if err != nil { - log.Errorf("cannot load TLS credentials: %v", err) + log.WithContext(ctx).Errorf("cannot load TLS credentials: %v", err) return err } @@ -278,18 +289,18 @@ var ( idpManager = config.IdpManagerConfig.ManagerType } metricsWorker := metrics.NewWorker(ctx, installationID, store, peersUpdateManager, idpManager) - go metricsWorker.Run() + go metricsWorker.Run(ctx) } var compatListener net.Listener if mgmtPort != ManagementLegacyPort { // The Management gRPC server was running on port 33073 previously. Old agents that are already connected to it // are using port 33073. For compatibility purposes we keep running a 2nd gRPC server on port 33073. - compatListener, err = serveGRPC(gRPCAPIHandler, ManagementLegacyPort) + compatListener, err = serveGRPC(ctx, gRPCAPIHandler, ManagementLegacyPort) if err != nil { return err } - log.Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String()) + log.WithContext(ctx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String()) } rootHandler := handlerFunc(gRPCAPIHandler, httpAPIHandler) @@ -306,8 +317,8 @@ var ( if err != nil { return fmt.Errorf("failed creating TLS listener on port %d: %v", mgmtPort, err) } - log.Infof("running HTTP server (LetsEncrypt challenge handler): %s", cml.Addr().String()) - serveHTTP(cml, certManager.HTTPHandler(nil)) + log.WithContext(ctx).Infof("running HTTP server (LetsEncrypt challenge handler): %s", cml.Addr().String()) + serveHTTP(ctx, cml, certManager.HTTPHandler(nil)) } } else if tlsConfig != nil { listener, err = tls.Listen("tcp", fmt.Sprintf(":%d", mgmtPort), tlsConfig) @@ -321,14 +332,14 @@ var ( } } - log.Infof("management server version %s", version.NetbirdVersion()) - log.Infof("running HTTP server and gRPC server on the same port: %s", listener.Addr().String()) - serveGRPCWithHTTP(listener, rootHandler, tlsEnabled) + log.WithContext(ctx).Infof("management server version %s", version.NetbirdVersion()) + log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", listener.Addr().String()) + serveGRPCWithHTTP(ctx, listener, rootHandler, tlsEnabled) SetupCloseHandler() <-stopCh - integratedPeerValidator.Stop() + integratedPeerValidator.Stop(ctx) if geo != nil { _ = geo.Stop() } @@ -339,39 +350,68 @@ var ( _ = certManager.Listener().Close() } gRPCAPIHandler.Stop() - _ = store.Close() - _ = eventStore.Close() - log.Infof("stopped Management Service") + _ = store.Close(ctx) + _ = eventStore.Close(ctx) + log.WithContext(ctx).Infof("stopped Management Service") return nil }, } ) -func notifyStop(msg string) { +func unaryInterceptor( + ctx context.Context, + req interface{}, + info *grpc.UnaryServerInfo, + handler grpc.UnaryHandler, +) (interface{}, error) { + reqID := uuid.New().String() + //nolint + ctx = context.WithValue(ctx, formatter.ExecutionContextKey, formatter.GRPCSource) + //nolint + ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID) + return handler(ctx, req) +} + +func streamInterceptor( + srv interface{}, + ss grpc.ServerStream, + info *grpc.StreamServerInfo, + handler grpc.StreamHandler, +) error { + reqID := uuid.New().String() + wrapped := grpcMiddleware.WrapServerStream(ss) + //nolint + ctx := context.WithValue(ss.Context(), formatter.ExecutionContextKey, formatter.GRPCSource) + //nolint + wrapped.WrappedContext = context.WithValue(ctx, nbContext.RequestIDKey, reqID) + return handler(srv, wrapped) +} + +func notifyStop(ctx context.Context, msg string) { select { case stopCh <- 1: - log.Error(msg) + log.WithContext(ctx).Error(msg) default: // stop has been already called, nothing to report } } -func getInstallationID(store server.Store) (string, error) { +func getInstallationID(ctx context.Context, store server.Store) (string, error) { installationID := store.GetInstallationID() if installationID != "" { return installationID, nil } installationID = strings.ToUpper(uuid.New().String()) - err := store.SaveInstallationID(installationID) + err := store.SaveInstallationID(ctx, installationID) if err != nil { return "", err } return installationID, nil } -func serveGRPC(grpcServer *grpc.Server, port int) (net.Listener, error) { +func serveGRPC(ctx context.Context, grpcServer *grpc.Server, port int) (net.Listener, error) { listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) if err != nil { return nil, err @@ -379,22 +419,22 @@ func serveGRPC(grpcServer *grpc.Server, port int) (net.Listener, error) { go func() { err := grpcServer.Serve(listener) if err != nil { - notifyStop(fmt.Sprintf("failed running gRPC server on port %d: %v", port, err)) + notifyStop(ctx, fmt.Sprintf("failed running gRPC server on port %d: %v", port, err)) } }() return listener, nil } -func serveHTTP(httpListener net.Listener, handler http.Handler) { +func serveHTTP(ctx context.Context, httpListener net.Listener, handler http.Handler) { go func() { err := http.Serve(httpListener, handler) if err != nil { - notifyStop(fmt.Sprintf("failed running HTTP server: %v", err)) + notifyStop(ctx, fmt.Sprintf("failed running HTTP server: %v", err)) } }() } -func serveGRPCWithHTTP(listener net.Listener, handler http.Handler, tlsEnabled bool) { +func serveGRPCWithHTTP(ctx context.Context, listener net.Listener, handler http.Handler, tlsEnabled bool) { go func() { var err error if tlsEnabled { @@ -411,7 +451,7 @@ func serveGRPCWithHTTP(listener net.Listener, handler http.Handler, tlsEnabled b if err != nil { select { case stopCh <- 1: - log.Errorf("failed to serve HTTP and gRPC server: %v", err) + log.WithContext(ctx).Errorf("failed to serve HTTP and gRPC server: %v", err) default: // stop has been already called, nothing to report } @@ -431,7 +471,7 @@ func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handle }) } -func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) { +func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config, error) { loadedConfig := &server.Config{} _, err := util.ReadJson(mgmtConfigPath, loadedConfig) if err != nil { @@ -452,26 +492,26 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) { oidcEndpoint := loadedConfig.HttpConfig.OIDCConfigEndpoint if oidcEndpoint != "" { // if OIDCConfigEndpoint is specified, we can load DeviceAuthEndpoint and TokenEndpoint automatically - log.Infof("loading OIDC configuration from the provided IDP configuration endpoint %s", oidcEndpoint) - oidcConfig, err := fetchOIDCConfig(oidcEndpoint) + log.WithContext(ctx).Infof("loading OIDC configuration from the provided IDP configuration endpoint %s", oidcEndpoint) + oidcConfig, err := fetchOIDCConfig(ctx, oidcEndpoint) if err != nil { return nil, err } - log.Infof("loaded OIDC configuration from the provided IDP configuration endpoint: %s", oidcEndpoint) + log.WithContext(ctx).Infof("loaded OIDC configuration from the provided IDP configuration endpoint: %s", oidcEndpoint) - log.Infof("overriding HttpConfig.AuthIssuer with a new value %s, previously configured value: %s", + log.WithContext(ctx).Infof("overriding HttpConfig.AuthIssuer with a new value %s, previously configured value: %s", oidcConfig.Issuer, loadedConfig.HttpConfig.AuthIssuer) loadedConfig.HttpConfig.AuthIssuer = oidcConfig.Issuer - log.Infof("overriding HttpConfig.AuthKeysLocation (JWT certs) with a new value %s, previously configured value: %s", + log.WithContext(ctx).Infof("overriding HttpConfig.AuthKeysLocation (JWT certs) with a new value %s, previously configured value: %s", oidcConfig.JwksURI, loadedConfig.HttpConfig.AuthKeysLocation) loadedConfig.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI if !(loadedConfig.DeviceAuthorizationFlow == nil || strings.ToLower(loadedConfig.DeviceAuthorizationFlow.Provider) == string(server.NONE)) { - log.Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s", + log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s", oidcConfig.TokenEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint) loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint - log.Infof("overriding DeviceAuthorizationFlow.DeviceAuthEndpoint with a new value: %s, previously configured value: %s", + log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.DeviceAuthEndpoint with a new value: %s, previously configured value: %s", oidcConfig.DeviceAuthEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint) loadedConfig.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint = oidcConfig.DeviceAuthEndpoint @@ -479,7 +519,7 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) { if err != nil { return nil, err } - log.Infof("overriding DeviceAuthorizationFlow.ProviderConfig.Domain with a new value: %s, previously configured value: %s", + log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.ProviderConfig.Domain with a new value: %s, previously configured value: %s", u.Host, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain) loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host @@ -489,10 +529,10 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) { } if loadedConfig.PKCEAuthorizationFlow != nil { - log.Infof("overriding PKCEAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s", + log.WithContext(ctx).Infof("overriding PKCEAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s", oidcConfig.TokenEndpoint, loadedConfig.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint) loadedConfig.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint - log.Infof("overriding PKCEAuthorizationFlow.AuthorizationEndpoint with a new value: %s, previously configured value: %s", + log.WithContext(ctx).Infof("overriding PKCEAuthorizationFlow.AuthorizationEndpoint with a new value: %s, previously configured value: %s", oidcConfig.AuthorizationEndpoint, loadedConfig.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint) loadedConfig.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint = oidcConfig.AuthorizationEndpoint } @@ -501,8 +541,8 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) { return loadedConfig, err } -func updateMgmtConfig(path string, config *server.Config) error { - return util.DirectWriteJson(path, config) +func updateMgmtConfig(ctx context.Context, path string, config *server.Config) error { + return util.DirectWriteJson(ctx, path, config) } // OIDCConfigResponse used for parsing OIDC config response @@ -515,7 +555,7 @@ type OIDCConfigResponse struct { } // fetchOIDCConfig fetches OIDC configuration from the IDP -func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) { +func fetchOIDCConfig(ctx context.Context, oidcEndpoint string) (OIDCConfigResponse, error) { res, err := http.Get(oidcEndpoint) if err != nil { return OIDCConfigResponse{}, fmt.Errorf("failed fetching OIDC configuration from endpoint %s %v", oidcEndpoint, err) @@ -524,7 +564,7 @@ func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) { defer func() { err := res.Body.Close() if err != nil { - log.Debugf("failed closing response body %v", err) + log.WithContext(ctx).Debugf("failed closing response body %v", err) } }() diff --git a/management/cmd/migration_down.go b/management/cmd/migration_down.go deleted file mode 100644 index 81ef93a6c..000000000 --- a/management/cmd/migration_down.go +++ /dev/null @@ -1,67 +0,0 @@ -package cmd - -import ( - "errors" - "flag" - "fmt" - "os" - "path" - - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - - "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/util" -) - -var shortDown = "Rollback SQLite store to JSON file store. Please make a backup of the SQLite file before running this command." - -var downCmd = &cobra.Command{ - Use: "downgrade [--datadir directory] [--log-file console]", - Aliases: []string{"down"}, - Short: shortDown, - Long: shortDown + - "\n\n" + - "This command reads the content of {datadir}/store.db and migrates it to {datadir}/store.json that can be used by File store driver.", - RunE: func(cmd *cobra.Command, args []string) error { - flag.Parse() - err := util.InitLog(logLevel, logFile) - if err != nil { - return fmt.Errorf("failed initializing log %v", err) - } - - sqliteStorePath := path.Join(mgmtDataDir, "store.db") - if _, err := os.Stat(sqliteStorePath); errors.Is(err, os.ErrNotExist) { - return fmt.Errorf("%s doesn't exist, couldn't continue the operation", sqliteStorePath) - } - - fileStorePath := path.Join(mgmtDataDir, "store.json") - if _, err := os.Stat(fileStorePath); err == nil { - return fmt.Errorf("%s already exists, couldn't continue the operation", fileStorePath) - } - - sqlStore, err := server.NewSqliteStore(mgmtDataDir, nil) - if err != nil { - return fmt.Errorf("failed creating file store: %s: %v", mgmtDataDir, err) - } - - sqliteStoreAccounts := len(sqlStore.GetAllAccounts()) - log.Infof("%d account will be migrated from sqlite store %s to file store %s", - sqliteStoreAccounts, sqliteStorePath, fileStorePath) - - store, err := server.NewFilestoreFromSqliteStore(sqlStore, mgmtDataDir, nil) - if err != nil { - return fmt.Errorf("failed creating file store: %s: %v", mgmtDataDir, err) - } - - fsStoreAccounts := len(store.GetAllAccounts()) - if fsStoreAccounts != sqliteStoreAccounts { - return fmt.Errorf("failed to migrate accounts from sqlite to file[]. Expected accounts: %d, got: %d", - sqliteStoreAccounts, fsStoreAccounts) - } - - log.Info("Migration finished successfully") - - return nil - }, -} diff --git a/management/cmd/migration_up.go b/management/cmd/migration_up.go index 5c7505cfc..7aa11f0c9 100644 --- a/management/cmd/migration_up.go +++ b/management/cmd/migration_up.go @@ -1,16 +1,16 @@ package cmd import ( - "errors" + "context" "flag" "fmt" - "os" - "path" - "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/util" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + + "github.com/netbirdio/netbird/formatter" + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/util" ) var shortUp = "Migrate JSON file store to SQLite store. Please make a backup of the JSON file before running this command." @@ -29,37 +29,13 @@ var upCmd = &cobra.Command{ return fmt.Errorf("failed initializing log %v", err) } - fileStorePath := path.Join(mgmtDataDir, "store.json") - if _, err := os.Stat(fileStorePath); errors.Is(err, os.ErrNotExist) { - return fmt.Errorf("%s doesn't exist, couldn't continue the operation", fileStorePath) + //nolint + ctx := context.WithValue(cmd.Context(), formatter.ExecutionContextKey, formatter.SystemSource) + + if err := server.MigrateFileStoreToSqlite(ctx, mgmtDataDir); err != nil { + return err } - - sqlStorePath := path.Join(mgmtDataDir, "store.db") - if _, err := os.Stat(sqlStorePath); err == nil { - return fmt.Errorf("%s already exists, couldn't continue the operation", sqlStorePath) - } - - fstore, err := server.NewFileStore(mgmtDataDir, nil) - if err != nil { - return fmt.Errorf("failed creating file store: %s: %v", mgmtDataDir, err) - } - - fsStoreAccounts := len(fstore.GetAllAccounts()) - log.Infof("%d account will be migrated from file store %s to sqlite store %s", - fsStoreAccounts, fileStorePath, sqlStorePath) - - store, err := server.NewSqliteStoreFromFileStore(fstore, mgmtDataDir, nil) - if err != nil { - return fmt.Errorf("failed creating file store: %s: %v", mgmtDataDir, err) - } - - sqliteStoreAccounts := len(store.GetAllAccounts()) - if fsStoreAccounts != sqliteStoreAccounts { - return fmt.Errorf("failed to migrate accounts from file to sqlite. Expected accounts: %d, got: %d", - fsStoreAccounts, sqliteStoreAccounts) - } - - log.Info("Migration finished successfully") + log.WithContext(ctx).Info("Migration finished successfully") return nil }, diff --git a/management/cmd/root.go b/management/cmd/root.go index f5c533969..32c044a56 100644 --- a/management/cmd/root.go +++ b/management/cmd/root.go @@ -75,7 +75,6 @@ func init() { migrationCmd.MarkFlagRequired("datadir") //nolint migrationCmd.AddCommand(upCmd) - migrationCmd.AddCommand(downCmd) rootCmd.AddCommand(migrationCmd) } diff --git a/management/domain/domain.go b/management/domain/domain.go new file mode 100644 index 000000000..e7e6b050a --- /dev/null +++ b/management/domain/domain.go @@ -0,0 +1,34 @@ +package domain + +import ( + "golang.org/x/net/idna" +) + +type Domain string + +// String converts the Domain to a non-punycode string. +func (d Domain) String() (string, error) { + unicode, err := idna.ToUnicode(string(d)) + if err != nil { + return "", err + } + return unicode, nil +} + +// SafeString converts the Domain to a non-punycode string, falling back to the original string if conversion fails. +func (d Domain) SafeString() string { + str, err := d.String() + if err != nil { + str = string(d) + } + return str +} + +// FromString creates a Domain from a string, converting it to punycode. +func FromString(s string) (Domain, error) { + ascii, err := idna.ToASCII(s) + if err != nil { + return "", err + } + return Domain(ascii), nil +} diff --git a/management/domain/list.go b/management/domain/list.go new file mode 100644 index 000000000..413a23442 --- /dev/null +++ b/management/domain/list.go @@ -0,0 +1,83 @@ +package domain + +import "strings" + +type List []Domain + +// ToStringList converts a List to a slice of string. +func (d List) ToStringList() ([]string, error) { + var list []string + for _, domain := range d { + s, err := domain.String() + if err != nil { + return nil, err + } + list = append(list, s) + } + return list, nil +} + +// ToPunycodeList converts the List to a slice of Punycode-encoded domain strings. +func (d List) ToPunycodeList() []string { + var list []string + for _, domain := range d { + list = append(list, string(domain)) + } + return list +} + +// ToSafeStringList converts the List to a slice of non-punycode strings. +// If a domain cannot be converted, the original string is used. +func (d List) ToSafeStringList() []string { + var list []string + for _, domain := range d { + list = append(list, domain.SafeString()) + } + return list +} + +// String converts List to a comma-separated string. +func (d List) String() (string, error) { + list, err := d.ToStringList() + if err != nil { + return "", err + } + return strings.Join(list, ", "), nil +} + +// SafeString converts List to a comma-separated non-punycode string. +// If a domain cannot be converted, the original string is used. +func (d List) SafeString() string { + str, err := d.String() + if err != nil { + return strings.Join(d.ToPunycodeList(), ", ") + } + return str +} + +// PunycodeString converts the List to a comma-separated string of Punycode-encoded domains. +func (d List) PunycodeString() string { + return strings.Join(d.ToPunycodeList(), ", ") +} + +// FromStringList creates a DomainList from a slice of string. +func FromStringList(s []string) (List, error) { + var dl List + for _, domain := range s { + d, err := FromString(domain) + if err != nil { + return nil, err + } + dl = append(dl, d) + } + return dl, nil +} + +// FromPunycodeList creates a List from a slice of Punycode-encoded domain strings. +func FromPunycodeList(s []string) List { + var dl List + for _, domain := range s { + dl = append(dl, Domain(domain)) + } + return dl +} diff --git a/management/proto/management.pb.go b/management/proto/management.pb.go index bd097f9b3..48f048c4c 100644 --- a/management/proto/management.pb.go +++ b/management/proto/management.pb.go @@ -73,7 +73,7 @@ func (x HostConfig_Protocol) Number() protoreflect.EnumNumber { // Deprecated: Use HostConfig_Protocol.Descriptor instead. func (HostConfig_Protocol) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{11, 0} + return file_management_proto_rawDescGZIP(), []int{13, 0} } type DeviceAuthorizationFlowProvider int32 @@ -116,7 +116,7 @@ func (x DeviceAuthorizationFlowProvider) Number() protoreflect.EnumNumber { // Deprecated: Use DeviceAuthorizationFlowProvider.Descriptor instead. func (DeviceAuthorizationFlowProvider) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{19, 0} + return file_management_proto_rawDescGZIP(), []int{21, 0} } type FirewallRuleDirection int32 @@ -162,7 +162,7 @@ func (x FirewallRuleDirection) Number() protoreflect.EnumNumber { // Deprecated: Use FirewallRuleDirection.Descriptor instead. func (FirewallRuleDirection) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{29, 0} + return file_management_proto_rawDescGZIP(), []int{31, 0} } type FirewallRuleAction int32 @@ -208,7 +208,7 @@ func (x FirewallRuleAction) Number() protoreflect.EnumNumber { // Deprecated: Use FirewallRuleAction.Descriptor instead. func (FirewallRuleAction) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{29, 1} + return file_management_proto_rawDescGZIP(), []int{31, 1} } type FirewallRuleProtocol int32 @@ -263,7 +263,7 @@ func (x FirewallRuleProtocol) Number() protoreflect.EnumNumber { // Deprecated: Use FirewallRuleProtocol.Descriptor instead. func (FirewallRuleProtocol) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{29, 2} + return file_management_proto_rawDescGZIP(), []int{31, 2} } type EncryptedMessage struct { @@ -336,6 +336,9 @@ type SyncRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields + + // Meta data of the peer + Meta *PeerSystemMeta `protobuf:"bytes,1,opt,name=meta,proto3" json:"meta,omitempty"` } func (x *SyncRequest) Reset() { @@ -370,6 +373,13 @@ func (*SyncRequest) Descriptor() ([]byte, []int) { return file_management_proto_rawDescGZIP(), []int{1} } +func (x *SyncRequest) GetMeta() *PeerSystemMeta { + if x != nil { + return x.Meta + } + return nil +} + // SyncResponse represents a state that should be applied to the local peer (e.g. Wiretrustee servers config as well as local peer and remote peers configs) type SyncResponse struct { state protoimpl.MessageState @@ -386,6 +396,8 @@ type SyncResponse struct { // Deprecated. Use NetworkMap.remotePeersIsEmpty RemotePeersIsEmpty bool `protobuf:"varint,4,opt,name=remotePeersIsEmpty,proto3" json:"remotePeersIsEmpty,omitempty"` NetworkMap *NetworkMap `protobuf:"bytes,5,opt,name=NetworkMap,proto3" json:"NetworkMap,omitempty"` + // Posture checks to be evaluated by client + Checks []*Checks `protobuf:"bytes,6,rep,name=Checks,proto3" json:"Checks,omitempty"` } func (x *SyncResponse) Reset() { @@ -455,6 +467,61 @@ func (x *SyncResponse) GetNetworkMap() *NetworkMap { return nil } +func (x *SyncResponse) GetChecks() []*Checks { + if x != nil { + return x.Checks + } + return nil +} + +type SyncMetaRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Meta data of the peer + Meta *PeerSystemMeta `protobuf:"bytes,1,opt,name=meta,proto3" json:"meta,omitempty"` +} + +func (x *SyncMetaRequest) Reset() { + *x = SyncMetaRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SyncMetaRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SyncMetaRequest) ProtoMessage() {} + +func (x *SyncMetaRequest) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SyncMetaRequest.ProtoReflect.Descriptor instead. +func (*SyncMetaRequest) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{3} +} + +func (x *SyncMetaRequest) GetMeta() *PeerSystemMeta { + if x != nil { + return x.Meta + } + return nil +} + type LoginRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -473,7 +540,7 @@ type LoginRequest struct { func (x *LoginRequest) Reset() { *x = LoginRequest{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[3] + mi := &file_management_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -486,7 +553,7 @@ func (x *LoginRequest) String() string { func (*LoginRequest) ProtoMessage() {} func (x *LoginRequest) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[3] + mi := &file_management_proto_msgTypes[4] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -499,7 +566,7 @@ func (x *LoginRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use LoginRequest.ProtoReflect.Descriptor instead. func (*LoginRequest) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{3} + return file_management_proto_rawDescGZIP(), []int{4} } func (x *LoginRequest) GetSetupKey() string { @@ -546,7 +613,7 @@ type PeerKeys struct { func (x *PeerKeys) Reset() { *x = PeerKeys{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[4] + mi := &file_management_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -559,7 +626,7 @@ func (x *PeerKeys) String() string { func (*PeerKeys) ProtoMessage() {} func (x *PeerKeys) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[4] + mi := &file_management_proto_msgTypes[5] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -572,7 +639,7 @@ func (x *PeerKeys) ProtoReflect() protoreflect.Message { // Deprecated: Use PeerKeys.ProtoReflect.Descriptor instead. func (*PeerKeys) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{4} + return file_management_proto_rawDescGZIP(), []int{5} } func (x *PeerKeys) GetSshPubKey() []byte { @@ -604,7 +671,7 @@ type Environment struct { func (x *Environment) Reset() { *x = Environment{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[5] + mi := &file_management_proto_msgTypes[6] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -617,7 +684,7 @@ func (x *Environment) String() string { func (*Environment) ProtoMessage() {} func (x *Environment) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[5] + mi := &file_management_proto_msgTypes[6] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -630,7 +697,7 @@ func (x *Environment) ProtoReflect() protoreflect.Message { // Deprecated: Use Environment.ProtoReflect.Descriptor instead. func (*Environment) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{5} + return file_management_proto_rawDescGZIP(), []int{6} } func (x *Environment) GetCloud() string { @@ -647,6 +714,73 @@ func (x *Environment) GetPlatform() string { return "" } +// File represents a file on the system. +type File struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // path is the path to the file. + Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"` + // exist indicate whether the file exists. + Exist bool `protobuf:"varint,2,opt,name=exist,proto3" json:"exist,omitempty"` + // processIsRunning indicates whether the file is a running process or not. + ProcessIsRunning bool `protobuf:"varint,3,opt,name=processIsRunning,proto3" json:"processIsRunning,omitempty"` +} + +func (x *File) Reset() { + *x = File{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *File) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*File) ProtoMessage() {} + +func (x *File) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[7] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use File.ProtoReflect.Descriptor instead. +func (*File) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{7} +} + +func (x *File) GetPath() string { + if x != nil { + return x.Path + } + return "" +} + +func (x *File) GetExist() bool { + if x != nil { + return x.Exist + } + return false +} + +func (x *File) GetProcessIsRunning() bool { + if x != nil { + return x.ProcessIsRunning + } + return false +} + // PeerSystemMeta is machine meta data like OS and version. type PeerSystemMeta struct { state protoimpl.MessageState @@ -668,12 +802,13 @@ type PeerSystemMeta struct { SysProductName string `protobuf:"bytes,13,opt,name=sysProductName,proto3" json:"sysProductName,omitempty"` SysManufacturer string `protobuf:"bytes,14,opt,name=sysManufacturer,proto3" json:"sysManufacturer,omitempty"` Environment *Environment `protobuf:"bytes,15,opt,name=environment,proto3" json:"environment,omitempty"` + Files []*File `protobuf:"bytes,16,rep,name=files,proto3" json:"files,omitempty"` } func (x *PeerSystemMeta) Reset() { *x = PeerSystemMeta{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[6] + mi := &file_management_proto_msgTypes[8] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -686,7 +821,7 @@ func (x *PeerSystemMeta) String() string { func (*PeerSystemMeta) ProtoMessage() {} func (x *PeerSystemMeta) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[6] + mi := &file_management_proto_msgTypes[8] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -699,7 +834,7 @@ func (x *PeerSystemMeta) ProtoReflect() protoreflect.Message { // Deprecated: Use PeerSystemMeta.ProtoReflect.Descriptor instead. func (*PeerSystemMeta) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{6} + return file_management_proto_rawDescGZIP(), []int{8} } func (x *PeerSystemMeta) GetHostname() string { @@ -807,6 +942,13 @@ func (x *PeerSystemMeta) GetEnvironment() *Environment { return nil } +func (x *PeerSystemMeta) GetFiles() []*File { + if x != nil { + return x.Files + } + return nil +} + type LoginResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -816,12 +958,14 @@ type LoginResponse struct { WiretrusteeConfig *WiretrusteeConfig `protobuf:"bytes,1,opt,name=wiretrusteeConfig,proto3" json:"wiretrusteeConfig,omitempty"` // Peer local config PeerConfig *PeerConfig `protobuf:"bytes,2,opt,name=peerConfig,proto3" json:"peerConfig,omitempty"` + // Posture checks to be evaluated by client + Checks []*Checks `protobuf:"bytes,3,rep,name=Checks,proto3" json:"Checks,omitempty"` } func (x *LoginResponse) Reset() { *x = LoginResponse{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[7] + mi := &file_management_proto_msgTypes[9] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -834,7 +978,7 @@ func (x *LoginResponse) String() string { func (*LoginResponse) ProtoMessage() {} func (x *LoginResponse) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[7] + mi := &file_management_proto_msgTypes[9] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -847,7 +991,7 @@ func (x *LoginResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use LoginResponse.ProtoReflect.Descriptor instead. func (*LoginResponse) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{7} + return file_management_proto_rawDescGZIP(), []int{9} } func (x *LoginResponse) GetWiretrusteeConfig() *WiretrusteeConfig { @@ -864,6 +1008,13 @@ func (x *LoginResponse) GetPeerConfig() *PeerConfig { return nil } +func (x *LoginResponse) GetChecks() []*Checks { + if x != nil { + return x.Checks + } + return nil +} + type ServerKeyResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -880,7 +1031,7 @@ type ServerKeyResponse struct { func (x *ServerKeyResponse) Reset() { *x = ServerKeyResponse{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[8] + mi := &file_management_proto_msgTypes[10] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -893,7 +1044,7 @@ func (x *ServerKeyResponse) String() string { func (*ServerKeyResponse) ProtoMessage() {} func (x *ServerKeyResponse) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[8] + mi := &file_management_proto_msgTypes[10] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -906,7 +1057,7 @@ func (x *ServerKeyResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ServerKeyResponse.ProtoReflect.Descriptor instead. func (*ServerKeyResponse) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{8} + return file_management_proto_rawDescGZIP(), []int{10} } func (x *ServerKeyResponse) GetKey() string { @@ -939,7 +1090,7 @@ type Empty struct { func (x *Empty) Reset() { *x = Empty{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[9] + mi := &file_management_proto_msgTypes[11] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -952,7 +1103,7 @@ func (x *Empty) String() string { func (*Empty) ProtoMessage() {} func (x *Empty) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[9] + mi := &file_management_proto_msgTypes[11] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -965,7 +1116,7 @@ func (x *Empty) ProtoReflect() protoreflect.Message { // Deprecated: Use Empty.ProtoReflect.Descriptor instead. func (*Empty) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{9} + return file_management_proto_rawDescGZIP(), []int{11} } // WiretrusteeConfig is a common configuration of any Wiretrustee peer. It contains STUN, TURN, Signal and Management servers configurations @@ -986,7 +1137,7 @@ type WiretrusteeConfig struct { func (x *WiretrusteeConfig) Reset() { *x = WiretrusteeConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[10] + mi := &file_management_proto_msgTypes[12] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -999,7 +1150,7 @@ func (x *WiretrusteeConfig) String() string { func (*WiretrusteeConfig) ProtoMessage() {} func (x *WiretrusteeConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[10] + mi := &file_management_proto_msgTypes[12] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1012,7 +1163,7 @@ func (x *WiretrusteeConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use WiretrusteeConfig.ProtoReflect.Descriptor instead. func (*WiretrusteeConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{10} + return file_management_proto_rawDescGZIP(), []int{12} } func (x *WiretrusteeConfig) GetStuns() []*HostConfig { @@ -1057,7 +1208,7 @@ type HostConfig struct { func (x *HostConfig) Reset() { *x = HostConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[11] + mi := &file_management_proto_msgTypes[13] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1070,7 +1221,7 @@ func (x *HostConfig) String() string { func (*HostConfig) ProtoMessage() {} func (x *HostConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[11] + mi := &file_management_proto_msgTypes[13] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1083,7 +1234,7 @@ func (x *HostConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use HostConfig.ProtoReflect.Descriptor instead. func (*HostConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{11} + return file_management_proto_rawDescGZIP(), []int{13} } func (x *HostConfig) GetUri() string { @@ -1113,7 +1264,7 @@ type RelayConfig struct { func (x *RelayConfig) Reset() { *x = RelayConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[12] + mi := &file_management_proto_msgTypes[14] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1126,7 +1277,7 @@ func (x *RelayConfig) String() string { func (*RelayConfig) ProtoMessage() {} func (x *RelayConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[12] + mi := &file_management_proto_msgTypes[14] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1139,7 +1290,7 @@ func (x *RelayConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use RelayConfig.ProtoReflect.Descriptor instead. func (*RelayConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{12} + return file_management_proto_rawDescGZIP(), []int{14} } func (x *RelayConfig) GetUrls() []string { @@ -1178,7 +1329,7 @@ type ProtectedHostConfig struct { func (x *ProtectedHostConfig) Reset() { *x = ProtectedHostConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[13] + mi := &file_management_proto_msgTypes[15] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1191,7 +1342,7 @@ func (x *ProtectedHostConfig) String() string { func (*ProtectedHostConfig) ProtoMessage() {} func (x *ProtectedHostConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[13] + mi := &file_management_proto_msgTypes[15] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1204,7 +1355,7 @@ func (x *ProtectedHostConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use ProtectedHostConfig.ProtoReflect.Descriptor instead. func (*ProtectedHostConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{13} + return file_management_proto_rawDescGZIP(), []int{15} } func (x *ProtectedHostConfig) GetHostConfig() *HostConfig { @@ -1248,7 +1399,7 @@ type PeerConfig struct { func (x *PeerConfig) Reset() { *x = PeerConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[14] + mi := &file_management_proto_msgTypes[16] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1261,7 +1412,7 @@ func (x *PeerConfig) String() string { func (*PeerConfig) ProtoMessage() {} func (x *PeerConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[14] + mi := &file_management_proto_msgTypes[16] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1274,7 +1425,7 @@ func (x *PeerConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use PeerConfig.ProtoReflect.Descriptor instead. func (*PeerConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{14} + return file_management_proto_rawDescGZIP(), []int{16} } func (x *PeerConfig) GetAddress() string { @@ -1336,7 +1487,7 @@ type NetworkMap struct { func (x *NetworkMap) Reset() { *x = NetworkMap{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[15] + mi := &file_management_proto_msgTypes[17] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1349,7 +1500,7 @@ func (x *NetworkMap) String() string { func (*NetworkMap) ProtoMessage() {} func (x *NetworkMap) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[15] + mi := &file_management_proto_msgTypes[17] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1362,7 +1513,7 @@ func (x *NetworkMap) ProtoReflect() protoreflect.Message { // Deprecated: Use NetworkMap.ProtoReflect.Descriptor instead. func (*NetworkMap) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{15} + return file_management_proto_rawDescGZIP(), []int{17} } func (x *NetworkMap) GetSerial() uint64 { @@ -1448,7 +1599,7 @@ type RemotePeerConfig struct { func (x *RemotePeerConfig) Reset() { *x = RemotePeerConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[16] + mi := &file_management_proto_msgTypes[18] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1461,7 +1612,7 @@ func (x *RemotePeerConfig) String() string { func (*RemotePeerConfig) ProtoMessage() {} func (x *RemotePeerConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[16] + mi := &file_management_proto_msgTypes[18] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1474,7 +1625,7 @@ func (x *RemotePeerConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use RemotePeerConfig.ProtoReflect.Descriptor instead. func (*RemotePeerConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{16} + return file_management_proto_rawDescGZIP(), []int{18} } func (x *RemotePeerConfig) GetWgPubKey() string { @@ -1521,7 +1672,7 @@ type SSHConfig struct { func (x *SSHConfig) Reset() { *x = SSHConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[17] + mi := &file_management_proto_msgTypes[19] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1534,7 +1685,7 @@ func (x *SSHConfig) String() string { func (*SSHConfig) ProtoMessage() {} func (x *SSHConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[17] + mi := &file_management_proto_msgTypes[19] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1547,7 +1698,7 @@ func (x *SSHConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use SSHConfig.ProtoReflect.Descriptor instead. func (*SSHConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{17} + return file_management_proto_rawDescGZIP(), []int{19} } func (x *SSHConfig) GetSshEnabled() bool { @@ -1574,7 +1725,7 @@ type DeviceAuthorizationFlowRequest struct { func (x *DeviceAuthorizationFlowRequest) Reset() { *x = DeviceAuthorizationFlowRequest{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[18] + mi := &file_management_proto_msgTypes[20] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1587,7 +1738,7 @@ func (x *DeviceAuthorizationFlowRequest) String() string { func (*DeviceAuthorizationFlowRequest) ProtoMessage() {} func (x *DeviceAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[18] + mi := &file_management_proto_msgTypes[20] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1600,7 +1751,7 @@ func (x *DeviceAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DeviceAuthorizationFlowRequest.ProtoReflect.Descriptor instead. func (*DeviceAuthorizationFlowRequest) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{18} + return file_management_proto_rawDescGZIP(), []int{20} } // DeviceAuthorizationFlow represents Device Authorization Flow information @@ -1619,7 +1770,7 @@ type DeviceAuthorizationFlow struct { func (x *DeviceAuthorizationFlow) Reset() { *x = DeviceAuthorizationFlow{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[19] + mi := &file_management_proto_msgTypes[21] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1632,7 +1783,7 @@ func (x *DeviceAuthorizationFlow) String() string { func (*DeviceAuthorizationFlow) ProtoMessage() {} func (x *DeviceAuthorizationFlow) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[19] + mi := &file_management_proto_msgTypes[21] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1645,7 +1796,7 @@ func (x *DeviceAuthorizationFlow) ProtoReflect() protoreflect.Message { // Deprecated: Use DeviceAuthorizationFlow.ProtoReflect.Descriptor instead. func (*DeviceAuthorizationFlow) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{19} + return file_management_proto_rawDescGZIP(), []int{21} } func (x *DeviceAuthorizationFlow) GetProvider() DeviceAuthorizationFlowProvider { @@ -1672,7 +1823,7 @@ type PKCEAuthorizationFlowRequest struct { func (x *PKCEAuthorizationFlowRequest) Reset() { *x = PKCEAuthorizationFlowRequest{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[20] + mi := &file_management_proto_msgTypes[22] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1685,7 +1836,7 @@ func (x *PKCEAuthorizationFlowRequest) String() string { func (*PKCEAuthorizationFlowRequest) ProtoMessage() {} func (x *PKCEAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[20] + mi := &file_management_proto_msgTypes[22] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1698,7 +1849,7 @@ func (x *PKCEAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use PKCEAuthorizationFlowRequest.ProtoReflect.Descriptor instead. func (*PKCEAuthorizationFlowRequest) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{20} + return file_management_proto_rawDescGZIP(), []int{22} } // PKCEAuthorizationFlow represents Authorization Code Flow information @@ -1715,7 +1866,7 @@ type PKCEAuthorizationFlow struct { func (x *PKCEAuthorizationFlow) Reset() { *x = PKCEAuthorizationFlow{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[21] + mi := &file_management_proto_msgTypes[23] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1728,7 +1879,7 @@ func (x *PKCEAuthorizationFlow) String() string { func (*PKCEAuthorizationFlow) ProtoMessage() {} func (x *PKCEAuthorizationFlow) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[21] + mi := &file_management_proto_msgTypes[23] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1741,7 +1892,7 @@ func (x *PKCEAuthorizationFlow) ProtoReflect() protoreflect.Message { // Deprecated: Use PKCEAuthorizationFlow.ProtoReflect.Descriptor instead. func (*PKCEAuthorizationFlow) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{21} + return file_management_proto_rawDescGZIP(), []int{23} } func (x *PKCEAuthorizationFlow) GetProviderConfig() *ProviderConfig { @@ -1783,7 +1934,7 @@ type ProviderConfig struct { func (x *ProviderConfig) Reset() { *x = ProviderConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[22] + mi := &file_management_proto_msgTypes[24] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1796,7 +1947,7 @@ func (x *ProviderConfig) String() string { func (*ProviderConfig) ProtoMessage() {} func (x *ProviderConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[22] + mi := &file_management_proto_msgTypes[24] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1809,7 +1960,7 @@ func (x *ProviderConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use ProviderConfig.ProtoReflect.Descriptor instead. func (*ProviderConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{22} + return file_management_proto_rawDescGZIP(), []int{24} } func (x *ProviderConfig) GetClientID() string { @@ -1888,19 +2039,21 @@ type Route struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ID string `protobuf:"bytes,1,opt,name=ID,proto3" json:"ID,omitempty"` - Network string `protobuf:"bytes,2,opt,name=Network,proto3" json:"Network,omitempty"` - NetworkType int64 `protobuf:"varint,3,opt,name=NetworkType,proto3" json:"NetworkType,omitempty"` - Peer string `protobuf:"bytes,4,opt,name=Peer,proto3" json:"Peer,omitempty"` - Metric int64 `protobuf:"varint,5,opt,name=Metric,proto3" json:"Metric,omitempty"` - Masquerade bool `protobuf:"varint,6,opt,name=Masquerade,proto3" json:"Masquerade,omitempty"` - NetID string `protobuf:"bytes,7,opt,name=NetID,proto3" json:"NetID,omitempty"` + ID string `protobuf:"bytes,1,opt,name=ID,proto3" json:"ID,omitempty"` + Network string `protobuf:"bytes,2,opt,name=Network,proto3" json:"Network,omitempty"` + NetworkType int64 `protobuf:"varint,3,opt,name=NetworkType,proto3" json:"NetworkType,omitempty"` + Peer string `protobuf:"bytes,4,opt,name=Peer,proto3" json:"Peer,omitempty"` + Metric int64 `protobuf:"varint,5,opt,name=Metric,proto3" json:"Metric,omitempty"` + Masquerade bool `protobuf:"varint,6,opt,name=Masquerade,proto3" json:"Masquerade,omitempty"` + NetID string `protobuf:"bytes,7,opt,name=NetID,proto3" json:"NetID,omitempty"` + Domains []string `protobuf:"bytes,8,rep,name=Domains,proto3" json:"Domains,omitempty"` + KeepRoute bool `protobuf:"varint,9,opt,name=keepRoute,proto3" json:"keepRoute,omitempty"` } func (x *Route) Reset() { *x = Route{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[23] + mi := &file_management_proto_msgTypes[25] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1913,7 +2066,7 @@ func (x *Route) String() string { func (*Route) ProtoMessage() {} func (x *Route) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[23] + mi := &file_management_proto_msgTypes[25] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1926,7 +2079,7 @@ func (x *Route) ProtoReflect() protoreflect.Message { // Deprecated: Use Route.ProtoReflect.Descriptor instead. func (*Route) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{23} + return file_management_proto_rawDescGZIP(), []int{25} } func (x *Route) GetID() string { @@ -1978,6 +2131,20 @@ func (x *Route) GetNetID() string { return "" } +func (x *Route) GetDomains() []string { + if x != nil { + return x.Domains + } + return nil +} + +func (x *Route) GetKeepRoute() bool { + if x != nil { + return x.KeepRoute + } + return false +} + // DNSConfig represents a dns.Update type DNSConfig struct { state protoimpl.MessageState @@ -1992,7 +2159,7 @@ type DNSConfig struct { func (x *DNSConfig) Reset() { *x = DNSConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[24] + mi := &file_management_proto_msgTypes[26] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2005,7 +2172,7 @@ func (x *DNSConfig) String() string { func (*DNSConfig) ProtoMessage() {} func (x *DNSConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[24] + mi := &file_management_proto_msgTypes[26] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2018,7 +2185,7 @@ func (x *DNSConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use DNSConfig.ProtoReflect.Descriptor instead. func (*DNSConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{24} + return file_management_proto_rawDescGZIP(), []int{26} } func (x *DNSConfig) GetServiceEnable() bool { @@ -2055,7 +2222,7 @@ type CustomZone struct { func (x *CustomZone) Reset() { *x = CustomZone{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[25] + mi := &file_management_proto_msgTypes[27] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2068,7 +2235,7 @@ func (x *CustomZone) String() string { func (*CustomZone) ProtoMessage() {} func (x *CustomZone) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[25] + mi := &file_management_proto_msgTypes[27] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2081,7 +2248,7 @@ func (x *CustomZone) ProtoReflect() protoreflect.Message { // Deprecated: Use CustomZone.ProtoReflect.Descriptor instead. func (*CustomZone) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{25} + return file_management_proto_rawDescGZIP(), []int{27} } func (x *CustomZone) GetDomain() string { @@ -2114,7 +2281,7 @@ type SimpleRecord struct { func (x *SimpleRecord) Reset() { *x = SimpleRecord{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[26] + mi := &file_management_proto_msgTypes[28] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2127,7 +2294,7 @@ func (x *SimpleRecord) String() string { func (*SimpleRecord) ProtoMessage() {} func (x *SimpleRecord) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[26] + mi := &file_management_proto_msgTypes[28] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2140,7 +2307,7 @@ func (x *SimpleRecord) ProtoReflect() protoreflect.Message { // Deprecated: Use SimpleRecord.ProtoReflect.Descriptor instead. func (*SimpleRecord) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{26} + return file_management_proto_rawDescGZIP(), []int{28} } func (x *SimpleRecord) GetName() string { @@ -2193,7 +2360,7 @@ type NameServerGroup struct { func (x *NameServerGroup) Reset() { *x = NameServerGroup{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[27] + mi := &file_management_proto_msgTypes[29] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2206,7 +2373,7 @@ func (x *NameServerGroup) String() string { func (*NameServerGroup) ProtoMessage() {} func (x *NameServerGroup) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[27] + mi := &file_management_proto_msgTypes[29] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2219,7 +2386,7 @@ func (x *NameServerGroup) ProtoReflect() protoreflect.Message { // Deprecated: Use NameServerGroup.ProtoReflect.Descriptor instead. func (*NameServerGroup) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{27} + return file_management_proto_rawDescGZIP(), []int{29} } func (x *NameServerGroup) GetNameServers() []*NameServer { @@ -2264,7 +2431,7 @@ type NameServer struct { func (x *NameServer) Reset() { *x = NameServer{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[28] + mi := &file_management_proto_msgTypes[30] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2277,7 +2444,7 @@ func (x *NameServer) String() string { func (*NameServer) ProtoMessage() {} func (x *NameServer) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[28] + mi := &file_management_proto_msgTypes[30] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2290,7 +2457,7 @@ func (x *NameServer) ProtoReflect() protoreflect.Message { // Deprecated: Use NameServer.ProtoReflect.Descriptor instead. func (*NameServer) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{28} + return file_management_proto_rawDescGZIP(), []int{30} } func (x *NameServer) GetIP() string { @@ -2330,7 +2497,7 @@ type FirewallRule struct { func (x *FirewallRule) Reset() { *x = FirewallRule{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[29] + mi := &file_management_proto_msgTypes[31] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2343,7 +2510,7 @@ func (x *FirewallRule) String() string { func (*FirewallRule) ProtoMessage() {} func (x *FirewallRule) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[29] + mi := &file_management_proto_msgTypes[31] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2356,7 +2523,7 @@ func (x *FirewallRule) ProtoReflect() protoreflect.Message { // Deprecated: Use FirewallRule.ProtoReflect.Descriptor instead. func (*FirewallRule) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{29} + return file_management_proto_rawDescGZIP(), []int{31} } func (x *FirewallRule) GetPeerIP() string { @@ -2406,7 +2573,7 @@ type NetworkAddress struct { func (x *NetworkAddress) Reset() { *x = NetworkAddress{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[30] + mi := &file_management_proto_msgTypes[32] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2419,7 +2586,7 @@ func (x *NetworkAddress) String() string { func (*NetworkAddress) ProtoMessage() {} func (x *NetworkAddress) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[30] + mi := &file_management_proto_msgTypes[32] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2432,7 +2599,7 @@ func (x *NetworkAddress) ProtoReflect() protoreflect.Message { // Deprecated: Use NetworkAddress.ProtoReflect.Descriptor instead. func (*NetworkAddress) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{30} + return file_management_proto_rawDescGZIP(), []int{32} } func (x *NetworkAddress) GetNetIP() string { @@ -2449,6 +2616,53 @@ func (x *NetworkAddress) GetMac() string { return "" } +type Checks struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Files []string `protobuf:"bytes,1,rep,name=Files,proto3" json:"Files,omitempty"` +} + +func (x *Checks) Reset() { + *x = Checks{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[33] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Checks) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Checks) ProtoMessage() {} + +func (x *Checks) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[33] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Checks.ProtoReflect.Descriptor instead. +func (*Checks) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{33} +} + +func (x *Checks) GetFiles() []string { + if x != nil { + return x.Files + } + return nil +} + var File_management_proto protoreflect.FileDescriptor var file_management_proto_rawDesc = []byte{ @@ -2461,8 +2675,11 @@ var file_management_proto_rawDesc = []byte{ 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x62, 0x6f, 0x64, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x62, 0x6f, 0x64, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x0d, 0x0a, - 0x0b, 0x53, 0x79, 0x6e, 0x63, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbb, 0x02, 0x0a, + 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x3d, 0x0a, + 0x0b, 0x53, 0x79, 0x6e, 0x63, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x2e, 0x0a, 0x04, + 0x6d, 0x65, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x79, 0x73, 0x74, + 0x65, 0x6d, 0x4d, 0x65, 0x74, 0x61, 0x52, 0x04, 0x6d, 0x65, 0x74, 0x61, 0x22, 0xe7, 0x02, 0x0a, 0x0c, 0x53, 0x79, 0x6e, 0x63, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x4b, 0x0a, 0x11, 0x77, 0x69, 0x72, 0x65, 0x74, 0x72, 0x75, 0x73, 0x74, 0x65, 0x65, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, @@ -2482,324 +2699,351 @@ var file_management_proto_rawDesc = []byte{ 0x74, 0x79, 0x12, 0x36, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x52, 0x0a, - 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x22, 0xa8, 0x01, 0x0a, 0x0c, 0x4c, - 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x73, - 0x65, 0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, - 0x65, 0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x12, 0x2e, 0x0a, 0x04, 0x6d, 0x65, 0x74, 0x61, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x4d, 0x65, 0x74, - 0x61, 0x52, 0x04, 0x6d, 0x65, 0x74, 0x61, 0x12, 0x1a, 0x0a, 0x08, 0x6a, 0x77, 0x74, 0x54, 0x6f, - 0x6b, 0x65, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x6a, 0x77, 0x74, 0x54, 0x6f, - 0x6b, 0x65, 0x6e, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x65, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x73, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x73, 0x52, 0x08, 0x70, 0x65, 0x65, - 0x72, 0x4b, 0x65, 0x79, 0x73, 0x22, 0x44, 0x0a, 0x08, 0x50, 0x65, 0x65, 0x72, 0x4b, 0x65, 0x79, - 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, - 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x0c, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x22, 0x3f, 0x0a, 0x0b, 0x45, - 0x6e, 0x76, 0x69, 0x72, 0x6f, 0x6e, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x63, 0x6c, - 0x6f, 0x75, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x63, 0x6c, 0x6f, 0x75, 0x64, - 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x08, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x22, 0xa9, 0x04, 0x0a, - 0x0e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x4d, 0x65, 0x74, 0x61, 0x12, - 0x1a, 0x0a, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x67, - 0x6f, 0x4f, 0x53, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x67, 0x6f, 0x4f, 0x53, 0x12, - 0x16, 0x0a, 0x06, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x06, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x63, 0x6f, 0x72, 0x65, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x63, 0x6f, 0x72, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x70, - 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, - 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x12, 0x0e, 0x0a, 0x02, 0x4f, 0x53, 0x18, 0x06, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x02, 0x4f, 0x53, 0x12, 0x2e, 0x0a, 0x12, 0x77, 0x69, 0x72, 0x65, 0x74, - 0x72, 0x75, 0x73, 0x74, 0x65, 0x65, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x07, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x12, 0x77, 0x69, 0x72, 0x65, 0x74, 0x72, 0x75, 0x73, 0x74, 0x65, 0x65, - 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x75, 0x69, 0x56, 0x65, 0x72, - 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x75, 0x69, 0x56, 0x65, - 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x24, 0x0a, 0x0d, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x56, - 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x6b, 0x65, - 0x72, 0x6e, 0x65, 0x6c, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4f, - 0x53, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, - 0x4f, 0x53, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x46, 0x0a, 0x10, 0x6e, 0x65, 0x74, - 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x65, 0x73, 0x18, 0x0b, 0x20, - 0x03, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, - 0x10, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x65, - 0x73, 0x12, 0x28, 0x0a, 0x0f, 0x73, 0x79, 0x73, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x4e, 0x75, - 0x6d, 0x62, 0x65, 0x72, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x73, 0x79, 0x73, 0x53, - 0x65, 0x72, 0x69, 0x61, 0x6c, 0x4e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x12, 0x26, 0x0a, 0x0e, 0x73, - 0x79, 0x73, 0x50, 0x72, 0x6f, 0x64, 0x75, 0x63, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x0d, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x0e, 0x73, 0x79, 0x73, 0x50, 0x72, 0x6f, 0x64, 0x75, 0x63, 0x74, 0x4e, - 0x61, 0x6d, 0x65, 0x12, 0x28, 0x0a, 0x0f, 0x73, 0x79, 0x73, 0x4d, 0x61, 0x6e, 0x75, 0x66, 0x61, - 0x63, 0x74, 0x75, 0x72, 0x65, 0x72, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x73, 0x79, - 0x73, 0x4d, 0x61, 0x6e, 0x75, 0x66, 0x61, 0x63, 0x74, 0x75, 0x72, 0x65, 0x72, 0x12, 0x39, 0x0a, - 0x0b, 0x65, 0x6e, 0x76, 0x69, 0x72, 0x6f, 0x6e, 0x6d, 0x65, 0x6e, 0x74, 0x18, 0x0f, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x45, 0x6e, 0x76, 0x69, 0x72, 0x6f, 0x6e, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x0b, 0x65, 0x6e, 0x76, - 0x69, 0x72, 0x6f, 0x6e, 0x6d, 0x65, 0x6e, 0x74, 0x22, 0x94, 0x01, 0x0a, 0x0d, 0x4c, 0x6f, 0x67, - 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x4b, 0x0a, 0x11, 0x77, 0x69, - 0x72, 0x65, 0x74, 0x72, 0x75, 0x73, 0x74, 0x65, 0x65, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x57, 0x69, 0x72, 0x65, 0x74, 0x72, 0x75, 0x73, 0x74, 0x65, 0x65, 0x43, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x52, 0x11, 0x77, 0x69, 0x72, 0x65, 0x74, 0x72, 0x75, 0x73, 0x74, 0x65, - 0x65, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, - 0x79, 0x0a, 0x11, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x38, 0x0a, 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, - 0x73, 0x41, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, - 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, - 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, - 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, - 0x05, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x07, 0x0a, 0x05, 0x45, 0x6d, - 0x70, 0x74, 0x79, 0x22, 0xd7, 0x01, 0x0a, 0x11, 0x57, 0x69, 0x72, 0x65, 0x74, 0x72, 0x75, 0x73, - 0x74, 0x65, 0x65, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2c, 0x0a, 0x05, 0x73, 0x74, 0x75, - 0x6e, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x52, 0x05, 0x73, 0x74, 0x75, 0x6e, 0x73, 0x12, 0x35, 0x0a, 0x05, 0x74, 0x75, 0x72, 0x6e, 0x73, - 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x74, 0x65, 0x63, 0x74, 0x65, 0x64, 0x48, 0x6f, 0x73, - 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x74, 0x75, 0x72, 0x6e, 0x73, 0x12, 0x2e, - 0x0a, 0x06, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x06, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x12, 0x2d, - 0x0a, 0x05, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x22, 0x98, 0x01, - 0x0a, 0x0a, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x10, 0x0a, 0x03, - 0x75, 0x72, 0x69, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x69, 0x12, 0x3b, - 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, - 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, - 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22, 0x3b, 0x0a, 0x08, 0x50, - 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x00, - 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x01, 0x12, 0x08, 0x0a, 0x04, 0x48, 0x54, 0x54, - 0x50, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x48, 0x54, 0x54, 0x50, 0x53, 0x10, 0x03, 0x12, 0x08, - 0x0a, 0x04, 0x44, 0x54, 0x4c, 0x53, 0x10, 0x04, 0x22, 0x6d, 0x0a, 0x0b, 0x52, 0x65, 0x6c, 0x61, - 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x72, 0x6c, 0x73, 0x18, - 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x04, 0x75, 0x72, 0x6c, 0x73, 0x12, 0x22, 0x0a, 0x0c, 0x74, - 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, - 0x26, 0x0a, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, - 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, - 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0x7d, 0x0a, 0x13, 0x50, 0x72, 0x6f, 0x74, 0x65, - 0x63, 0x74, 0x65, 0x64, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x36, - 0x0a, 0x0a, 0x68, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x68, 0x6f, 0x73, 0x74, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x73, 0x65, 0x72, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, - 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x61, - 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x81, 0x01, 0x0a, 0x0a, 0x50, 0x65, 0x65, 0x72, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, - 0x10, 0x0a, 0x03, 0x64, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x64, 0x6e, + 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x2a, 0x0a, 0x06, 0x43, 0x68, + 0x65, 0x63, 0x6b, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x52, 0x06, + 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x22, 0x41, 0x0a, 0x0f, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, + 0x74, 0x61, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x2e, 0x0a, 0x04, 0x6d, 0x65, 0x74, + 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x4d, + 0x65, 0x74, 0x61, 0x52, 0x04, 0x6d, 0x65, 0x74, 0x61, 0x22, 0xa8, 0x01, 0x0a, 0x0c, 0x4c, 0x6f, + 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65, + 0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x65, + 0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x12, 0x2e, 0x0a, 0x04, 0x6d, 0x65, 0x74, 0x61, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x4d, 0x65, 0x74, 0x61, + 0x52, 0x04, 0x6d, 0x65, 0x74, 0x61, 0x12, 0x1a, 0x0a, 0x08, 0x6a, 0x77, 0x74, 0x54, 0x6f, 0x6b, + 0x65, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x6a, 0x77, 0x74, 0x54, 0x6f, 0x6b, + 0x65, 0x6e, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x65, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x73, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x73, 0x52, 0x08, 0x70, 0x65, 0x65, 0x72, + 0x4b, 0x65, 0x79, 0x73, 0x22, 0x44, 0x0a, 0x08, 0x50, 0x65, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x73, + 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1a, + 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, + 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x22, 0x3f, 0x0a, 0x0b, 0x45, 0x6e, + 0x76, 0x69, 0x72, 0x6f, 0x6e, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x63, 0x6c, 0x6f, + 0x75, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x63, 0x6c, 0x6f, 0x75, 0x64, 0x12, + 0x1a, 0x0a, 0x08, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x08, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x22, 0x5c, 0x0a, 0x04, 0x46, + 0x69, 0x6c, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x78, 0x69, 0x73, 0x74, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x65, 0x78, 0x69, 0x73, 0x74, 0x12, 0x2a, 0x0a, + 0x10, 0x70, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x49, 0x73, 0x52, 0x75, 0x6e, 0x6e, 0x69, 0x6e, + 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x70, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, + 0x49, 0x73, 0x52, 0x75, 0x6e, 0x6e, 0x69, 0x6e, 0x67, 0x22, 0xd1, 0x04, 0x0a, 0x0e, 0x50, 0x65, + 0x65, 0x72, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1a, 0x0a, 0x08, + 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, + 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x67, 0x6f, 0x4f, 0x53, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x67, 0x6f, 0x4f, 0x53, 0x12, 0x16, 0x0a, 0x06, + 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6b, 0x65, + 0x72, 0x6e, 0x65, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x63, 0x6f, 0x72, 0x65, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x04, 0x63, 0x6f, 0x72, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x6c, 0x61, 0x74, + 0x66, 0x6f, 0x72, 0x6d, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x6c, 0x61, 0x74, + 0x66, 0x6f, 0x72, 0x6d, 0x12, 0x0e, 0x0a, 0x02, 0x4f, 0x53, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x02, 0x4f, 0x53, 0x12, 0x2e, 0x0a, 0x12, 0x77, 0x69, 0x72, 0x65, 0x74, 0x72, 0x75, 0x73, + 0x74, 0x65, 0x65, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x12, 0x77, 0x69, 0x72, 0x65, 0x74, 0x72, 0x75, 0x73, 0x74, 0x65, 0x65, 0x56, 0x65, 0x72, + 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x75, 0x69, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, + 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x75, 0x69, 0x56, 0x65, 0x72, 0x73, 0x69, + 0x6f, 0x6e, 0x12, 0x24, 0x0a, 0x0d, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x56, 0x65, 0x72, 0x73, + 0x69, 0x6f, 0x6e, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x6b, 0x65, 0x72, 0x6e, 0x65, + 0x6c, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4f, 0x53, 0x56, 0x65, + 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x4f, 0x53, 0x56, + 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x46, 0x0a, 0x10, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x65, 0x73, 0x18, 0x0b, 0x20, 0x03, 0x28, 0x0b, + 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65, + 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x10, 0x6e, 0x65, + 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x65, 0x73, 0x12, 0x28, + 0x0a, 0x0f, 0x73, 0x79, 0x73, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x4e, 0x75, 0x6d, 0x62, 0x65, + 0x72, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x73, 0x79, 0x73, 0x53, 0x65, 0x72, 0x69, + 0x61, 0x6c, 0x4e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x12, 0x26, 0x0a, 0x0e, 0x73, 0x79, 0x73, 0x50, + 0x72, 0x6f, 0x64, 0x75, 0x63, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0e, 0x73, 0x79, 0x73, 0x50, 0x72, 0x6f, 0x64, 0x75, 0x63, 0x74, 0x4e, 0x61, 0x6d, 0x65, + 0x12, 0x28, 0x0a, 0x0f, 0x73, 0x79, 0x73, 0x4d, 0x61, 0x6e, 0x75, 0x66, 0x61, 0x63, 0x74, 0x75, + 0x72, 0x65, 0x72, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x73, 0x79, 0x73, 0x4d, 0x61, + 0x6e, 0x75, 0x66, 0x61, 0x63, 0x74, 0x75, 0x72, 0x65, 0x72, 0x12, 0x39, 0x0a, 0x0b, 0x65, 0x6e, + 0x76, 0x69, 0x72, 0x6f, 0x6e, 0x6d, 0x65, 0x6e, 0x74, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x76, + 0x69, 0x72, 0x6f, 0x6e, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x0b, 0x65, 0x6e, 0x76, 0x69, 0x72, 0x6f, + 0x6e, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x26, 0x0a, 0x05, 0x66, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x10, + 0x20, 0x03, 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x46, 0x69, 0x6c, 0x65, 0x52, 0x05, 0x66, 0x69, 0x6c, 0x65, 0x73, 0x22, 0xc0, 0x01, + 0x0a, 0x0d, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, + 0x4b, 0x0a, 0x11, 0x77, 0x69, 0x72, 0x65, 0x74, 0x72, 0x75, 0x73, 0x74, 0x65, 0x65, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x57, 0x69, 0x72, 0x65, 0x74, 0x72, 0x75, 0x73, + 0x74, 0x65, 0x65, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x11, 0x77, 0x69, 0x72, 0x65, 0x74, + 0x72, 0x75, 0x73, 0x74, 0x65, 0x65, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x36, 0x0a, 0x0a, + 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, + 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2a, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x18, 0x03, + 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x52, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, + 0x22, 0x79, 0x0a, 0x11, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x38, 0x0a, 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, + 0x65, 0x73, 0x41, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, + 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, + 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, + 0x74, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x05, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x07, 0x0a, 0x05, 0x45, + 0x6d, 0x70, 0x74, 0x79, 0x22, 0xd7, 0x01, 0x0a, 0x11, 0x57, 0x69, 0x72, 0x65, 0x74, 0x72, 0x75, + 0x73, 0x74, 0x65, 0x65, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2c, 0x0a, 0x05, 0x73, 0x74, + 0x75, 0x6e, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x52, 0x05, 0x73, 0x74, 0x75, 0x6e, 0x73, 0x12, 0x35, 0x0a, 0x05, 0x74, 0x75, 0x72, 0x6e, + 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x74, 0x65, 0x63, 0x74, 0x65, 0x64, 0x48, 0x6f, + 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x74, 0x75, 0x72, 0x6e, 0x73, 0x12, + 0x2e, 0x0a, 0x06, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, + 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x06, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x12, + 0x2d, 0x0a, 0x05, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6c, 0x61, + 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x22, 0x98, + 0x01, 0x0a, 0x0a, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x10, 0x0a, + 0x03, 0x75, 0x72, 0x69, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x69, 0x12, + 0x3b, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x0e, 0x32, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, + 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22, 0x3b, 0x0a, 0x08, + 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, + 0x00, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x01, 0x12, 0x08, 0x0a, 0x04, 0x48, 0x54, + 0x54, 0x50, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x48, 0x54, 0x54, 0x50, 0x53, 0x10, 0x03, 0x12, + 0x08, 0x0a, 0x04, 0x44, 0x54, 0x4c, 0x53, 0x10, 0x04, 0x22, 0x6d, 0x0a, 0x0b, 0x52, 0x65, 0x6c, + 0x61, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x72, 0x6c, 0x73, + 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x04, 0x75, 0x72, 0x6c, 0x73, 0x12, 0x22, 0x0a, 0x0c, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, + 0x12, 0x26, 0x0a, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, + 0x72, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, + 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0x7d, 0x0a, 0x13, 0x50, 0x72, 0x6f, 0x74, + 0x65, 0x63, 0x74, 0x65, 0x64, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, + 0x36, 0x0a, 0x0a, 0x68, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x68, 0x6f, 0x73, + 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x73, 0x65, 0x72, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, 0x12, 0x1a, 0x0a, 0x08, 0x70, + 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, + 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x81, 0x01, 0x0a, 0x0a, 0x50, 0x65, 0x65, 0x72, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, + 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, + 0x12, 0x10, 0x0a, 0x03, 0x64, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x64, + 0x6e, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, + 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0xe2, 0x03, 0x0a, 0x0a, + 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65, + 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69, + 0x61, 0x6c, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, + 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x3e, 0x0a, 0x0b, 0x72, 0x65, + 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, + 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0b, 0x72, + 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, + 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, + 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, + 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x29, 0x0a, 0x06, 0x52, 0x6f, + 0x75, 0x74, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x52, + 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, + 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x40, 0x0a, 0x0c, 0x6f, 0x66, + 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x0b, + 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, + 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0c, + 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x3e, 0x0a, 0x0d, + 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x08, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0d, 0x46, + 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x32, 0x0a, 0x14, + 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, + 0x6d, 0x70, 0x74, 0x79, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x66, 0x69, 0x72, 0x65, + 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, + 0x22, 0x97, 0x01, 0x0a, 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, + 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, + 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x18, + 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0xe2, 0x03, 0x0a, 0x0a, 0x4e, - 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65, 0x72, - 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, - 0x6c, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, - 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x3e, 0x0a, 0x0b, 0x72, 0x65, 0x6d, - 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, - 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0b, 0x72, 0x65, - 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6d, - 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, - 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x29, 0x0a, 0x06, 0x52, 0x6f, 0x75, - 0x74, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x52, 0x6f, - 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, - 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x40, 0x0a, 0x0c, 0x6f, 0x66, 0x66, - 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, - 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0c, 0x6f, - 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x3e, 0x0a, 0x0d, 0x46, - 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x08, 0x20, 0x03, - 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0d, 0x46, 0x69, - 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x66, - 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, - 0x70, 0x74, 0x79, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, - 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, - 0x97, 0x01, 0x0a, 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, - 0x12, 0x1e, 0x0a, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x18, 0x02, - 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, - 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0x49, 0x0a, 0x09, 0x53, 0x53, 0x48, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, - 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, 0x45, - 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, - 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, - 0x62, 0x4b, 0x65, 0x79, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, - 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, 0x69, 0x63, - 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, - 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, - 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, - 0x65, 0x72, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, 0x0a, 0x0e, - 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, 0x06, - 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, 0x50, 0x4b, 0x43, 0x45, - 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, - 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, 0x50, 0x4b, 0x43, 0x45, - 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, - 0x77, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0xea, 0x02, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, - 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, 0x65, - 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, 0x65, - 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, - 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, 0x69, 0x65, - 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, - 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, - 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, 0x0a, 0x12, - 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, - 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, - 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, 0x0d, - 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, - 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, 0x65, 0x49, - 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x55, 0x73, - 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, 0x74, 0x68, - 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, - 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, - 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x22, - 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x18, 0x0a, - 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, - 0x4c, 0x73, 0x22, 0xb5, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, - 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, - 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, - 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, - 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, 0x74, - 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, 0x72, - 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, - 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, 0x65, - 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, - 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, - 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x22, 0xb4, 0x01, 0x0a, 0x09, 0x44, - 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, - 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, - 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, - 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, - 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, - 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, - 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, - 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, - 0x73, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, - 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, - 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, - 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, - 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, - 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, - 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, - 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, - 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, - 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, - 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, - 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, - 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, - 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, - 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, - 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, - 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, - 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, - 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, - 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, - 0x74, 0x22, 0xf0, 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, - 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x12, 0x40, 0x0a, 0x09, 0x44, 0x69, - 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x22, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, - 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x2e, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, - 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x37, 0x0a, 0x06, - 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1f, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, - 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x2e, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x3d, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, - 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x21, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, - 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, - 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0x1c, 0x0a, 0x09, 0x64, 0x69, 0x72, 0x65, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, - 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x22, 0x1e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, - 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, - 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x22, 0x3c, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, - 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, - 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, - 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, - 0x4d, 0x50, 0x10, 0x04, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, - 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, - 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x32, 0xd1, - 0x03, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, - 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, - 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, - 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, - 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0x49, 0x0a, 0x09, 0x53, 0x53, + 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, + 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, + 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, + 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, + 0x75, 0x62, 0x4b, 0x65, 0x79, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, + 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, 0x69, + 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, + 0x6c, 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, + 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, 0x69, + 0x64, 0x65, 0x72, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, 0x0a, + 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, + 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, 0x50, 0x4b, 0x43, + 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, + 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, 0x50, 0x4b, 0x43, + 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, + 0x6f, 0x77, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0xea, 0x02, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, + 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, + 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, + 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, + 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, 0x69, + 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, + 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, 0x0a, + 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, + 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, + 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, + 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, + 0x69, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, 0x65, + 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x55, + 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, 0x74, + 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, + 0x6e, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, + 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, + 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x18, + 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, + 0x52, 0x4c, 0x73, 0x22, 0xed, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, + 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, + 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, + 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, + 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, + 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, + 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, + 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, + 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, + 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, + 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x44, + 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, + 0x74, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, + 0x75, 0x74, 0x65, 0x22, 0xb4, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, + 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, + 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, + 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, + 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, + 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x10, + 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, + 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, + 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, 0x43, + 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75, + 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, + 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, + 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, + 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, + 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, + 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, + 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, + 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, + 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, + 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, + 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, + 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, + 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, + 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, + 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, + 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, + 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, + 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, + 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, + 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, + 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xf0, 0x02, 0x0a, 0x0c, 0x46, + 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, + 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, + 0x72, 0x49, 0x50, 0x12, 0x40, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x22, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, + 0x2e, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, 0x65, + 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x37, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x2e, + 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x3d, + 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, + 0x32, 0x21, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, + 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, + 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, + 0x74, 0x22, 0x1c, 0x0a, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, + 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x22, + 0x1e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, + 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x22, + 0x3c, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, + 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, 0x10, + 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, + 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x22, 0x38, 0x0a, + 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, + 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, + 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, + 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x32, 0x90, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, + 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, + 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, + 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, + 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, + 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, + 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, + 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, + 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, + 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, + 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, + 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, + 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, + 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, - 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, - 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, - 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, - 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, - 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, - 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, - 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, - 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x33, + 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, + 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, + 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, + 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, + 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -2815,7 +3059,7 @@ func file_management_proto_rawDescGZIP() []byte { } var file_management_proto_enumTypes = make([]protoimpl.EnumInfo, 5) -var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 31) +var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 34) var file_management_proto_goTypes = []interface{}{ (HostConfig_Protocol)(0), // 0: management.HostConfig.Protocol (DeviceAuthorizationFlowProvider)(0), // 1: management.DeviceAuthorizationFlow.provider @@ -2825,89 +3069,99 @@ var file_management_proto_goTypes = []interface{}{ (*EncryptedMessage)(nil), // 5: management.EncryptedMessage (*SyncRequest)(nil), // 6: management.SyncRequest (*SyncResponse)(nil), // 7: management.SyncResponse - (*LoginRequest)(nil), // 8: management.LoginRequest - (*PeerKeys)(nil), // 9: management.PeerKeys - (*Environment)(nil), // 10: management.Environment - (*PeerSystemMeta)(nil), // 11: management.PeerSystemMeta - (*LoginResponse)(nil), // 12: management.LoginResponse - (*ServerKeyResponse)(nil), // 13: management.ServerKeyResponse - (*Empty)(nil), // 14: management.Empty - (*WiretrusteeConfig)(nil), // 15: management.WiretrusteeConfig - (*HostConfig)(nil), // 16: management.HostConfig - (*RelayConfig)(nil), // 17: management.RelayConfig - (*ProtectedHostConfig)(nil), // 18: management.ProtectedHostConfig - (*PeerConfig)(nil), // 19: management.PeerConfig - (*NetworkMap)(nil), // 20: management.NetworkMap - (*RemotePeerConfig)(nil), // 21: management.RemotePeerConfig - (*SSHConfig)(nil), // 22: management.SSHConfig - (*DeviceAuthorizationFlowRequest)(nil), // 23: management.DeviceAuthorizationFlowRequest - (*DeviceAuthorizationFlow)(nil), // 24: management.DeviceAuthorizationFlow - (*PKCEAuthorizationFlowRequest)(nil), // 25: management.PKCEAuthorizationFlowRequest - (*PKCEAuthorizationFlow)(nil), // 26: management.PKCEAuthorizationFlow - (*ProviderConfig)(nil), // 27: management.ProviderConfig - (*Route)(nil), // 28: management.Route - (*DNSConfig)(nil), // 29: management.DNSConfig - (*CustomZone)(nil), // 30: management.CustomZone - (*SimpleRecord)(nil), // 31: management.SimpleRecord - (*NameServerGroup)(nil), // 32: management.NameServerGroup - (*NameServer)(nil), // 33: management.NameServer - (*FirewallRule)(nil), // 34: management.FirewallRule - (*NetworkAddress)(nil), // 35: management.NetworkAddress - (*timestamppb.Timestamp)(nil), // 36: google.protobuf.Timestamp + (*SyncMetaRequest)(nil), // 8: management.SyncMetaRequest + (*LoginRequest)(nil), // 9: management.LoginRequest + (*PeerKeys)(nil), // 10: management.PeerKeys + (*Environment)(nil), // 11: management.Environment + (*File)(nil), // 12: management.File + (*PeerSystemMeta)(nil), // 13: management.PeerSystemMeta + (*LoginResponse)(nil), // 14: management.LoginResponse + (*ServerKeyResponse)(nil), // 15: management.ServerKeyResponse + (*Empty)(nil), // 16: management.Empty + (*WiretrusteeConfig)(nil), // 17: management.WiretrusteeConfig + (*HostConfig)(nil), // 18: management.HostConfig + (*RelayConfig)(nil), // 19: management.RelayConfig + (*ProtectedHostConfig)(nil), // 20: management.ProtectedHostConfig + (*PeerConfig)(nil), // 21: management.PeerConfig + (*NetworkMap)(nil), // 22: management.NetworkMap + (*RemotePeerConfig)(nil), // 23: management.RemotePeerConfig + (*SSHConfig)(nil), // 24: management.SSHConfig + (*DeviceAuthorizationFlowRequest)(nil), // 25: management.DeviceAuthorizationFlowRequest + (*DeviceAuthorizationFlow)(nil), // 26: management.DeviceAuthorizationFlow + (*PKCEAuthorizationFlowRequest)(nil), // 27: management.PKCEAuthorizationFlowRequest + (*PKCEAuthorizationFlow)(nil), // 28: management.PKCEAuthorizationFlow + (*ProviderConfig)(nil), // 29: management.ProviderConfig + (*Route)(nil), // 30: management.Route + (*DNSConfig)(nil), // 31: management.DNSConfig + (*CustomZone)(nil), // 32: management.CustomZone + (*SimpleRecord)(nil), // 33: management.SimpleRecord + (*NameServerGroup)(nil), // 34: management.NameServerGroup + (*NameServer)(nil), // 35: management.NameServer + (*FirewallRule)(nil), // 36: management.FirewallRule + (*NetworkAddress)(nil), // 37: management.NetworkAddress + (*Checks)(nil), // 38: management.Checks + (*timestamppb.Timestamp)(nil), // 39: google.protobuf.Timestamp } var file_management_proto_depIdxs = []int32{ - 15, // 0: management.SyncResponse.wiretrusteeConfig:type_name -> management.WiretrusteeConfig - 19, // 1: management.SyncResponse.peerConfig:type_name -> management.PeerConfig - 21, // 2: management.SyncResponse.remotePeers:type_name -> management.RemotePeerConfig - 20, // 3: management.SyncResponse.NetworkMap:type_name -> management.NetworkMap - 11, // 4: management.LoginRequest.meta:type_name -> management.PeerSystemMeta - 9, // 5: management.LoginRequest.peerKeys:type_name -> management.PeerKeys - 35, // 6: management.PeerSystemMeta.networkAddresses:type_name -> management.NetworkAddress - 10, // 7: management.PeerSystemMeta.environment:type_name -> management.Environment - 15, // 8: management.LoginResponse.wiretrusteeConfig:type_name -> management.WiretrusteeConfig - 19, // 9: management.LoginResponse.peerConfig:type_name -> management.PeerConfig - 36, // 10: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp - 16, // 11: management.WiretrusteeConfig.stuns:type_name -> management.HostConfig - 18, // 12: management.WiretrusteeConfig.turns:type_name -> management.ProtectedHostConfig - 16, // 13: management.WiretrusteeConfig.signal:type_name -> management.HostConfig - 17, // 14: management.WiretrusteeConfig.relay:type_name -> management.RelayConfig - 0, // 15: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol - 16, // 16: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig - 22, // 17: management.PeerConfig.sshConfig:type_name -> management.SSHConfig - 19, // 18: management.NetworkMap.peerConfig:type_name -> management.PeerConfig - 21, // 19: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig - 28, // 20: management.NetworkMap.Routes:type_name -> management.Route - 29, // 21: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig - 21, // 22: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig - 34, // 23: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule - 22, // 24: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig - 1, // 25: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider - 27, // 26: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 27, // 27: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 32, // 28: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup - 30, // 29: management.DNSConfig.CustomZones:type_name -> management.CustomZone - 31, // 30: management.CustomZone.Records:type_name -> management.SimpleRecord - 33, // 31: management.NameServerGroup.NameServers:type_name -> management.NameServer - 2, // 32: management.FirewallRule.Direction:type_name -> management.FirewallRule.direction - 3, // 33: management.FirewallRule.Action:type_name -> management.FirewallRule.action - 4, // 34: management.FirewallRule.Protocol:type_name -> management.FirewallRule.protocol - 5, // 35: management.ManagementService.Login:input_type -> management.EncryptedMessage - 5, // 36: management.ManagementService.Sync:input_type -> management.EncryptedMessage - 14, // 37: management.ManagementService.GetServerKey:input_type -> management.Empty - 14, // 38: management.ManagementService.isHealthy:input_type -> management.Empty - 5, // 39: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage - 5, // 40: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage - 5, // 41: management.ManagementService.Login:output_type -> management.EncryptedMessage - 5, // 42: management.ManagementService.Sync:output_type -> management.EncryptedMessage - 13, // 43: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse - 14, // 44: management.ManagementService.isHealthy:output_type -> management.Empty - 5, // 45: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage - 5, // 46: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage - 41, // [41:47] is the sub-list for method output_type - 35, // [35:41] is the sub-list for method input_type - 35, // [35:35] is the sub-list for extension type_name - 35, // [35:35] is the sub-list for extension extendee - 0, // [0:35] is the sub-list for field type_name + 13, // 0: management.SyncRequest.meta:type_name -> management.PeerSystemMeta + 17, // 1: management.SyncResponse.wiretrusteeConfig:type_name -> management.WiretrusteeConfig + 21, // 2: management.SyncResponse.peerConfig:type_name -> management.PeerConfig + 23, // 3: management.SyncResponse.remotePeers:type_name -> management.RemotePeerConfig + 22, // 4: management.SyncResponse.NetworkMap:type_name -> management.NetworkMap + 38, // 5: management.SyncResponse.Checks:type_name -> management.Checks + 13, // 6: management.SyncMetaRequest.meta:type_name -> management.PeerSystemMeta + 13, // 7: management.LoginRequest.meta:type_name -> management.PeerSystemMeta + 10, // 8: management.LoginRequest.peerKeys:type_name -> management.PeerKeys + 37, // 9: management.PeerSystemMeta.networkAddresses:type_name -> management.NetworkAddress + 11, // 10: management.PeerSystemMeta.environment:type_name -> management.Environment + 12, // 11: management.PeerSystemMeta.files:type_name -> management.File + 17, // 12: management.LoginResponse.wiretrusteeConfig:type_name -> management.WiretrusteeConfig + 21, // 13: management.LoginResponse.peerConfig:type_name -> management.PeerConfig + 38, // 14: management.LoginResponse.Checks:type_name -> management.Checks + 39, // 15: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp + 18, // 16: management.WiretrusteeConfig.stuns:type_name -> management.HostConfig + 20, // 17: management.WiretrusteeConfig.turns:type_name -> management.ProtectedHostConfig + 18, // 18: management.WiretrusteeConfig.signal:type_name -> management.HostConfig + 19, // 19: management.WiretrusteeConfig.relay:type_name -> management.RelayConfig + 0, // 20: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol + 18, // 21: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig + 24, // 22: management.PeerConfig.sshConfig:type_name -> management.SSHConfig + 21, // 23: management.NetworkMap.peerConfig:type_name -> management.PeerConfig + 23, // 24: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig + 30, // 25: management.NetworkMap.Routes:type_name -> management.Route + 31, // 26: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig + 23, // 27: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig + 36, // 28: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule + 24, // 29: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig + 1, // 30: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider + 29, // 31: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 29, // 32: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 34, // 33: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup + 32, // 34: management.DNSConfig.CustomZones:type_name -> management.CustomZone + 33, // 35: management.CustomZone.Records:type_name -> management.SimpleRecord + 35, // 36: management.NameServerGroup.NameServers:type_name -> management.NameServer + 2, // 37: management.FirewallRule.Direction:type_name -> management.FirewallRule.direction + 3, // 38: management.FirewallRule.Action:type_name -> management.FirewallRule.action + 4, // 39: management.FirewallRule.Protocol:type_name -> management.FirewallRule.protocol + 5, // 40: management.ManagementService.Login:input_type -> management.EncryptedMessage + 5, // 41: management.ManagementService.Sync:input_type -> management.EncryptedMessage + 16, // 42: management.ManagementService.GetServerKey:input_type -> management.Empty + 16, // 43: management.ManagementService.isHealthy:input_type -> management.Empty + 5, // 44: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage + 5, // 45: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage + 5, // 46: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage + 5, // 47: management.ManagementService.Login:output_type -> management.EncryptedMessage + 5, // 48: management.ManagementService.Sync:output_type -> management.EncryptedMessage + 15, // 49: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse + 16, // 50: management.ManagementService.isHealthy:output_type -> management.Empty + 5, // 51: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage + 5, // 52: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage + 16, // 53: management.ManagementService.SyncMeta:output_type -> management.Empty + 47, // [47:54] is the sub-list for method output_type + 40, // [40:47] is the sub-list for method input_type + 40, // [40:40] is the sub-list for extension type_name + 40, // [40:40] is the sub-list for extension extendee + 0, // [0:40] is the sub-list for field type_name } func init() { file_management_proto_init() } @@ -2953,7 +3207,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*LoginRequest); i { + switch v := v.(*SyncMetaRequest); i { case 0: return &v.state case 1: @@ -2965,7 +3219,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PeerKeys); i { + switch v := v.(*LoginRequest); i { case 0: return &v.state case 1: @@ -2977,7 +3231,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Environment); i { + switch v := v.(*PeerKeys); i { case 0: return &v.state case 1: @@ -2989,7 +3243,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PeerSystemMeta); i { + switch v := v.(*Environment); i { case 0: return &v.state case 1: @@ -3001,7 +3255,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*LoginResponse); i { + switch v := v.(*File); i { case 0: return &v.state case 1: @@ -3013,7 +3267,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ServerKeyResponse); i { + switch v := v.(*PeerSystemMeta); i { case 0: return &v.state case 1: @@ -3025,7 +3279,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[9].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Empty); i { + switch v := v.(*LoginResponse); i { case 0: return &v.state case 1: @@ -3037,7 +3291,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[10].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*WiretrusteeConfig); i { + switch v := v.(*ServerKeyResponse); i { case 0: return &v.state case 1: @@ -3049,7 +3303,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[11].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*HostConfig); i { + switch v := v.(*Empty); i { case 0: return &v.state case 1: @@ -3061,7 +3315,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[12].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RelayConfig); i { + switch v := v.(*WiretrusteeConfig); i { case 0: return &v.state case 1: @@ -3073,7 +3327,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[13].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ProtectedHostConfig); i { + switch v := v.(*HostConfig); i { case 0: return &v.state case 1: @@ -3085,7 +3339,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[14].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PeerConfig); i { + switch v := v.(*RelayConfig); i { case 0: return &v.state case 1: @@ -3097,7 +3351,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[15].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NetworkMap); i { + switch v := v.(*ProtectedHostConfig); i { case 0: return &v.state case 1: @@ -3109,7 +3363,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[16].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RemotePeerConfig); i { + switch v := v.(*PeerConfig); i { case 0: return &v.state case 1: @@ -3121,7 +3375,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[17].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SSHConfig); i { + switch v := v.(*NetworkMap); i { case 0: return &v.state case 1: @@ -3133,7 +3387,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[18].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DeviceAuthorizationFlowRequest); i { + switch v := v.(*RemotePeerConfig); i { case 0: return &v.state case 1: @@ -3145,7 +3399,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[19].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DeviceAuthorizationFlow); i { + switch v := v.(*SSHConfig); i { case 0: return &v.state case 1: @@ -3157,7 +3411,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PKCEAuthorizationFlowRequest); i { + switch v := v.(*DeviceAuthorizationFlowRequest); i { case 0: return &v.state case 1: @@ -3169,7 +3423,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[21].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PKCEAuthorizationFlow); i { + switch v := v.(*DeviceAuthorizationFlow); i { case 0: return &v.state case 1: @@ -3181,7 +3435,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[22].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ProviderConfig); i { + switch v := v.(*PKCEAuthorizationFlowRequest); i { case 0: return &v.state case 1: @@ -3193,7 +3447,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[23].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Route); i { + switch v := v.(*PKCEAuthorizationFlow); i { case 0: return &v.state case 1: @@ -3205,7 +3459,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[24].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DNSConfig); i { + switch v := v.(*ProviderConfig); i { case 0: return &v.state case 1: @@ -3217,7 +3471,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[25].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CustomZone); i { + switch v := v.(*Route); i { case 0: return &v.state case 1: @@ -3229,7 +3483,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[26].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SimpleRecord); i { + switch v := v.(*DNSConfig); i { case 0: return &v.state case 1: @@ -3241,7 +3495,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[27].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NameServerGroup); i { + switch v := v.(*CustomZone); i { case 0: return &v.state case 1: @@ -3253,7 +3507,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[28].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NameServer); i { + switch v := v.(*SimpleRecord); i { case 0: return &v.state case 1: @@ -3265,7 +3519,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[29].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*FirewallRule); i { + switch v := v.(*NameServerGroup); i { case 0: return &v.state case 1: @@ -3277,6 +3531,30 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[30].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*NameServer); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[31].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*FirewallRule); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[32].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*NetworkAddress); i { case 0: return &v.state @@ -3288,6 +3566,18 @@ func file_management_proto_init() { return nil } } + file_management_proto_msgTypes[33].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Checks); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } type x struct{} out := protoimpl.TypeBuilder{ @@ -3295,7 +3585,7 @@ func file_management_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_management_proto_rawDesc, NumEnums: 5, - NumMessages: 31, + NumMessages: 34, NumExtensions: 0, NumServices: 1, }, diff --git a/management/proto/management.proto b/management/proto/management.proto index c6695dd6a..c5646820f 100644 --- a/management/proto/management.proto +++ b/management/proto/management.proto @@ -38,6 +38,12 @@ service ManagementService { // EncryptedMessage of the request has a body of PKCEAuthorizationFlowRequest. // EncryptedMessage of the response has a body of PKCEAuthorizationFlow. rpc GetPKCEAuthorizationFlow(EncryptedMessage) returns (EncryptedMessage) {} + + // SyncMeta is used to sync metadata of the peer. + // After sync the peer if there is a change in peer posture check which needs to be evaluated by the client, + // sync meta will evaluate the checks and update the peer meta with the result. + // EncryptedMessage of the request has a body of Empty. + rpc SyncMeta(EncryptedMessage) returns (Empty) {} } message EncryptedMessage { @@ -50,7 +56,10 @@ message EncryptedMessage { int32 version = 3; } -message SyncRequest {} +message SyncRequest { + // Meta data of the peer + PeerSystemMeta meta = 1; +} // SyncResponse represents a state that should be applied to the local peer (e.g. Wiretrustee servers config as well as local peer and remote peers configs) message SyncResponse { @@ -69,6 +78,14 @@ message SyncResponse { bool remotePeersIsEmpty = 4; NetworkMap NetworkMap = 5; + + // Posture checks to be evaluated by client + repeated Checks Checks = 6; +} + +message SyncMetaRequest { + // Meta data of the peer + PeerSystemMeta meta = 1; } message LoginRequest { @@ -82,6 +99,7 @@ message LoginRequest { PeerKeys peerKeys = 4; } + // PeerKeys is additional peer info like SSH pub key and WireGuard public key. // This message is sent on Login or register requests, or when a key rotation has to happen. message PeerKeys { @@ -100,6 +118,16 @@ message Environment { string platform = 2; } +// File represents a file on the system. +message File { + // path is the path to the file. + string path = 1; + // exist indicate whether the file exists. + bool exist = 2; + // processIsRunning indicates whether the file is a running process or not. + bool processIsRunning = 3; +} + // PeerSystemMeta is machine meta data like OS and version. message PeerSystemMeta { string hostname = 1; @@ -117,6 +145,7 @@ message PeerSystemMeta { string sysProductName = 13; string sysManufacturer = 14; Environment environment = 15; + repeated File files = 16; } message LoginResponse { @@ -124,6 +153,8 @@ message LoginResponse { WiretrusteeConfig wiretrusteeConfig = 1; // Peer local config PeerConfig peerConfig = 2; + // Posture checks to be evaluated by client + repeated Checks Checks = 3; } message ServerKeyResponse { @@ -312,6 +343,8 @@ message Route { int64 Metric = 5; bool Masquerade = 6; string NetID = 7; + repeated string Domains = 8; + bool keepRoute = 9; } // DNSConfig represents a dns.Update @@ -380,3 +413,7 @@ message NetworkAddress { string netIP = 1; string mac = 2; } + +message Checks { + repeated string Files= 1; +} diff --git a/management/proto/management_grpc.pb.go b/management/proto/management_grpc.pb.go index 5e2bcd225..badf242f5 100644 --- a/management/proto/management_grpc.pb.go +++ b/management/proto/management_grpc.pb.go @@ -43,6 +43,11 @@ type ManagementServiceClient interface { // EncryptedMessage of the request has a body of PKCEAuthorizationFlowRequest. // EncryptedMessage of the response has a body of PKCEAuthorizationFlow. GetPKCEAuthorizationFlow(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error) + // SyncMeta is used to sync metadata of the peer. + // After sync the peer if there is a change in peer posture check which needs to be evaluated by the client, + // sync meta will evaluate the checks and update the peer meta with the result. + // EncryptedMessage of the request has a body of Empty. + SyncMeta(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error) } type managementServiceClient struct { @@ -130,6 +135,15 @@ func (c *managementServiceClient) GetPKCEAuthorizationFlow(ctx context.Context, return out, nil } +func (c *managementServiceClient) SyncMeta(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error) { + out := new(Empty) + err := c.cc.Invoke(ctx, "/management.ManagementService/SyncMeta", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // ManagementServiceServer is the server API for ManagementService service. // All implementations must embed UnimplementedManagementServiceServer // for forward compatibility @@ -159,6 +173,11 @@ type ManagementServiceServer interface { // EncryptedMessage of the request has a body of PKCEAuthorizationFlowRequest. // EncryptedMessage of the response has a body of PKCEAuthorizationFlow. GetPKCEAuthorizationFlow(context.Context, *EncryptedMessage) (*EncryptedMessage, error) + // SyncMeta is used to sync metadata of the peer. + // After sync the peer if there is a change in peer posture check which needs to be evaluated by the client, + // sync meta will evaluate the checks and update the peer meta with the result. + // EncryptedMessage of the request has a body of Empty. + SyncMeta(context.Context, *EncryptedMessage) (*Empty, error) mustEmbedUnimplementedManagementServiceServer() } @@ -184,6 +203,9 @@ func (UnimplementedManagementServiceServer) GetDeviceAuthorizationFlow(context.C func (UnimplementedManagementServiceServer) GetPKCEAuthorizationFlow(context.Context, *EncryptedMessage) (*EncryptedMessage, error) { return nil, status.Errorf(codes.Unimplemented, "method GetPKCEAuthorizationFlow not implemented") } +func (UnimplementedManagementServiceServer) SyncMeta(context.Context, *EncryptedMessage) (*Empty, error) { + return nil, status.Errorf(codes.Unimplemented, "method SyncMeta not implemented") +} func (UnimplementedManagementServiceServer) mustEmbedUnimplementedManagementServiceServer() {} // UnsafeManagementServiceServer may be embedded to opt out of forward compatibility for this service. @@ -308,6 +330,24 @@ func _ManagementService_GetPKCEAuthorizationFlow_Handler(srv interface{}, ctx co return interceptor(ctx, in, info, handler) } +func _ManagementService_SyncMeta_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(EncryptedMessage) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ManagementServiceServer).SyncMeta(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/management.ManagementService/SyncMeta", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ManagementServiceServer).SyncMeta(ctx, req.(*EncryptedMessage)) + } + return interceptor(ctx, in, info, handler) +} + // ManagementService_ServiceDesc is the grpc.ServiceDesc for ManagementService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -335,6 +375,10 @@ var ManagementService_ServiceDesc = grpc.ServiceDesc{ MethodName: "GetPKCEAuthorizationFlow", Handler: _ManagementService_GetPKCEAuthorizationFlow_Handler, }, + { + MethodName: "SyncMeta", + Handler: _ManagementService_SyncMeta_Handler, + }, }, Streams: []grpc.StreamDesc{ { diff --git a/management/server/account.go b/management/server/account.go index 984139a12..27c21e402 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -11,6 +11,7 @@ import ( "net/netip" "reflect" "regexp" + "slices" "strings" "sync" "time" @@ -20,9 +21,11 @@ import ( gocache "github.com/patrickmn/go-cache" "github.com/rs/xid" log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" "github.com/netbirdio/netbird/base62" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/geolocation" @@ -56,83 +59,85 @@ func cacheEntryExpiration() time.Duration { } type AccountManager interface { - GetOrCreateAccountByUser(userId, domain string) (*Account, error) - CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, + GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*Account, error) + CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) - SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error) - CreateUser(accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error) - DeleteUser(accountID, initiatorUserID string, targetUserID string) error - InviteUser(accountID string, initiatorUserID string, targetUserID string) error - ListSetupKeys(accountID, userID string) ([]*SetupKey, error) - SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error) - SaveOrAddUser(accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) - GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) - GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) - GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) - CheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error - GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error) - DeleteAccount(accountID, userID string) error - MarkPATUsed(tokenID string) error - GetUser(claims jwtclaims.AuthorizationClaims) (*User, error) - ListUsers(accountID string) ([]*User, error) - GetPeers(accountID, userID string) ([]*nbpeer.Peer, error) - MarkPeerConnected(peerKey string, connected bool, realIP net.IP, account *Account) error - DeletePeer(accountID, peerID, userID string) error - UpdatePeer(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) - GetNetworkMap(peerID string) (*NetworkMap, error) - GetPeerNetwork(peerID string) (*Network, error) - AddPeer(setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, error) - CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) - DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error - GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) - GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) - UpdatePeerSSHKey(peerID string, sshKey string) error - GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) - GetGroup(accountId, groupID, userID string) (*nbgroup.Group, error) - GetAllGroups(accountID, userID string) ([]*nbgroup.Group, error) - GetGroupByName(groupName, accountID string) (*nbgroup.Group, error) - SaveGroup(accountID, userID string, group *nbgroup.Group) error - DeleteGroup(accountId, userId, groupID string) error - ListGroups(accountId string) ([]*nbgroup.Group, error) - GroupAddPeer(accountId, groupID, peerID string) error - GroupDeletePeer(accountId, groupID, peerID string) error - GetPolicy(accountID, policyID, userID string) (*Policy, error) - SavePolicy(accountID, userID string, policy *Policy) error - DeletePolicy(accountID, policyID, userID string) error - ListPolicies(accountID, userID string) ([]*Policy, error) - GetRoute(accountID string, routeID route.ID, userID string) (*route.Route, error) - CreateRoute(accountID, prefix, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) - SaveRoute(accountID, userID string, route *route.Route) error - DeleteRoute(accountID string, routeID route.ID, userID string) error - ListRoutes(accountID, userID string) ([]*route.Route, error) - GetNameServerGroup(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) - CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) - SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error - DeleteNameServerGroup(accountID, nsGroupID, userID string) error - ListNameServerGroups(accountID string, userID string) ([]*nbdns.NameServerGroup, error) + SaveSetupKey(ctx context.Context, accountID string, key *SetupKey, userID string) (*SetupKey, error) + CreateUser(ctx context.Context, accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error) + DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error + InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error + ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) + SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error) + SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) + GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) + GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error) + GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) + CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error + GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error) + DeleteAccount(ctx context.Context, accountID, userID string) error + MarkPATUsed(ctx context.Context, tokenID string) error + GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) + ListUsers(ctx context.Context, accountID string) ([]*User, error) + GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) + MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *Account) error + DeletePeer(ctx context.Context, accountID, peerID, userID string) error + UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) + GetNetworkMap(ctx context.Context, peerID string) (*NetworkMap, error) + GetPeerNetwork(ctx context.Context, peerID string) (*Network, error) + AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) + CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) + DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error + GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) + GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) + UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error + GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error) + GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error) + GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) + GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) + SaveGroup(ctx context.Context, accountID, userID string, group *nbgroup.Group) error + DeleteGroup(ctx context.Context, accountId, userId, groupID string) error + ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error) + GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error + GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error + GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) + SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error + DeletePolicy(ctx context.Context, accountID, policyID, userID string) error + ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) + GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) + CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) + SaveRoute(ctx context.Context, accountID, userID string, route *route.Route) error + DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error + ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) + GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) + CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) + SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error + DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error + ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) GetDNSDomain() string - StoreEvent(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) - GetEvents(accountID, userID string) ([]*activity.Event, error) - GetDNSSettings(accountID string, userID string) (*DNSSettings, error) - SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error - GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error) - UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error) - LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API - SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API + StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) + GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) + GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) + SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error + GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) + UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) + LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API + SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API GetAllConnectedPeers() (map[string]struct{}, error) HasConnectedChannel(peerID string) bool GetExternalCacheManager() ExternalCacheManager - GetPostureChecks(accountID, postureChecksID, userID string) (*posture.Checks, error) - SavePostureChecks(accountID, userID string, postureChecks *posture.Checks) error - DeletePostureChecks(accountID, postureChecksID, userID string) error - ListPostureChecks(accountID, userID string) ([]*posture.Checks, error) + GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) + SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error + DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error + ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) GetIdpManager() idp.Manager - UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error - GroupValidation(accountId string, groups []string) (bool, error) + UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error + GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) GetValidatedPeers(account *Account) (map[string]struct{}, error) - SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error) - CancelPeerRoutines(peer *nbpeer.Peer) error + SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) + CancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) error + SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) + GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) } type DefaultAccountManager struct { @@ -270,16 +275,16 @@ type UserInfo struct { // getRoutesToSync returns the enabled routes for the peer ID and the routes // from the ACL peers that have distribution groups associated with the peer ID. // Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. -func (a *Account) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer) []*route.Route { - routes, peerDisabledRoutes := a.getRoutingPeerRoutes(peerID) +func (a *Account) getRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer) []*route.Route { + routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID) peerRoutesMembership := make(lookupMap) for _, r := range append(routes, peerDisabledRoutes...) { - peerRoutesMembership[string(route.GetHAUniqueID(r))] = struct{}{} + peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{} } groupListMap := a.getPeerGroups(peerID) for _, peer := range aclPeers { - activeRoutes, _ := a.getRoutingPeerRoutes(peer.ID) + activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID) groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, groupListMap) filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership) routes = append(routes, filteredRoutes...) @@ -292,7 +297,7 @@ func (a *Account) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer) []*rou func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships lookupMap) []*route.Route { var filteredRoutes []*route.Route for _, r := range routes { - _, found := peerMemberships[string(route.GetHAUniqueID(r))] + _, found := peerMemberships[string(r.GetHAUniqueID())] if !found { filteredRoutes = append(filteredRoutes, r) } @@ -318,11 +323,11 @@ func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap looku // getRoutingPeerRoutes returns the enabled and disabled lists of routes that the given routing peer serves // Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. // If the given is not a routing peer, then the lists are empty. -func (a *Account) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) { +func (a *Account) getRoutingPeerRoutes(ctx context.Context, peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) { peer := a.GetPeer(peerID) if peer == nil { - log.Errorf("peer %s that doesn't exist under account %s", peerID, a.Id) + log.WithContext(ctx).Errorf("peer %s that doesn't exist under account %s", peerID, a.Id) return enabledRoutes, disabledRoutes } @@ -351,7 +356,7 @@ func (a *Account) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Ro for _, groupID := range r.PeerGroups { group := a.GetGroup(groupID) if group == nil { - log.Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id) + log.WithContext(ctx).Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id) continue } for _, id := range group.Peers { @@ -375,11 +380,13 @@ func (a *Account) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Ro return enabledRoutes, disabledRoutes } -// GetRoutesByPrefix return list of routes by account and route prefix -func (a *Account) GetRoutesByPrefix(prefix netip.Prefix) []*route.Route { +// GetRoutesByPrefixOrDomains return list of routes by account and route prefix +func (a *Account) GetRoutesByPrefixOrDomains(prefix netip.Prefix, domains domain.List) []*route.Route { var routes []*route.Route for _, r := range a.Routes { - if r.Network.String() == prefix.String() { + dynamic := r.IsDynamic() + if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() || + !dynamic && r.Network.String() == prefix.String() { routes = append(routes, r) } } @@ -393,7 +400,7 @@ func (a *Account) GetGroup(groupID string) *nbgroup.Group { } // GetPeerNetworkMap returns a group by ID if exists, nil otherwise -func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string, validatedPeersMap map[string]struct{}) *NetworkMap { +func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain string, validatedPeersMap map[string]struct{}) *NetworkMap { peer := a.Peers[peerID] if peer == nil { return &NetworkMap{ @@ -407,7 +414,7 @@ func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string, validatedPeersMap } } - aclPeers, firewallRules := a.getPeerConnectionResources(peerID, validatedPeersMap) + aclPeers, firewallRules := a.getPeerConnectionResources(ctx, peerID, validatedPeersMap) // exclude expired peers var peersToConnect []*nbpeer.Peer var expiredPeers []*nbpeer.Peer @@ -420,7 +427,7 @@ func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string, validatedPeersMap peersToConnect = append(peersToConnect, p) } - routesUpdate := a.getRoutesToSync(peerID, peersToConnect) + routesUpdate := a.getRoutesToSync(ctx, peerID, peersToConnect) dnsManagementStatus := a.getPeerDNSManagementStatus(peerID) dnsUpdate := nbdns.Config{ @@ -429,7 +436,7 @@ func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string, validatedPeersMap if dnsManagementStatus { var zones []nbdns.CustomZone - peersCustomZone := getPeersCustomZone(a, dnsDomain) + peersCustomZone := getPeersCustomZone(ctx, a, dnsDomain) if peersCustomZone.Domain != "" { zones = append(zones, peersCustomZone) } @@ -758,8 +765,13 @@ func (a *Account) GetPeer(peerID string) *nbpeer.Peer { return a.Peers[peerID] } -// SetJWTGroups to account and to user autoassigned groups +// SetJWTGroups updates the user's auto groups by synchronizing JWT groups. +// Returns true if there are changes in the JWT group membership. func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { + if len(groupsNames) == 0 { + return false + } + user, ok := a.Users[userID] if !ok { return false @@ -770,23 +782,19 @@ func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { existedGroupsByName[group.Name] = group } - // remove JWT groups from the autogroups, to sync them again - removed := 0 - jwtAutoGroups := make(map[string]struct{}) - for i, id := range user.AutoGroups { - if group, ok := a.Groups[id]; ok && group.Issued == nbgroup.GroupIssuedJWT { - jwtAutoGroups[group.Name] = struct{}{} - user.AutoGroups = append(user.AutoGroups[:i-removed], user.AutoGroups[i-removed+1:]...) - removed++ - } + newAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, a.Groups) + groupsToAdd := difference(groupsNames, maps.Keys(jwtGroupsMap)) + groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupsNames) + + // If no groups are added or removed, we should not sync account + if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { + return false } - // create JWT groups if they doesn't exist - // and all of them to the autogroups var modified bool - for _, name := range groupsNames { - group, ok := existedGroupsByName[name] - if !ok { + for _, name := range groupsToAdd { + group, exists := existedGroupsByName[name] + if !exists { group = &nbgroup.Group{ ID: xid.New().String(), Name: name, @@ -794,20 +802,20 @@ func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { } a.Groups[group.ID] = group } - // only JWT groups will be synced if group.Issued == nbgroup.GroupIssuedJWT { - user.AutoGroups = append(user.AutoGroups, group.ID) - if _, ok := jwtAutoGroups[name]; !ok { - modified = true - } - delete(jwtAutoGroups, name) + newAutoGroups = append(newAutoGroups, group.ID) + modified = true } } - // if not empty it means we removed some groups - if len(jwtAutoGroups) > 0 { + for name, id := range jwtGroupsMap { + if !slices.Contains(groupsToRemove, name) { + newAutoGroups = append(newAutoGroups, id) + continue + } modified = true } + user.AutoGroups = newAutoGroups return modified } @@ -865,7 +873,7 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) { } // BuildManager creates a new DefaultAccountManager with a provided Store -func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager, +func BuildManager(ctx context.Context, store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager, singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, geo *geolocation.Geolocation, userDeleteFromIDPEnabled bool, integratedPeerValidator integrated_validator.IntegratedValidator, @@ -884,7 +892,7 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage userDeleteFromIDPEnabled: userDeleteFromIDPEnabled, integratedPeerValidator: integratedPeerValidator, } - allAccounts := store.GetAllAccounts() + allAccounts := store.GetAllAccounts(ctx) // enable single account mode only if configured by user and number of existing accounts is not grater than 1 am.singleAccountMode = singleAccountModeDomain != "" && len(allAccounts) <= 1 if am.singleAccountMode { @@ -892,9 +900,9 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for a single account mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain) } am.singleAccountModeDomain = singleAccountModeDomain - log.Infof("single account mode enabled, accounts number %d", len(allAccounts)) + log.WithContext(ctx).Infof("single account mode enabled, accounts number %d", len(allAccounts)) } else { - log.Infof("single account mode disabled, accounts number %d", len(allAccounts)) + log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", len(allAccounts)) } // if account doesn't have a default group @@ -912,7 +920,7 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage } if shouldSave { - err = store.SaveAccount(account) + err = store.SaveAccount(ctx, account) if err != nil { return nil, err } @@ -930,16 +938,18 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage if !isNil(am.idpManager) { go func() { - err := am.warmupIDPCache() + err := am.warmupIDPCache(ctx) if err != nil { - log.Warnf("failed warming up cache due to error: %v", err) + log.WithContext(ctx).Warnf("failed warming up cache due to error: %v", err) // todo retry? return } }() } - am.integratedPeerValidator.SetPeerInvalidationListener(am.onPeersInvalidated) + am.integratedPeerValidator.SetPeerInvalidationListener(func(accountID string) { + am.onPeersInvalidated(ctx, accountID) + }) return am, nil } @@ -956,7 +966,7 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager { // Only users with role UserRoleAdmin can update the account. // User that performs the update has to belong to the account. // Returns an updated Account -func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error) { +func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) { halfYearLimit := 180 * 24 * time.Hour if newSettings.PeerLoginExpiration > halfYearLimit { return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") @@ -966,10 +976,10 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string, return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") } - unlock := am.Store.AcquireAccountWriteLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -983,7 +993,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string, return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account") } - err = am.integratedPeerValidator.ValidateExtraSettings(newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID) + err = am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID) if err != nil { return nil, err } @@ -993,21 +1003,21 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string, event := activity.AccountPeerLoginExpirationEnabled if !newSettings.PeerLoginExpirationEnabled { event = activity.AccountPeerLoginExpirationDisabled - am.peerLoginExpiry.Cancel([]string{accountID}) + am.peerLoginExpiry.Cancel(ctx, []string{accountID}) } else { - am.checkAndSchedulePeerLoginExpiration(account) + am.checkAndSchedulePeerLoginExpiration(ctx, account) } - am.StoreEvent(userID, accountID, accountID, event, nil) + am.StoreEvent(ctx, userID, accountID, accountID, event, nil) } if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration { - am.StoreEvent(userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil) - am.checkAndSchedulePeerLoginExpiration(account) + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil) + am.checkAndSchedulePeerLoginExpiration(ctx, account) } updatedAccount := account.UpdateSettings(newSettings) - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return nil, err } @@ -1015,14 +1025,14 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string, return updatedAccount, nil } -func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func() (time.Duration, bool) { +func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { return func() (time.Duration, bool) { - unlock := am.Store.AcquireAccountWriteLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { - log.Errorf("failed getting account %s expiring peers", account.Id) + log.WithContext(ctx).Errorf("failed getting account %s expiring peers", accountID) return account.GetNextPeerExpiration() } @@ -1032,10 +1042,10 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func() peerIDs = append(peerIDs, peer.ID) } - log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id) + log.WithContext(ctx).Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id) - if err := am.expireAndUpdatePeers(account, expiredPeers); err != nil { - log.Errorf("failed updating account peers while expiring peers for account %s", account.Id) + if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil { + log.WithContext(ctx).Errorf("failed updating account peers while expiring peers for account %s", account.Id) return account.GetNextPeerExpiration() } @@ -1043,28 +1053,28 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func() } } -func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(account *Account) { - am.peerLoginExpiry.Cancel([]string{account.Id}) +func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, account *Account) { + am.peerLoginExpiry.Cancel(ctx, []string{account.Id}) if nextRun, ok := account.GetNextPeerExpiration(); ok { - go am.peerLoginExpiry.Schedule(nextRun, account.Id, am.peerLoginExpirationJob(account.Id)) + go am.peerLoginExpiry.Schedule(ctx, nextRun, account.Id, am.peerLoginExpirationJob(ctx, account.Id)) } } // newAccount creates a new Account with a generated ID and generated default setup keys. // If ID is already in use (due to collision) we try one more time before returning error -func (am *DefaultAccountManager) newAccount(userID, domain string) (*Account, error) { +func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*Account, error) { for i := 0; i < 2; i++ { accountId := xid.New().String() - _, err := am.Store.GetAccount(accountId) + _, err := am.Store.GetAccount(ctx, accountId) statusErr, _ := status.FromError(err) switch { case err == nil: - log.Warnf("an account with ID already exists, retrying...") + log.WithContext(ctx).Warnf("an account with ID already exists, retrying...") continue case statusErr.Type() == status.NotFound: - newAccount := newAccountWithId(accountId, userID, domain) - am.StoreEvent(userID, newAccount.Id, accountId, activity.AccountCreated, nil) + newAccount := newAccountWithId(ctx, accountId, userID, domain) + am.StoreEvent(ctx, userID, newAccount.Id, accountId, activity.AccountCreated, nil) return newAccount, nil default: return nil, err @@ -1074,12 +1084,12 @@ func (am *DefaultAccountManager) newAccount(userID, domain string) (*Account, er return nil, status.Errorf(status.Internal, "error while creating new account") } -func (am *DefaultAccountManager) warmupIDPCache() error { - userData, err := am.idpManager.GetAllAccounts() +func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context) error { + userData, err := am.idpManager.GetAllAccounts(ctx) if err != nil { return err } - log.Infof("%d entries received from IdP management", len(userData)) + log.WithContext(ctx).Infof("%d entries received from IdP management", len(userData)) // If the Identity Provider does not support writing AppMetadata, // in cases like this, we expect it to return all users in an "unset" field. @@ -1087,7 +1097,7 @@ func (am *DefaultAccountManager) warmupIDPCache() error { // update their AppMetadata with the AccountID. if unsetData, ok := userData[idp.UnsetAccountID]; ok { for _, user := range unsetData { - accountID, err := am.Store.GetAccountByUser(user.ID) + accountID, err := am.Store.GetAccountByUser(ctx, user.ID) if err == nil { data := userData[accountID.Id] if data == nil { @@ -1110,15 +1120,15 @@ func (am *DefaultAccountManager) warmupIDPCache() error { return err } } - log.Infof("warmed up IDP cache with %d entries for %d accounts", rcvdUsers, len(userData)) + log.WithContext(ctx).Infof("warmed up IDP cache with %d entries for %d accounts", rcvdUsers, len(userData)) return nil } // DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner -func (am *DefaultAccountManager) DeleteAccount(accountID, userID string) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } @@ -1144,42 +1154,42 @@ func (am *DefaultAccountManager) DeleteAccount(accountID, userID string) error { continue } - deleteUserErr := am.deleteRegularUser(account, userID, otherUser.Id) + deleteUserErr := am.deleteRegularUser(ctx, account, userID, otherUser.Id) if deleteUserErr != nil { return deleteUserErr } } - err = am.deleteRegularUser(account, userID, userID) + err = am.deleteRegularUser(ctx, account, userID, userID) if err != nil { - log.Errorf("failed deleting user %s. error: %s", userID, err) + log.WithContext(ctx).Errorf("failed deleting user %s. error: %s", userID, err) return err } - err = am.Store.DeleteAccount(account) + err = am.Store.DeleteAccount(ctx, account) if err != nil { - log.Errorf("failed deleting account %s. error: %s", accountID, err) + log.WithContext(ctx).Errorf("failed deleting account %s. error: %s", accountID, err) return err } // cancel peer login expiry job - am.peerLoginExpiry.Cancel([]string{account.Id}) + am.peerLoginExpiry.Cancel(ctx, []string{account.Id}) - log.Debugf("account %s deleted", accountID) + log.WithContext(ctx).Debugf("account %s deleted", accountID) return nil } // GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and // userID doesn't have an account associated with it, one account is created // domain is used to create a new account if no account is found -func (am *DefaultAccountManager) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) { +func (am *DefaultAccountManager) GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error) { if accountID != "" { - return am.Store.GetAccount(accountID) + return am.Store.GetAccount(ctx, accountID) } else if userID != "" { - account, err := am.GetOrCreateAccountByUser(userID, domain) + account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) if err != nil { return nil, status.Errorf(status.NotFound, "account not found using user id: %s", userID) } - err = am.addAccountIDToIDPAppMeta(userID, account) + err = am.addAccountIDToIDPAppMeta(ctx, userID, account) if err != nil { return nil, err } @@ -1194,28 +1204,28 @@ func isNil(i idp.Manager) bool { } // addAccountIDToIDPAppMeta update user's app metadata in idp manager -func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(userID string, account *Account) error { +func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, account *Account) error { if !isNil(am.idpManager) { // user can be nil if it wasn't found (e.g., just created) - user, err := am.lookupUserInCache(userID, account) + user, err := am.lookupUserInCache(ctx, userID, account) if err != nil { return err } if user != nil && user.AppMetadata.WTAccountID == account.Id { // it was already set, so we skip the unnecessary update - log.Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s", + log.WithContext(ctx).Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s", account.Id, userID) return nil } - err = am.idpManager.UpdateUserAppMetadata(userID, idp.AppMetadata{WTAccountID: account.Id}) + err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: account.Id}) if err != nil { return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err) } // refresh cache to reflect the update - _, err = am.refreshCache(account.Id) + _, err = am.refreshCache(ctx, account.Id) if err != nil { return err } @@ -1223,20 +1233,20 @@ func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(userID string, account return nil } -func (am *DefaultAccountManager) loadAccount(_ context.Context, accountID interface{}) ([]*idp.UserData, error) { - log.Debugf("account %s not found in cache, reloading", accountID) +func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID interface{}) ([]*idp.UserData, error) { + log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID) accountIDString := fmt.Sprintf("%v", accountID) - account, err := am.Store.GetAccount(accountIDString) + account, err := am.Store.GetAccount(ctx, accountIDString) if err != nil { return nil, err } - userData, err := am.idpManager.GetAccount(accountIDString) + userData, err := am.idpManager.GetAccount(ctx, accountIDString) if err != nil { return nil, err } - log.Debugf("%d entries received from IdP management", len(userData)) + log.WithContext(ctx).Debugf("%d entries received from IdP management", len(userData)) dataMap := make(map[string]*idp.UserData, len(userData)) for _, datum := range userData { @@ -1250,7 +1260,7 @@ func (am *DefaultAccountManager) loadAccount(_ context.Context, accountID interf } datum, ok := dataMap[user.Id] if !ok { - log.Warnf("user %s not found in IDP", user.Id) + log.WithContext(ctx).Warnf("user %s not found in IDP", user.Id) continue } matchedUserData = append(matchedUserData, datum) @@ -1258,8 +1268,8 @@ func (am *DefaultAccountManager) loadAccount(_ context.Context, accountID interf return matchedUserData, nil } -func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountID string) (*idp.UserData, error) { - data, err := am.getAccountFromCache(accountID, false) +func (am *DefaultAccountManager) lookupUserInCacheByEmail(ctx context.Context, email string, accountID string) (*idp.UserData, error) { + data, err := am.getAccountFromCache(ctx, accountID, false) if err != nil { return nil, err } @@ -1274,7 +1284,7 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountI } // lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil -func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Account) (*idp.UserData, error) { +func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, account *Account) (*idp.UserData, error) { users := make(map[string]userLoggedInOnce, len(account.Users)) // ignore service users and users provisioned by integrations than are never logged in for _, user := range account.Users { @@ -1286,8 +1296,8 @@ func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Accou } users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero()) } - log.Debugf("looking up user %s of account %s in cache", userID, account.Id) - userData, err := am.lookupCache(users, account.Id) + log.WithContext(ctx).Debugf("looking up user %s of account %s in cache", userID, account.Id) + userData, err := am.lookupCache(ctx, users, account.Id) if err != nil { return nil, err } @@ -1302,25 +1312,25 @@ func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Accou // or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta user, err := account.FindUser(userID) if err != nil { - log.Errorf("failed finding user %s in account %s", userID, account.Id) + log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, account.Id) return nil, err } key := user.IntegrationReference.CacheKey(account.Id, userID) ud, err := am.externalCacheManager.Get(am.ctx, key) if err != nil { - log.Debugf("failed to get externalCache for key: %s, error: %s", key, err) + log.WithContext(ctx).Debugf("failed to get externalCache for key: %s, error: %s", key, err) } return ud, nil } -func (am *DefaultAccountManager) refreshCache(accountID string) ([]*idp.UserData, error) { - return am.getAccountFromCache(accountID, true) +func (am *DefaultAccountManager) refreshCache(ctx context.Context, accountID string) ([]*idp.UserData, error) { + return am.getAccountFromCache(ctx, accountID, true) } // getAccountFromCache returns user data for a given account ensuring that cache load happens only once -func (am *DefaultAccountManager) getAccountFromCache(accountID string, forceReload bool) ([]*idp.UserData, error) { +func (am *DefaultAccountManager) getAccountFromCache(ctx context.Context, accountID string, forceReload bool) ([]*idp.UserData, error) { am.cacheMux.Lock() loadingChan := am.cacheLoading[accountID] if loadingChan == nil { @@ -1346,7 +1356,7 @@ func (am *DefaultAccountManager) getAccountFromCache(accountID string, forceRelo } am.cacheMux.Unlock() - log.Debugf("one request to get account %s is already running", accountID) + log.WithContext(ctx).Debugf("one request to get account %s is already running", accountID) select { case <-loadingChan: @@ -1357,19 +1367,19 @@ func (am *DefaultAccountManager) getAccountFromCache(accountID string, forceRelo } } -func (am *DefaultAccountManager) lookupCache(accountUsers map[string]userLoggedInOnce, accountID string) ([]*idp.UserData, error) { +func (am *DefaultAccountManager) lookupCache(ctx context.Context, accountUsers map[string]userLoggedInOnce, accountID string) ([]*idp.UserData, error) { var data []*idp.UserData var err error maxAttempts := 2 - data, err = am.getAccountFromCache(accountID, false) + data, err = am.getAccountFromCache(ctx, accountID, false) if err != nil { return nil, err } for attempt := 1; attempt <= maxAttempts; attempt++ { - if am.isCacheFresh(accountUsers, data) { + if am.isCacheFresh(ctx, accountUsers, data) { return data, nil } @@ -1377,14 +1387,14 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]userLoggedI time.Sleep(200 * time.Millisecond) } - log.Infof("refreshing cache for account %s", accountID) - data, err = am.refreshCache(accountID) + log.WithContext(ctx).Infof("refreshing cache for account %s", accountID) + data, err = am.refreshCache(ctx, accountID) if err != nil { return nil, err } if attempt == maxAttempts { - log.Warnf("cache for account %s reached maximum refresh attempts (%d)", accountID, maxAttempts) + log.WithContext(ctx).Warnf("cache for account %s reached maximum refresh attempts (%d)", accountID, maxAttempts) } } @@ -1392,7 +1402,7 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]userLoggedI } // isCacheFresh checks if the cache is refreshed already by comparing the accountUsers with the cache data by user count and user invite status -func (am *DefaultAccountManager) isCacheFresh(accountUsers map[string]userLoggedInOnce, data []*idp.UserData) bool { +func (am *DefaultAccountManager) isCacheFresh(ctx context.Context, accountUsers map[string]userLoggedInOnce, data []*idp.UserData) bool { userDataMap := make(map[string]*idp.UserData, len(data)) for _, datum := range data { userDataMap[datum.ID] = datum @@ -1405,26 +1415,26 @@ func (am *DefaultAccountManager) isCacheFresh(accountUsers map[string]userLogged if datum, ok := userDataMap[user]; ok { // check if the matching user data has a pending invite and if the user has logged in once, forcing the cache to be refreshed if datum.AppMetadata.WTPendingInvite != nil && *datum.AppMetadata.WTPendingInvite && loggedInOnce == true { //nolint:gosimple - log.Infof("user %s has a pending invite and has logged in once, cache invalid", user) + log.WithContext(ctx).Infof("user %s has a pending invite and has logged in once, cache invalid", user) return false } knownUsersCount-- continue } - log.Debugf("cache doesn't know about %s user", user) + log.WithContext(ctx).Debugf("cache doesn't know about %s user", user) } // if we know users that are not yet in cache more likely cache is outdated if knownUsersCount > 0 { - log.Infof("cache invalid. Users unknown to the cache: %d", knownUsersCount) + log.WithContext(ctx).Infof("cache invalid. Users unknown to the cache: %d", knownUsersCount) return false } return true } -func (am *DefaultAccountManager) removeUserFromCache(accountID, userID string) error { - data, err := am.getAccountFromCache(accountID, false) +func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accountID, userID string) error { + data, err := am.getAccountFromCache(ctx, accountID, false) if err != nil { return err } @@ -1440,7 +1450,7 @@ func (am *DefaultAccountManager) removeUserFromCache(accountID, userID string) e } // updateAccountDomainAttributes updates the account domain attributes and then, saves the account -func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account, claims jwtclaims.AuthorizationClaims, +func (am *DefaultAccountManager) updateAccountDomainAttributes(ctx context.Context, account *Account, claims jwtclaims.AuthorizationClaims, primaryDomain bool, ) error { @@ -1457,10 +1467,10 @@ func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account, account.DomainCategory = claims.DomainCategory } } else { - log.Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims) + log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims) } - err := am.Store.SaveAccount(account) + err := am.Store.SaveAccount(ctx, account) if err != nil { return err } @@ -1469,17 +1479,18 @@ func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account, // handleExistingUserAccount handles existing User accounts and update its domain attributes. func (am *DefaultAccountManager) handleExistingUserAccount( + ctx context.Context, existingAcc *Account, primaryDomain bool, claims jwtclaims.AuthorizationClaims, ) error { - err := am.updateAccountDomainAttributes(existingAcc, claims, primaryDomain) + err := am.updateAccountDomainAttributes(ctx, existingAcc, claims, primaryDomain) if err != nil { return err } // we should register the account ID to this user's metadata in our IDP manager - err = am.addAccountIDToIDPAppMeta(claims.UserId, existingAcc) + err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, existingAcc) if err != nil { return err } @@ -1489,7 +1500,7 @@ func (am *DefaultAccountManager) handleExistingUserAccount( // handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account, // otherwise it will create a new account and make it primary account for the domain. -func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) { +func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) { if claims.UserId == "" { return nil, fmt.Errorf("user ID is empty") } @@ -1502,40 +1513,40 @@ func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims if domainAcc != nil { account = domainAcc account.Users[claims.UserId] = NewRegularUser(claims.UserId) - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return nil, err } } else { - account, err = am.newAccount(claims.UserId, lowerDomain) + account, err = am.newAccount(ctx, claims.UserId, lowerDomain) if err != nil { return nil, err } - err = am.updateAccountDomainAttributes(account, claims, true) + err = am.updateAccountDomainAttributes(ctx, account, claims, true) if err != nil { return nil, err } } - err = am.addAccountIDToIDPAppMeta(claims.UserId, account) + err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, account) if err != nil { return nil, err } - am.StoreEvent(claims.UserId, claims.UserId, account.Id, activity.UserJoined, nil) + am.StoreEvent(ctx, claims.UserId, claims.UserId, account.Id, activity.UserJoined, nil) return account, nil } // redeemInvite checks whether user has been invited and redeems the invite -func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) error { +func (am *DefaultAccountManager) redeemInvite(ctx context.Context, account *Account, userID string) error { // only possible with the enabled IdP manager if am.idpManager == nil { - log.Warnf("invites only work with enabled IdP manager") + log.WithContext(ctx).Warnf("invites only work with enabled IdP manager") return nil } - user, err := am.lookupUserInCache(userID, account) + user, err := am.lookupUserInCache(ctx, userID, account) if err != nil { return err } @@ -1545,17 +1556,17 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e } if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite { - log.Infof("redeeming invite for user %s account %s", userID, account.Id) + log.WithContext(ctx).Infof("redeeming invite for user %s account %s", userID, account.Id) // User has already logged in, meaning that IdP should have set wt_pending_invite to false. // Our job is to just reload cache. go func() { - _, err = am.refreshCache(account.Id) + _, err = am.refreshCache(ctx, account.Id) if err != nil { - log.Warnf("failed reloading cache when redeeming user %s under account %s", userID, account.Id) + log.WithContext(ctx).Warnf("failed reloading cache when redeeming user %s under account %s", userID, account.Id) return } - log.Debugf("user %s of account %s redeemed invite", user.ID, account.Id) - am.StoreEvent(userID, userID, account.Id, activity.UserJoined, nil) + log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", user.ID, account.Id) + am.StoreEvent(ctx, userID, userID, account.Id, activity.UserJoined, nil) }() } @@ -1563,22 +1574,22 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e } // MarkPATUsed marks a personal access token as used -func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error { +func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string) error { - user, err := am.Store.GetUserByTokenID(tokenID) + user, err := am.Store.GetUserByTokenID(ctx, tokenID) if err != nil { return err } - account, err := am.Store.GetAccountByUser(user.Id) + account, err := am.Store.GetAccountByUser(ctx, user.Id) if err != nil { return err } - unlock := am.Store.AcquireAccountWriteLock(account.Id) + unlock := am.Store.AcquireAccountWriteLock(ctx, account.Id) defer unlock() - account, err = am.Store.GetAccountByUser(user.Id) + account, err = am.Store.GetAccountByUser(ctx, user.Id) if err != nil { return err } @@ -1590,11 +1601,11 @@ func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error { pat.LastUsed = time.Now().UTC() - return am.Store.SaveAccount(account) + return am.Store.SaveAccount(ctx, account) } // GetAccountFromPAT returns Account and User associated with a personal access token -func (am *DefaultAccountManager) GetAccountFromPAT(token string) (*Account, *User, *PersonalAccessToken, error) { +func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*Account, *User, *PersonalAccessToken, error) { if len(token) != PATLength { return nil, nil, nil, fmt.Errorf("token has wrong length") } @@ -1618,17 +1629,17 @@ func (am *DefaultAccountManager) GetAccountFromPAT(token string) (*Account, *Use hashedToken := sha256.Sum256([]byte(token)) encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) - tokenID, err := am.Store.GetTokenIDByHashedToken(encodedHashedToken) + tokenID, err := am.Store.GetTokenIDByHashedToken(ctx, encodedHashedToken) if err != nil { return nil, nil, nil, err } - user, err := am.Store.GetUserByTokenID(tokenID) + user, err := am.Store.GetUserByTokenID(ctx, tokenID) if err != nil { return nil, nil, nil, err } - account, err := am.Store.GetAccountByUser(user.Id) + account, err := am.Store.GetAccountByUser(ctx, user.Id) if err != nil { return nil, nil, nil, err } @@ -1642,7 +1653,7 @@ func (am *DefaultAccountManager) GetAccountFromPAT(token string) (*Account, *Use } // GetAccountFromToken returns an account associated with this token -func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) { +func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) { if claims.UserId == "" { return nil, nil, fmt.Errorf("user ID is empty") } @@ -1651,14 +1662,14 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat // We override incoming domain claims to group users under a single account. claims.Domain = am.singleAccountModeDomain claims.DomainCategory = PrivateCategory - log.Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled") + log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled") } - newAcc, err := am.getAccountWithAuthorizationClaims(claims) + newAcc, err := am.getAccountWithAuthorizationClaims(ctx, claims) if err != nil { return nil, nil, err } - unlock := am.Store.AcquireAccountWriteLock(newAcc.Id) + unlock := am.Store.AcquireAccountWriteLock(ctx, newAcc.Id) alreadyUnlocked := false defer func() { if !alreadyUnlocked { @@ -1666,7 +1677,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat } }() - account, err := am.Store.GetAccount(newAcc.Id) + account, err := am.Store.GetAccount(ctx, newAcc.Id) if err != nil { return nil, nil, err } @@ -1678,7 +1689,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat } if !user.IsServiceUser && claims.Invited { - err = am.redeemInvite(account, claims.UserId) + err = am.redeemInvite(ctx, account, claims.UserId) if err != nil { return nil, nil, err } @@ -1686,7 +1697,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat if account.Settings.JWTGroupsEnabled { if account.Settings.JWTGroupsClaimName == "" { - log.Errorf("JWT groups are enabled but no claim name is set") + log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set") return account, user, nil } if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok { @@ -1696,7 +1707,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat if g, ok := item.(string); ok { groupsNames = append(groupsNames, g) } else { - log.Errorf("JWT claim %q is not a string: %v", account.Settings.JWTGroupsClaimName, item) + log.WithContext(ctx).Errorf("JWT claim %q is not a string: %v", account.Settings.JWTGroupsClaimName, item) } } @@ -1711,15 +1722,16 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat account.UserGroupsAddToPeers(claims.UserId, addNewGroups...) account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...) account.Network.IncSerial() - if err := am.Store.SaveAccount(account); err != nil { - log.Errorf("failed to save account: %v", err) + if err := am.Store.SaveAccount(ctx, account); err != nil { + log.WithContext(ctx).Errorf("failed to save account: %v", err) } else { - am.updateAccountPeers(account) + log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) + am.updateAccountPeers(ctx, account) unlock() alreadyUnlocked = true for _, g := range addNewGroups { if group := account.GetGroup(g); group != nil { - am.StoreEvent(user.Id, user.Id, account.Id, activity.GroupAddedToUser, + am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser, map[string]any{ "group": group.Name, "group_id": group.ID, @@ -1729,7 +1741,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat } for _, g := range removeOldGroups { if group := account.GetGroup(g); group != nil { - am.StoreEvent(user.Id, user.Id, account.Id, activity.GroupRemovedFromUser, + am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser, map[string]any{ "group": group.Name, "group_id": group.ID, @@ -1740,16 +1752,16 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat } } } else { - if err := am.Store.SaveAccount(account); err != nil { - log.Errorf("failed to save account: %v", err) + if err := am.Store.SaveAccount(ctx, account); err != nil { + log.WithContext(ctx).Errorf("failed to save account: %v", err) } } } } else { - log.Debugf("JWT claim %q is not a string array", account.Settings.JWTGroupsClaimName) + log.WithContext(ctx).Debugf("JWT claim %q is not a string array", account.Settings.JWTGroupsClaimName) } } else { - log.Debugf("JWT claim %q not found", account.Settings.JWTGroupsClaimName) + log.WithContext(ctx).Debugf("JWT claim %q not found", account.Settings.JWTGroupsClaimName) } } @@ -1773,8 +1785,8 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat // Existing user + Existing account + Existing Indexed Domain -> Nothing changes // // Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain) -func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error) { - log.Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"", +func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, error) { + log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"", claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory) if claims.UserId == "" { return nil, fmt.Errorf("user ID is empty") @@ -1782,9 +1794,9 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla // if Account ID is part of the claims // it means that we've already classified the domain and user has an account if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { - return am.GetAccountByUserOrAccountID(claims.UserId, claims.AccountId, claims.Domain) + return am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain) } else if claims.AccountId != "" { - accountFromID, err := am.Store.GetAccount(claims.AccountId) + accountFromID, err := am.Store.GetAccount(ctx, claims.AccountId) if err != nil { return nil, err } @@ -1797,12 +1809,12 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla } start := time.Now() - unlock := am.Store.AcquireGlobalLock() + unlock := am.Store.AcquireGlobalLock(ctx) defer unlock() - log.Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId) + log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId) // We checked if the domain has a primary account already - domainAccount, err := am.Store.GetAccountByPrivateDomain(claims.Domain) + domainAccount, err := am.Store.GetAccountByPrivateDomain(ctx, claims.Domain) if err != nil { // if NotFound we are good to continue, otherwise return error e, ok := status.FromError(err) @@ -1811,11 +1823,11 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla } } - account, err := am.Store.GetAccountByUser(claims.UserId) + account, err := am.Store.GetAccountByUser(ctx, claims.UserId) if err == nil { - unlockAccount := am.Store.AcquireAccountWriteLock(account.Id) + unlockAccount := am.Store.AcquireAccountWriteLock(ctx, account.Id) defer unlockAccount() - account, err = am.Store.GetAccountByUser(claims.UserId) + account, err = am.Store.GetAccountByUser(ctx, claims.UserId) if err != nil { return nil, err } @@ -1826,59 +1838,59 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla // and peers that shouldn't be lost. primaryDomain := domainAccount == nil || account.Id == domainAccount.Id - err = am.handleExistingUserAccount(account, primaryDomain, claims) + err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims) if err != nil { return nil, err } return account, nil } else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { if domainAccount != nil { - unlockAccount := am.Store.AcquireAccountWriteLock(domainAccount.Id) + unlockAccount := am.Store.AcquireAccountWriteLock(ctx, domainAccount.Id) defer unlockAccount() - domainAccount, err = am.Store.GetAccountByPrivateDomain(claims.Domain) + domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain) if err != nil { return nil, err } } - return am.handleNewUserAccount(domainAccount, claims) + return am.handleNewUserAccount(ctx, domainAccount, claims) } else { // other error return nil, err } } -func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error) { - accountID, err := am.Store.GetAccountIDByPeerPubKey(peerPubKey) +func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { + accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey) if err != nil { if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { - return nil, nil, status.Errorf(status.Unauthenticated, "peer not registered") + return nil, nil, nil, status.Errorf(status.Unauthenticated, "peer not registered") } - return nil, nil, err + return nil, nil, nil, err } - unlock := am.Store.AcquireAccountReadLock(accountID) + unlock := am.Store.AcquireAccountReadLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { - return nil, nil, err + return nil, nil, nil, err } - peer, netMap, err := am.SyncPeer(PeerSync{WireGuardPubKey: peerPubKey}, account) + peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account) if err != nil { - return nil, nil, err + return nil, nil, nil, err } - err = am.MarkPeerConnected(peerPubKey, true, realIP, account) + err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, account) if err != nil { - log.Warnf("failed marking peer as connected %s %v", peerPubKey, err) + log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) } - return peer, netMap, nil + return peer, netMap, postureChecks, nil } -func (am *DefaultAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error { - accountID, err := am.Store.GetAccountIDByPeerPubKey(peer.Key) +func (am *DefaultAccountManager) CancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) error { + accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peer.Key) if err != nil { if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { return status.Errorf(status.Unauthenticated, "peer not registered") @@ -1886,23 +1898,44 @@ func (am *DefaultAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error { return err } - unlock := am.Store.AcquireAccountWriteLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } - err = am.MarkPeerConnected(peer.Key, false, nil, account) + err = am.MarkPeerConnected(ctx, peer.Key, false, nil, account) if err != nil { - log.Warnf("failed marking peer as connected %s %v", peer.Key, err) + log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peer.Key, err) } return nil } +func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error { + accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey) + if err != nil { + return err + } + + unlock := am.Store.AcquireAccountReadLock(ctx, accountID) + defer unlock() + + account, err := am.Store.GetAccount(ctx, accountID) + if err != nil { + return err + } + + _, _, _, err = am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, account) + if err != nil { + return mapError(ctx, err) + } + return nil +} + // GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers() func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) { return am.peersUpdateManager.GetAllConnectedPeers(), nil @@ -1926,8 +1959,8 @@ func (am *DefaultAccountManager) GetDNSDomain() string { // CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT // group propagation and set the list of groups with access permissions. -func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error { - account, _, err := am.GetAccountFromToken(claims) +func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error { + account, _, err := am.GetAccountFromToken(ctx, claims) if err != nil { return err } @@ -1957,20 +1990,24 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.Aut return nil } -func (am *DefaultAccountManager) onPeersInvalidated(accountID string) { - log.Debugf("validated peers has been invalidated for account %s", accountID) - updatedAccount, err := am.Store.GetAccount(accountID) +func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) { + log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID) + updatedAccount, err := am.Store.GetAccount(ctx, accountID) if err != nil { - log.Errorf("failed to get account %s: %v", accountID, err) + log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err) return } - am.updateAccountPeers(updatedAccount) + am.updateAccountPeers(ctx, updatedAccount) } func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { return am.Store.GetPostureCheckByChecksDefinition(accountID, checks) } +func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) { + return am.Store.GetAccountIDByPeerPubKey(ctx, peerKey) +} + // addAllGroup to account object if it doesn't exist func addAllGroup(account *Account) error { if len(account.Groups) == 0 { @@ -2012,8 +2049,8 @@ func addAllGroup(account *Account) error { } // newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id -func newAccountWithId(accountID, userID, domain string) *Account { - log.Debugf("creating new account") +func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Account { + log.WithContext(ctx).Debugf("creating new account") network := NewNetwork() peers := make(map[string]*nbpeer.Peer) @@ -2025,7 +2062,7 @@ func newAccountWithId(accountID, userID, domain string) *Account { dnsSettings := DNSSettings{ DisabledManagementGroups: make([]string, 0), } - log.Debugf("created new account %s", accountID) + log.WithContext(ctx).Debugf("created new account %s", accountID) acc := &Account{ Id: accountID, @@ -2048,7 +2085,7 @@ func newAccountWithId(accountID, userID, domain string) *Account { } if err := addAllGroup(acc); err != nil { - log.Errorf("error adding all group to account %s: %v", acc.Id, err) + log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err) } return acc } @@ -2064,3 +2101,22 @@ func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool { } return false } + +// separateGroups separates user's auto groups into non-JWT and JWT groups. +// Returns the list of standard auto groups and a map of JWT auto groups, +// where the keys are the group names and the values are the group IDs. +func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([]string, map[string]string) { + newAutoGroups := make([]string, 0) + jwtAutoGroups := make(map[string]string) // map of group name to group ID + + for _, id := range autoGroups { + if group, ok := allGroups[id]; ok { + if group.Issued == nbgroup.GroupIssuedJWT { + jwtAutoGroups[group.Name] = id + } else { + newAutoGroups = append(newAutoGroups, id) + } + } + } + return newAutoGroups, jwtAutoGroups +} diff --git a/management/server/account_test.go b/management/server/account_test.go index 38c9fabbc..71b43bd65 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "crypto/sha256" b64 "encoding/base64" "encoding/json" @@ -29,11 +30,11 @@ import ( type MocIntegratedValidator struct { } -func (a MocIntegratedValidator) ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { +func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { return nil } -func (a MocIntegratedValidator) ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) { +func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) { return update, nil } func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { @@ -44,15 +45,15 @@ func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[s return validatedPeers, nil } -func (MocIntegratedValidator) PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { +func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { return peer } -func (MocIntegratedValidator) IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) { +func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) { return false, false, nil } -func (MocIntegratedValidator) PeerDeleted(_, _ string) error { +func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error { return nil } @@ -60,7 +61,7 @@ func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string) } -func (MocIntegratedValidator) Stop() { +func (MocIntegratedValidator) Stop(_ context.Context) { } func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Account, userID string) { @@ -85,7 +86,7 @@ func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Ac setupKey = key.Key } - _, _, err := manager.AddPeer(setupKey, userID, peer) + _, _, _, err := manager.AddPeer(context.Background(), setupKey, userID, peer) if err != nil { t.Error("expected to add new peer successfully after creating new account, but failed", err) } @@ -395,7 +396,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { } for _, testCase := range tt { - account := newAccountWithId("account-1", userID, "netbird.io") + account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io") account.UpdateSettings(&testCase.accountSettings) account.Network = network account.Peers = testCase.peers @@ -409,7 +410,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { validatedPeers[p] = struct{}{} } - networkMap := account.GetPeerNetworkMap(testCase.peerID, "netbird.io", validatedPeers) + networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, "netbird.io", validatedPeers) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) } @@ -419,7 +420,7 @@ func TestNewAccount(t *testing.T) { domain := "netbird.io" userId := "account_creator" accountID := "account_id" - account := newAccountWithId(accountID, userId, domain) + account := newAccountWithId(context.Background(), accountID, userId, domain) verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId}) } @@ -430,7 +431,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { return } - account, err := manager.GetOrCreateAccountByUser(userID, "") + account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") if err != nil { t.Fatal(err) } @@ -439,7 +440,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { return } - account, err = manager.Store.GetAccountByUser(userID) + account, err = manager.Store.GetAccountByUser(context.Background(), userID) if err != nil { t.Errorf("expected to get existing account after creation, no account was found for a user %s", userID) return @@ -630,11 +631,11 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - initAccount, err := manager.GetAccountByUserOrAccountID(testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain) + initAccount, err := manager.GetAccountByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain) require.NoError(t, err, "create init user failed") if testCase.inputUpdateAttrs { - err = manager.updateAccountDomainAttributes(initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) + err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) require.NoError(t, err, "update init user failed") } @@ -642,7 +643,7 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) { testCase.inputClaims.AccountId = initAccount.Id } - account, _, err := manager.GetAccountFromToken(testCase.inputClaims) + account, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims) require.NoError(t, err, "support function failed") verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers) verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy) @@ -661,12 +662,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { userId := "user-id" domain := "test.domain" - initAccount := newAccountWithId("", userId, domain) + initAccount := newAccountWithId(context.Background(), "", userId, domain) manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") accountID := initAccount.Id - acc, err := manager.GetAccountByUserOrAccountID(userId, accountID, domain) + acc, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, accountID, domain) require.NoError(t, err, "create init user failed") // as initAccount was created without account id we have to take the id after account initialization // that happens inside the GetAccountByUserOrAccountID where the id is getting generated @@ -682,18 +683,18 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { } t.Run("JWT groups disabled", func(t *testing.T) { - account, _, err := manager.GetAccountFromToken(claims) + account, _, err := manager.GetAccountFromToken(context.Background(), claims) require.NoError(t, err, "get account by token failed") require.Len(t, account.Groups, 1, "only ALL group should exists") }) t.Run("JWT groups enabled without claim name", func(t *testing.T) { initAccount.Settings.JWTGroupsEnabled = true - err := manager.Store.SaveAccount(initAccount) + err := manager.Store.SaveAccount(context.Background(), initAccount) require.NoError(t, err, "save account failed") - require.Len(t, manager.Store.GetAllAccounts(), 1, "only one account should exist") + require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist") - account, _, err := manager.GetAccountFromToken(claims) + account, _, err := manager.GetAccountFromToken(context.Background(), claims) require.NoError(t, err, "get account by token failed") require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT") }) @@ -701,11 +702,11 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { t.Run("JWT groups enabled", func(t *testing.T) { initAccount.Settings.JWTGroupsEnabled = true initAccount.Settings.JWTGroupsClaimName = "idp-groups" - err := manager.Store.SaveAccount(initAccount) + err := manager.Store.SaveAccount(context.Background(), initAccount) require.NoError(t, err, "save account failed") - require.Len(t, manager.Store.GetAllAccounts(), 1, "only one account should exist") + require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist") - account, _, err := manager.GetAccountFromToken(claims) + account, _, err := manager.GetAccountFromToken(context.Background(), claims) require.NoError(t, err, "get account by token failed") require.Len(t, account.Groups, 3, "groups should be added to the account") @@ -728,7 +729,7 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { func TestAccountManager_GetAccountFromPAT(t *testing.T) { store := newStore(t) - account := newAccountWithId("account_id", "testuser", "") + account := newAccountWithId(context.Background(), "account_id", "testuser", "") token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" hashedToken := sha256.Sum256([]byte(token)) @@ -742,7 +743,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) { }, }, } - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -751,7 +752,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) { Store: store, } - account, user, pat, err := am.GetAccountFromPAT(token) + account, user, pat, err := am.GetAccountFromPAT(context.Background(), token) if err != nil { t.Fatalf("Error when getting Account from PAT: %s", err) } @@ -763,7 +764,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) { func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { store := newStore(t) - account := newAccountWithId("account_id", "testuser", "") + account := newAccountWithId(context.Background(), "account_id", "testuser", "") token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" hashedToken := sha256.Sum256([]byte(token)) @@ -778,7 +779,7 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { }, }, } - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -787,12 +788,12 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { Store: store, } - err = am.MarkPATUsed("tokenId") + err = am.MarkPATUsed(context.Background(), "tokenId") if err != nil { t.Fatalf("Error when marking PAT used: %s", err) } - account, err = am.Store.GetAccount("account_id") + account, err = am.Store.GetAccount(context.Background(), "account_id") if err != nil { t.Fatalf("Error when getting account: %s", err) } @@ -807,7 +808,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) { } userId := "test_user" - account, err := manager.GetOrCreateAccountByUser(userId, "") + account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, "") if err != nil { t.Fatal(err) } @@ -815,7 +816,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) { t.Fatalf("expected to create an account for a user %s", userId) } - account, err = manager.Store.GetAccountByUser(userId) + account, err = manager.Store.GetAccountByUser(context.Background(), userId) if err != nil { t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId) } @@ -834,7 +835,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) { userId := "test_user" domain := "hotmail.com" - account, err := manager.GetOrCreateAccountByUser(userId, domain) + account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, domain) if err != nil { t.Fatal(err) } @@ -848,7 +849,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) { domain = "gmail.com" - account, err = manager.GetOrCreateAccountByUser(userId, domain) + account, err = manager.GetOrCreateAccountByUser(context.Background(), userId, domain) if err != nil { t.Fatalf("got the following error while retrieving existing acc: %v", err) } @@ -871,7 +872,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { userId := "test_user" - account, err := manager.GetAccountByUserOrAccountID(userId, "", "") + account, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, "", "") if err != nil { t.Fatal(err) } @@ -880,20 +881,20 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { return } - _, err = manager.GetAccountByUserOrAccountID("", account.Id, "") + _, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "") if err != nil { t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", account.Id) } - _, err = manager.GetAccountByUserOrAccountID("", "", "") + _, err = manager.GetAccountByUserOrAccountID(context.Background(), "", "", "") if err == nil { t.Errorf("expected an error when user and account IDs are empty") } } func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*Account, error) { - account := newAccountWithId(accountID, userID, domain) - err := am.Store.SaveAccount(account) + account := newAccountWithId(context.Background(), accountID, userID, domain) + err := am.Store.SaveAccount(context.Background(), account) if err != nil { return nil, err } @@ -915,7 +916,7 @@ func TestAccountManager_GetAccount(t *testing.T) { } // AddAccount has been already tested so we can assume it is correct and compare results - getAccount, err := manager.Store.GetAccount(account.Id) + getAccount, err := manager.Store.GetAccount(context.Background(), account.Id) if err != nil { t.Fatal(err) return @@ -952,12 +953,12 @@ func TestAccountManager_DeleteAccount(t *testing.T) { t.Fatal(err) } - err = manager.DeleteAccount(account.Id, userId) + err = manager.DeleteAccount(context.Background(), account.Id, userId) if err != nil { t.Fatal(err) } - getAccount, err := manager.Store.GetAccount(account.Id) + getAccount, err := manager.Store.GetAccount(context.Background(), account.Id) if err == nil { t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount)) } @@ -978,7 +979,7 @@ func TestAccountManager_AddPeer(t *testing.T) { serial := account.Network.CurrentSerial() // should be 0 - setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) if err != nil { t.Fatal("error creating setup key") return @@ -997,7 +998,7 @@ func TestAccountManager_AddPeer(t *testing.T) { expectedPeerKey := key.PublicKey().String() expectedSetupKey := setupKey.Key - peer, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, }) @@ -1006,7 +1007,7 @@ func TestAccountManager_AddPeer(t *testing.T) { return } - account, err = manager.Store.GetAccount(account.Id) + account, err = manager.Store.GetAccount(context.Background(), account.Id) if err != nil { t.Fatal(err) return @@ -1045,7 +1046,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { return } - account, err := manager.GetOrCreateAccountByUser(userID, "netbird.cloud") + account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "netbird.cloud") if err != nil { t.Fatal(err) } @@ -1065,7 +1066,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { expectedPeerKey := key.PublicKey().String() expectedUserID := userID - peer, _, err := manager.AddPeer("", userID, &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, }) @@ -1074,7 +1075,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { return } - account, err = manager.Store.GetAccount(account.Id) + account, err = manager.Store.GetAccount(context.Background(), account.Id) if err != nil { t.Fatal(err) return @@ -1121,7 +1122,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { t.Fatal(err) } - setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) if err != nil { t.Fatal("error creating setup key") return @@ -1140,7 +1141,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { } expectedPeerKey := key.PublicKey().String() - peer, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, }) @@ -1156,14 +1157,14 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { peer2 := getPeer() peer3 := getPeer() - account, err = manager.Store.GetAccount(account.Id) + account, err = manager.Store.GetAccount(context.Background(), account.Id) if err != nil { t.Fatal(err) return } - updMsg := manager.peersUpdateManager.CreateChannel(peer1.ID) - defer manager.peersUpdateManager.CloseChannel(peer1.ID) + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) group := group.Group{ ID: "group-id", @@ -1197,7 +1198,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { } }() - if err := manager.SaveGroup(account.Id, userID, &group); err != nil { + if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil { t.Errorf("save group: %v", err) return } @@ -1217,7 +1218,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { } }() - if err := manager.DeletePolicy(account.Id, account.Policies[0].ID, userID); err != nil { + if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { t.Errorf("delete default rule: %v", err) return } @@ -1237,7 +1238,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { } }() - if err := manager.SavePolicy(account.Id, userID, &policy); err != nil { + if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil { t.Errorf("delete default rule: %v", err) return } @@ -1256,7 +1257,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { } }() - if err := manager.DeletePeer(account.Id, peer3.ID, userID); err != nil { + if err := manager.DeletePeer(context.Background(), account.Id, peer3.ID, userID); err != nil { t.Errorf("delete peer: %v", err) return } @@ -1277,9 +1278,9 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { }() // clean policy is pre requirement for delete group - _ = manager.DeletePolicy(account.Id, policy.ID, userID) + _ = manager.DeletePolicy(context.Background(), account.Id, policy.ID, userID) - if err := manager.DeleteGroup(account.Id, "", group.ID); err != nil { + if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil { t.Errorf("delete group: %v", err) return } @@ -1301,7 +1302,7 @@ func TestAccountManager_DeletePeer(t *testing.T) { t.Fatal(err) } - setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) if err != nil { t.Fatal("error creating setup key") return @@ -1315,7 +1316,7 @@ func TestAccountManager_DeletePeer(t *testing.T) { peerKey := key.PublicKey().String() - peer, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ Key: peerKey, Meta: nbpeer.PeerSystemMeta{Hostname: peerKey}, }) @@ -1324,12 +1325,12 @@ func TestAccountManager_DeletePeer(t *testing.T) { return } - err = manager.DeletePeer(account.Id, peerKey, userID) + err = manager.DeletePeer(context.Background(), account.Id, peerKey, userID) if err != nil { return } - account, err = manager.Store.GetAccount(account.Id) + account, err = manager.Store.GetAccount(context.Background(), account.Id) if err != nil { t.Fatal(err) return @@ -1357,7 +1358,7 @@ func getEvent(t *testing.T, accountID string, manager AccountManager, eventType case <-time.After(time.Second): t.Fatal("no PeerAddedWithSetupKey event was generated") default: - events, err := manager.GetEvents(accountID, userID) + events, err := manager.GetEvents(context.Background(), accountID, userID) if err != nil { t.Fatal(err) } @@ -1389,7 +1390,7 @@ func TestGetUsersFromAccount(t *testing.T) { account.Users[user.Id] = user } - userInfos, err := manager.GetUsersFromAccount(accountId, "1") + userInfos, err := manager.GetUsersFromAccount(context.Background(), accountId, "1") if err != nil { t.Fatal(err) } @@ -1435,7 +1436,7 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) { }, } - routes := account.GetRoutesByPrefix(prefix) + routes := account.GetRoutesByPrefixOrDomains(prefix, nil) assert.Len(t, routes, 2) routeIDs := make(map[route.ID]struct{}, 2) @@ -1500,7 +1501,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) { }, } - routes := account.getRoutesToSync("peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}) + routes := account.getRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}) assert.Len(t, routes, 2) routeIDs := make(map[route.ID]struct{}, 2) @@ -1510,7 +1511,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) { assert.Contains(t, routeIDs, route.ID("route-2")) assert.Contains(t, routeIDs, route.ID("route-3")) - emptyRoutes := account.getRoutesToSync("peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}) + emptyRoutes := account.getRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}) assert.Len(t, emptyRoutes, 0) } @@ -1645,7 +1646,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - account, err := manager.GetAccountByUserOrAccountID(userID, "", "") + account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to create an account") assert.NotNil(t, account.Settings) @@ -1657,23 +1658,23 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - _, err = manager.GetAccountByUserOrAccountID(userID, "", "") + _, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() require.NoError(t, err, "unable to generate WireGuard key") - peer, _, err := manager.AddPeer("", userID, &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ Key: key.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, LoginExpirationEnabled: true, }) require.NoError(t, err, "unable to add peer") - account, err := manager.GetAccountByUserOrAccountID(userID, "", "") + account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to get the account") - err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) require.NoError(t, err, "unable to mark peer connected") - account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{ + account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, }) @@ -1682,10 +1683,10 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { wg := &sync.WaitGroup{} wg.Add(2) manager.peerLoginExpiry = &MockScheduler{ - CancelFunc: func(IDs []string) { + CancelFunc: func(ctx context.Context, IDs []string) { wg.Done() }, - ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { + ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { wg.Done() }, } @@ -1693,11 +1694,11 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { // disable expiration first update := peer.Copy() update.LoginExpirationEnabled = false - _, err = manager.UpdatePeer(account.Id, userID, update) + _, err = manager.UpdatePeer(context.Background(), account.Id, userID, update) require.NoError(t, err, "unable to update peer") // enabling expiration should trigger the routine update.LoginExpirationEnabled = true - _, err = manager.UpdatePeer(account.Id, userID, update) + _, err = manager.UpdatePeer(context.Background(), account.Id, userID, update) require.NoError(t, err, "unable to update peer") failed := waitTimeout(wg, time.Second) @@ -1710,18 +1711,18 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - account, err := manager.GetAccountByUserOrAccountID(userID, "", "") + account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() require.NoError(t, err, "unable to generate WireGuard key") - _, _, err = manager.AddPeer("", userID, &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ Key: key.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, LoginExpirationEnabled: true, }) require.NoError(t, err, "unable to add peer") - _, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, }) @@ -1730,18 +1731,18 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. wg := &sync.WaitGroup{} wg.Add(2) manager.peerLoginExpiry = &MockScheduler{ - CancelFunc: func(IDs []string) { + CancelFunc: func(ctx context.Context, IDs []string) { wg.Done() }, - ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { + ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { wg.Done() }, } - account, err = manager.GetAccountByUserOrAccountID(userID, "", "") + account, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to get the account") // when we mark peer as connected, the peer login expiration routine should trigger - err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) require.NoError(t, err, "unable to mark peer connected") failed := waitTimeout(wg, time.Second) @@ -1754,35 +1755,35 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - _, err = manager.GetAccountByUserOrAccountID(userID, "", "") + _, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() require.NoError(t, err, "unable to generate WireGuard key") - _, _, err = manager.AddPeer("", userID, &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ Key: key.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, LoginExpirationEnabled: true, }) require.NoError(t, err, "unable to add peer") - account, err := manager.GetAccountByUserOrAccountID(userID, "", "") + account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to get the account") - err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) require.NoError(t, err, "unable to mark peer connected") wg := &sync.WaitGroup{} wg.Add(2) manager.peerLoginExpiry = &MockScheduler{ - CancelFunc: func(IDs []string) { + CancelFunc: func(ctx context.Context, IDs []string) { wg.Done() }, - ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { + ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { wg.Done() }, } // enabling PeerLoginExpirationEnabled should trigger the expiration job - account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{ + account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, }) @@ -1795,7 +1796,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test wg.Add(1) // disabling PeerLoginExpirationEnabled should trigger cancel - _, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: false, }) @@ -1810,10 +1811,10 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - account, err := manager.GetAccountByUserOrAccountID(userID, "", "") + account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to create an account") - updated, err := manager.UpdateAccountSettings(account.Id, userID, &Settings{ + updated, err := manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: false, }) @@ -1821,19 +1822,19 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { assert.False(t, updated.Settings.PeerLoginExpirationEnabled) assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour) - account, err = manager.GetAccountByUserOrAccountID("", account.Id, "") + account, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "") require.NoError(t, err, "unable to get account by ID") assert.False(t, account.Settings.PeerLoginExpirationEnabled) assert.Equal(t, account.Settings.PeerLoginExpiration, time.Hour) - _, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ PeerLoginExpiration: time.Second, PeerLoginExpirationEnabled: false, }) require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour") - _, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ PeerLoginExpiration: time.Hour * 24 * 181, PeerLoginExpirationEnabled: false, }) @@ -2175,17 +2176,33 @@ func TestAccount_SetJWTGroups(t *testing.T) { }, } - t.Run("api group already exists", func(t *testing.T) { - updated := account.SetJWTGroups("user1", []string{"group1"}) + t.Run("empty jwt groups", func(t *testing.T) { + updated := account.SetJWTGroups("user1", []string{}) assert.False(t, updated, "account should not be updated") assert.Empty(t, account.Users["user1"].AutoGroups, "auto groups must be empty") }) + t.Run("jwt match existing api group", func(t *testing.T) { + updated := account.SetJWTGroups("user1", []string{"group1"}) + assert.False(t, updated, "account should not be updated") + assert.Equal(t, 0, len(account.Users["user1"].AutoGroups)) + assert.Equal(t, account.Groups["group1"].Issued, group.GroupIssuedAPI, "group should be api issued") + }) + + t.Run("jwt match existing api group in user auto groups", func(t *testing.T) { + account.Users["user1"].AutoGroups = []string{"group1"} + + updated := account.SetJWTGroups("user1", []string{"group1"}) + assert.False(t, updated, "account should not be updated") + assert.Equal(t, 1, len(account.Users["user1"].AutoGroups)) + assert.Equal(t, account.Groups["group1"].Issued, group.GroupIssuedAPI, "group should be api issued") + }) + t.Run("add jwt group", func(t *testing.T) { updated := account.SetJWTGroups("user1", []string{"group1", "group2"}) assert.True(t, updated, "account should be updated") assert.Len(t, account.Groups, 2, "new group should be added") - assert.Len(t, account.Users["user1"].AutoGroups, 1, "new group should be added") + assert.Len(t, account.Users["user1"].AutoGroups, 2, "new group should be added") assert.Contains(t, account.Groups, account.Users["user1"].AutoGroups[0], "groups must contain group2 from user groups") }) @@ -2278,7 +2295,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) { } eventStore := &activity.InMemoryEventStore{} - manager, err := BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}) + manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}) if err != nil { return nil, err } @@ -2289,7 +2306,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) { func createStore(t *testing.T) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromJson(dataDir) + store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir) if err != nil { return nil, err } diff --git a/management/server/activity/sqlite/sqlite.go b/management/server/activity/sqlite/sqlite.go index b54db5276..fadf1eb07 100644 --- a/management/server/activity/sqlite/sqlite.go +++ b/management/server/activity/sqlite/sqlite.go @@ -1,6 +1,7 @@ package sqlite import ( + "context" "database/sql" "encoding/json" "fmt" @@ -86,7 +87,7 @@ type Store struct { } // NewSQLiteStore creates a new Store with an event table if not exists. -func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) { +func NewSQLiteStore(ctx context.Context, dataDir string, encryptionKey string) (*Store, error) { dbFile := filepath.Join(dataDir, eventSinkDB) db, err := sql.Open("sqlite3", dbFile) if err != nil { @@ -111,7 +112,7 @@ func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) { return nil, err } - err = updateDeletedUsersTable(db) + err = updateDeletedUsersTable(ctx, db) if err != nil { _ = db.Close() return nil, err @@ -153,7 +154,7 @@ func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) { return s, nil } -func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) { +func (store *Store) processResult(ctx context.Context, result *sql.Rows) ([]*activity.Event, error) { events := make([]*activity.Event, 0) var cryptErr error for result.Next() { @@ -235,14 +236,14 @@ func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) { } if cryptErr != nil { - log.Warnf("%s", cryptErr) + log.WithContext(ctx).Warnf("%s", cryptErr) } return events, nil } // Get returns "limit" number of events from index ordered descending or ascending by a timestamp -func (store *Store) Get(accountID string, offset, limit int, descending bool) ([]*activity.Event, error) { +func (store *Store) Get(ctx context.Context, accountID string, offset, limit int, descending bool) ([]*activity.Event, error) { stmt := store.selectDescStatement if !descending { stmt = store.selectAscStatement @@ -254,11 +255,11 @@ func (store *Store) Get(accountID string, offset, limit int, descending bool) ([ } defer result.Close() //nolint - return store.processResult(result) + return store.processResult(ctx, result) } // Save an event in the SQLite events table end encrypt the "email" element in meta map -func (store *Store) Save(event *activity.Event) (*activity.Event, error) { +func (store *Store) Save(_ context.Context, event *activity.Event) (*activity.Event, error) { var jsonMeta string meta, err := store.saveDeletedUserEmailAndNameInEncrypted(event) if err != nil { @@ -317,15 +318,15 @@ func (store *Store) saveDeletedUserEmailAndNameInEncrypted(event *activity.Event } // Close the Store -func (store *Store) Close() error { +func (store *Store) Close(_ context.Context) error { if store.db != nil { return store.db.Close() } return nil } -func updateDeletedUsersTable(db *sql.DB) error { - log.Debugf("check deleted_users table version") +func updateDeletedUsersTable(ctx context.Context, db *sql.DB) error { + log.WithContext(ctx).Debugf("check deleted_users table version") rows, err := db.Query(`PRAGMA table_info(deleted_users);`) if err != nil { return err @@ -360,7 +361,7 @@ func updateDeletedUsersTable(db *sql.DB) error { return nil } - log.Debugf("update delted_users table") + log.WithContext(ctx).Debugf("update delted_users table") _, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`) return err } diff --git a/management/server/activity/sqlite/sqlite_test.go b/management/server/activity/sqlite/sqlite_test.go index f6a6f9467..b10f9b58a 100644 --- a/management/server/activity/sqlite/sqlite_test.go +++ b/management/server/activity/sqlite/sqlite_test.go @@ -1,6 +1,7 @@ package sqlite import ( + "context" "fmt" "testing" "time" @@ -13,17 +14,17 @@ import ( func TestNewSQLiteStore(t *testing.T) { dataDir := t.TempDir() key, _ := GenerateKey() - store, err := NewSQLiteStore(dataDir, key) + store, err := NewSQLiteStore(context.Background(), dataDir, key) if err != nil { t.Fatal(err) return } - defer store.Close() //nolint + defer store.Close(context.Background()) //nolint accountID := "account_1" for i := 0; i < 10; i++ { - _, err = store.Save(&activity.Event{ + _, err = store.Save(context.Background(), &activity.Event{ Timestamp: time.Now().UTC(), Activity: activity.PeerAddedByUser, InitiatorID: "user_" + fmt.Sprint(i), @@ -36,7 +37,7 @@ func TestNewSQLiteStore(t *testing.T) { } } - result, err := store.Get(accountID, 0, 10, false) + result, err := store.Get(context.Background(), accountID, 0, 10, false) if err != nil { t.Fatal(err) return @@ -45,7 +46,7 @@ func TestNewSQLiteStore(t *testing.T) { assert.Len(t, result, 10) assert.True(t, result[0].Timestamp.Before(result[len(result)-1].Timestamp)) - result, err = store.Get(accountID, 0, 5, true) + result, err = store.Get(context.Background(), accountID, 0, 5, true) if err != nil { t.Fatal(err) return diff --git a/management/server/activity/store.go b/management/server/activity/store.go index 77439e2e1..ef08e2b33 100644 --- a/management/server/activity/store.go +++ b/management/server/activity/store.go @@ -1,15 +1,18 @@ package activity -import "sync" +import ( + "context" + "sync" +) // Store provides an interface to store or stream events. type Store interface { // Save an event in the store - Save(event *Event) (*Event, error) + Save(ctx context.Context, event *Event) (*Event, error) // Get returns "limit" number of events from the "offset" index ordered descending or ascending by a timestamp - Get(accountID string, offset, limit int, descending bool) ([]*Event, error) + Get(ctx context.Context, accountID string, offset, limit int, descending bool) ([]*Event, error) // Close the sink flushing events if necessary - Close() error + Close(ctx context.Context) error } // InMemoryEventStore implements the Store interface storing data in-memory @@ -20,7 +23,7 @@ type InMemoryEventStore struct { } // Save sets the Event.ID to 1 -func (store *InMemoryEventStore) Save(event *Event) (*Event, error) { +func (store *InMemoryEventStore) Save(_ context.Context, event *Event) (*Event, error) { store.mu.Lock() defer store.mu.Unlock() if store.events == nil { @@ -33,7 +36,7 @@ func (store *InMemoryEventStore) Save(event *Event) (*Event, error) { } // Get returns a list of ALL events that belong to the given accountID without taking offset, limit and order into consideration -func (store *InMemoryEventStore) Get(accountID string, offset, limit int, descending bool) ([]*Event, error) { +func (store *InMemoryEventStore) Get(_ context.Context, accountID string, offset, limit int, descending bool) ([]*Event, error) { store.mu.Lock() defer store.mu.Unlock() events := make([]*Event, 0) @@ -46,7 +49,7 @@ func (store *InMemoryEventStore) Get(accountID string, offset, limit int, descen } // Close cleans up the event list -func (store *InMemoryEventStore) Close() error { +func (store *InMemoryEventStore) Close(_ context.Context) error { store.mu.Lock() defer store.mu.Unlock() store.events = make([]*Event, 0) diff --git a/management/server/context/keys.go b/management/server/context/keys.go new file mode 100644 index 000000000..c5b5da044 --- /dev/null +++ b/management/server/context/keys.go @@ -0,0 +1,8 @@ +package context + +const ( + RequestIDKey = "requestID" + AccountIDKey = "accountID" + UserIDKey = "userID" + PeerIDKey = "peerID" +) diff --git a/management/server/dns.go b/management/server/dns.go index 5e2febf55..8a889df3f 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "strconv" @@ -34,11 +35,11 @@ func (d DNSSettings) Copy() DNSSettings { } // GetDNSSettings validates a user role and returns the DNS settings for the provided account ID -func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string) (*DNSSettings, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -56,11 +57,11 @@ func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string) } // SaveDNSSettings validates a user role and updates the account's DNS settings -func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } @@ -89,7 +90,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string account.DNSSettings = dnsSettingsToSave.Copy() account.Network.IncSerial() - if err = am.Store.SaveAccount(account); err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return err } @@ -97,17 +98,17 @@ func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string for _, id := range addedGroups { group := account.GetGroup(id) meta := map[string]any{"group": group.Name, "group_id": group.ID} - am.StoreEvent(userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta) + am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta) } removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) for _, id := range removedGroups { group := account.GetGroup(id) meta := map[string]any{"group": group.Name, "group_id": group.ID} - am.StoreEvent(userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta) + am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta) } - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) return nil } @@ -149,9 +150,9 @@ func toProtocolDNSConfig(update nbdns.Config) *proto.DNSConfig { return protoUpdate } -func getPeersCustomZone(account *Account, dnsDomain string) nbdns.CustomZone { +func getPeersCustomZone(ctx context.Context, account *Account, dnsDomain string) nbdns.CustomZone { if dnsDomain == "" { - log.Errorf("no dns domain is set, returning empty zone") + log.WithContext(ctx).Errorf("no dns domain is set, returning empty zone") return nbdns.CustomZone{} } @@ -161,7 +162,7 @@ func getPeersCustomZone(account *Account, dnsDomain string) nbdns.CustomZone { for _, peer := range account.Peers { if peer.DNSLabel == "" { - log.Errorf("found a peer with empty dns label. It was probably caused by a invalid character in its name. Peer Name: %s", peer.Name) + log.WithContext(ctx).Errorf("found a peer with empty dns label. It was probably caused by a invalid character in its name. Peer Name: %s", peer.Name) continue } @@ -210,14 +211,14 @@ func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool { return false } -func addPeerLabelsToAccount(account *Account, peerLabels lookupMap) { +func addPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels lookupMap) { for _, peer := range account.Peers { label, err := getPeerHostLabel(peer.Name, peerLabels) if err != nil { - log.Errorf("got an error while generating a peer host label. Peer name %s, error: %v. Trying with the peer's meta hostname", peer.Name, err) + log.WithContext(ctx).Errorf("got an error while generating a peer host label. Peer name %s, error: %v. Trying with the peer's meta hostname", peer.Name, err) label, err = getPeerHostLabel(peer.Meta.Hostname, peerLabels) if err != nil { - log.Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skipping", peer.Meta.Hostname, err) + log.WithContext(ctx).Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skipping", peer.Meta.Hostname, err) continue } } diff --git a/management/server/dns_test.go b/management/server/dns_test.go index bfa50b1cf..c6758036f 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "net/netip" "testing" @@ -35,7 +36,7 @@ func TestGetDNSSettings(t *testing.T) { t.Fatal("failed to init testing account") } - dnsSettings, err := am.GetDNSSettings(account.Id, dnsAdminUserID) + dnsSettings, err := am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID) if err != nil { t.Fatalf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err) } @@ -48,12 +49,12 @@ func TestGetDNSSettings(t *testing.T) { DisabledManagementGroups: []string{group1ID}, } - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(context.Background(), account) if err != nil { t.Error("failed to save testing account with new DNS settings") } - dnsSettings, err = am.GetDNSSettings(account.Id, dnsAdminUserID) + dnsSettings, err = am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID) if err != nil { t.Errorf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err) } @@ -62,7 +63,7 @@ func TestGetDNSSettings(t *testing.T) { t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups) } - _, err = am.GetDNSSettings(account.Id, dnsRegularUserID) + _, err = am.GetDNSSettings(context.Background(), account.Id, dnsRegularUserID) if err == nil { t.Errorf("An error should be returned when getting the DNS settings with a regular user") } @@ -122,7 +123,7 @@ func TestSaveDNSSettings(t *testing.T) { t.Error("failed to init testing account") } - err = am.SaveDNSSettings(account.Id, testCase.userID, testCase.inputSettings) + err = am.SaveDNSSettings(context.Background(), account.Id, testCase.userID, testCase.inputSettings) if err != nil { if testCase.shouldFail { return @@ -130,7 +131,7 @@ func TestSaveDNSSettings(t *testing.T) { t.Error(err) } - updatedAccount, err := am.Store.GetAccount(account.Id) + updatedAccount, err := am.Store.GetAccount(context.Background(), account.Id) if err != nil { t.Errorf("should be able to retrieve updated account, got err: %s", err) } @@ -164,27 +165,27 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) { t.Error("failed to init testing account") } - newAccountDNSConfig, err := am.GetNetworkMap(peer1.ID) + newAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID) require.NoError(t, err) - require.Len(t, newAccountDNSConfig.DNSConfig.CustomZones, 1, "default DNS turnCfg should have one custom zone for peers") - require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS turnCfg should have local DNS service enabled") - require.Len(t, newAccountDNSConfig.DNSConfig.NameServerGroups, 0, "updated DNS turnCfg should have no nameserver groups since peer 1 is NS for the only existing NS group") + require.Len(t, newAccountDNSConfig.DNSConfig.CustomZones, 1, "default DNS config should have one custom zone for peers") + require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled") + require.Len(t, newAccountDNSConfig.DNSConfig.NameServerGroups, 0, "updated DNS config should have no nameserver groups since peer 1 is NS for the only existing NS group") dnsSettings := account.DNSSettings.Copy() dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID) account.DNSSettings = dnsSettings - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(context.Background(), account) require.NoError(t, err) - updatedAccountDNSConfig, err := am.GetNetworkMap(peer1.ID) + updatedAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID) require.NoError(t, err) - require.Len(t, updatedAccountDNSConfig.DNSConfig.CustomZones, 0, "updated DNS turnCfg should have no custom zone when peer belongs to a disabled group") - require.False(t, updatedAccountDNSConfig.DNSConfig.ServiceEnable, "updated DNS turnCfg should have local DNS service disabled when peer belongs to a disabled group") - peer2AccountDNSConfig, err := am.GetNetworkMap(peer2.ID) + require.Len(t, updatedAccountDNSConfig.DNSConfig.CustomZones, 0, "updated DNS config should have no custom zone when peer belongs to a disabled group") + require.False(t, updatedAccountDNSConfig.DNSConfig.ServiceEnable, "updated DNS config should have local DNS service disabled when peer belongs to a disabled group") + peer2AccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer2.ID) require.NoError(t, err) - require.Len(t, peer2AccountDNSConfig.DNSConfig.CustomZones, 1, "DNS turnCfg should have one custom zone for peers not in the disabled group") - require.True(t, peer2AccountDNSConfig.DNSConfig.ServiceEnable, "DNS turnCfg should have DNS service enabled for peers not in the disabled group") - require.Len(t, peer2AccountDNSConfig.DNSConfig.NameServerGroups, 1, "updated DNS turnCfg should have 1 nameserver groups since peer 2 is part of the group All") + require.Len(t, peer2AccountDNSConfig.DNSConfig.CustomZones, 1, "DNS config should have one custom zone for peers not in the disabled group") + require.True(t, peer2AccountDNSConfig.DNSConfig.ServiceEnable, "DNS config should have DNS service enabled for peers not in the disabled group") + require.Len(t, peer2AccountDNSConfig.DNSConfig.NameServerGroups, 1, "updated DNS config should have 1 nameserver groups since peer 2 is part of the group All") } func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { @@ -194,13 +195,13 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}) + return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}) } func createDNSStore(t *testing.T) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromJson(dataDir) + store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir) if err != nil { return nil, err } @@ -244,28 +245,28 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro domain := "example.com" - account := newAccountWithId(dnsAccountID, dnsAdminUserID, domain) + account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain) account.Users[dnsRegularUserID] = &User{ Id: dnsRegularUserID, Role: UserRoleUser, } - err := am.Store.SaveAccount(account) + err := am.Store.SaveAccount(context.Background(), account) if err != nil { return nil, err } - savedPeer1, _, err := am.AddPeer("", dnsAdminUserID, peer1) + savedPeer1, _, _, err := am.AddPeer(context.Background(), "", dnsAdminUserID, peer1) if err != nil { return nil, err } - _, _, err = am.AddPeer("", dnsAdminUserID, peer2) + _, _, _, err = am.AddPeer(context.Background(), "", dnsAdminUserID, peer2) if err != nil { return nil, err } - account, err = am.Store.GetAccount(account.Id) + account, err = am.Store.GetAccount(context.Background(), account.Id) if err != nil { return nil, err } @@ -312,10 +313,10 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro Groups: []string{allGroup.ID}, } - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(context.Background(), account) if err != nil { return nil, err } - return am.Store.GetAccount(account.Id) + return am.Store.GetAccount(context.Background(), account.Id) } diff --git a/management/server/ephemeral.go b/management/server/ephemeral.go index 4fffa024d..590b1d708 100644 --- a/management/server/ephemeral.go +++ b/management/server/ephemeral.go @@ -1,6 +1,7 @@ package server import ( + "context" "sync" "time" @@ -51,13 +52,15 @@ func NewEphemeralManager(store Store, accountManager AccountManager) *EphemeralM // LoadInitialPeers load from the database the ephemeral type of peers and schedule a cleanup procedure to the head // of the linked list (to the most deprecated peer). At the end of cleanup it schedules the next cleanup to the new // head. -func (e *EphemeralManager) LoadInitialPeers() { +func (e *EphemeralManager) LoadInitialPeers(ctx context.Context) { e.peersLock.Lock() defer e.peersLock.Unlock() - e.loadEphemeralPeers() + e.loadEphemeralPeers(ctx) if e.headPeer != nil { - e.timer = time.AfterFunc(ephemeralLifeTime, e.cleanup) + e.timer = time.AfterFunc(ephemeralLifeTime, func() { + e.cleanup(ctx) + }) } } @@ -73,12 +76,12 @@ func (e *EphemeralManager) Stop() { // OnPeerConnected remove the peer from the linked list of ephemeral peers. Because it has been called when the peer // is active the manager will not delete it while it is active. -func (e *EphemeralManager) OnPeerConnected(peer *nbpeer.Peer) { +func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Peer) { if !peer.Ephemeral { return } - log.Tracef("remove peer from ephemeral list: %s", peer.ID) + log.WithContext(ctx).Tracef("remove peer from ephemeral list: %s", peer.ID) e.peersLock.Lock() defer e.peersLock.Unlock() @@ -94,16 +97,16 @@ func (e *EphemeralManager) OnPeerConnected(peer *nbpeer.Peer) { // OnPeerDisconnected add the peer to the linked list of ephemeral peers. Because of the peer // is inactive it will be deleted after the ephemeralLifeTime period. -func (e *EphemeralManager) OnPeerDisconnected(peer *nbpeer.Peer) { +func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.Peer) { if !peer.Ephemeral { return } - log.Tracef("add peer to ephemeral list: %s", peer.ID) + log.WithContext(ctx).Tracef("add peer to ephemeral list: %s", peer.ID) - a, err := e.store.GetAccountByPeerID(peer.ID) + a, err := e.store.GetAccountByPeerID(context.Background(), peer.ID) if err != nil { - log.Errorf("failed to add peer to ephemeral list: %s", err) + log.WithContext(ctx).Errorf("failed to add peer to ephemeral list: %s", err) return } @@ -116,12 +119,14 @@ func (e *EphemeralManager) OnPeerDisconnected(peer *nbpeer.Peer) { e.addPeer(peer.ID, a, newDeadLine()) if e.timer == nil { - e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), e.cleanup) + e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() { + e.cleanup(ctx) + }) } } -func (e *EphemeralManager) loadEphemeralPeers() { - accounts := e.store.GetAllAccounts() +func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) { + accounts := e.store.GetAllAccounts(context.Background()) t := newDeadLine() count := 0 for _, a := range accounts { @@ -132,10 +137,10 @@ func (e *EphemeralManager) loadEphemeralPeers() { } } } - log.Debugf("loaded ephemeral peer(s): %d", count) + log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", count) } -func (e *EphemeralManager) cleanup() { +func (e *EphemeralManager) cleanup(ctx context.Context) { log.Tracef("on ephemeral cleanup") deletePeers := make(map[string]*ephemeralPeer) @@ -154,7 +159,9 @@ func (e *EphemeralManager) cleanup() { } if e.headPeer != nil { - e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), e.cleanup) + e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() { + e.cleanup(ctx) + }) } else { e.timer = nil } @@ -162,10 +169,10 @@ func (e *EphemeralManager) cleanup() { e.peersLock.Unlock() for id, p := range deletePeers { - log.Debugf("delete ephemeral peer: %s", id) - err := e.accountManager.DeletePeer(p.account.Id, id, activity.SystemInitiator) + log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id) + err := e.accountManager.DeletePeer(ctx, p.account.Id, id, activity.SystemInitiator) if err != nil { - log.Errorf("failed to delete ephemeral peer: %s", err) + log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err) } } } diff --git a/management/server/ephemeral_test.go b/management/server/ephemeral_test.go index 3e36335e3..36c88f1d1 100644 --- a/management/server/ephemeral_test.go +++ b/management/server/ephemeral_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "testing" "time" @@ -13,11 +14,11 @@ type MockStore struct { account *Account } -func (s *MockStore) GetAllAccounts() []*Account { +func (s *MockStore) GetAllAccounts(_ context.Context) []*Account { return []*Account{s.account} } -func (s *MockStore) GetAccountByPeerID(peerId string) (*Account, error) { +func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Account, error) { _, ok := s.account.Peers[peerId] if ok { return s.account, nil @@ -31,7 +32,7 @@ type MocAccountManager struct { store *MockStore } -func (a MocAccountManager) DeletePeer(accountID, peerID, userID string) error { +func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error { delete(a.store.account.Peers, peerID) return nil //nolint:nil } @@ -52,9 +53,9 @@ func TestNewManager(t *testing.T) { seedPeers(store, numberOfPeers, numberOfEphemeralPeers) mgr := NewEphemeralManager(store, am) - mgr.loadEphemeralPeers() + mgr.loadEphemeralPeers(context.Background()) startTime = startTime.Add(ephemeralLifeTime + 1) - mgr.cleanup() + mgr.cleanup(context.Background()) if len(store.account.Peers) != numberOfPeers { t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", numberOfPeers, len(store.account.Peers)) @@ -77,11 +78,11 @@ func TestNewManagerPeerConnected(t *testing.T) { seedPeers(store, numberOfPeers, numberOfEphemeralPeers) mgr := NewEphemeralManager(store, am) - mgr.loadEphemeralPeers() - mgr.OnPeerConnected(store.account.Peers["ephemeral_peer_0"]) + mgr.loadEphemeralPeers(context.Background()) + mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"]) startTime = startTime.Add(ephemeralLifeTime + 1) - mgr.cleanup() + mgr.cleanup(context.Background()) expected := numberOfPeers + 1 if len(store.account.Peers) != expected { @@ -105,15 +106,15 @@ func TestNewManagerPeerDisconnected(t *testing.T) { seedPeers(store, numberOfPeers, numberOfEphemeralPeers) mgr := NewEphemeralManager(store, am) - mgr.loadEphemeralPeers() + mgr.loadEphemeralPeers(context.Background()) for _, v := range store.account.Peers { - mgr.OnPeerConnected(v) + mgr.OnPeerConnected(context.Background(), v) } - mgr.OnPeerDisconnected(store.account.Peers["ephemeral_peer_0"]) + mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"]) startTime = startTime.Add(ephemeralLifeTime + 1) - mgr.cleanup() + mgr.cleanup(context.Background()) expected := numberOfPeers + numberOfEphemeralPeers - 1 if len(store.account.Peers) != expected { @@ -122,7 +123,7 @@ func TestNewManagerPeerDisconnected(t *testing.T) { } func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) { - store.account = newAccountWithId("my account", "", "") + store.account = newAccountWithId(context.Background(), "my account", "", "") for i := 0; i < numberOfPeers; i++ { peerId := fmt.Sprintf("peer_%d", i) diff --git a/management/server/event.go b/management/server/event.go index 303f88a79..616cea287 100644 --- a/management/server/event.go +++ b/management/server/event.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "time" @@ -11,11 +12,11 @@ import ( ) // GetEvents returns a list of activity events of an account -func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activity.Event, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -29,7 +30,7 @@ func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activit return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view events") } - events, err := am.eventStore.Get(accountID, 0, 10000, true) + events, err := am.eventStore.Get(ctx, accountID, 0, 10000, true) if err != nil { return nil, err } @@ -54,10 +55,10 @@ func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activit return filtered, nil } -func (am *DefaultAccountManager) StoreEvent(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { +func (am *DefaultAccountManager) StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { go func() { - _, err := am.eventStore.Save(&activity.Event{ + _, err := am.eventStore.Save(ctx, &activity.Event{ Timestamp: time.Now().UTC(), Activity: activityID, InitiatorID: initiatorID, @@ -67,7 +68,7 @@ func (am *DefaultAccountManager) StoreEvent(initiatorID, targetID, accountID str }) if err != nil { // todo add metric - log.Errorf("received an error while storing an activity event, error: %s", err) + log.WithContext(ctx).Errorf("received an error while storing an activity event, error: %s", err) } }() diff --git a/management/server/event_test.go b/management/server/event_test.go index 401c80759..8c56fd3f6 100644 --- a/management/server/event_test.go +++ b/management/server/event_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "testing" "time" @@ -13,7 +14,7 @@ func generateAndStoreEvents(t *testing.T, manager *DefaultAccountManager, typ ac accountID string, count int) { t.Helper() for i := 0; i < count; i++ { - _, err := manager.eventStore.Save(&activity.Event{ + _, err := manager.eventStore.Save(context.Background(), &activity.Event{ Timestamp: time.Now().UTC(), Activity: typ, InitiatorID: initiatorID, @@ -35,32 +36,32 @@ func TestDefaultAccountManager_GetEvents(t *testing.T) { accountID := "accountID" t.Run("get empty events list", func(t *testing.T) { - events, err := manager.GetEvents(accountID, userID) + events, err := manager.GetEvents(context.Background(), accountID, userID) if err != nil { return } assert.Len(t, events, 0) - _ = manager.eventStore.Close() //nolint + _ = manager.eventStore.Close(context.Background()) //nolint }) t.Run("get events", func(t *testing.T) { generateAndStoreEvents(t, manager, activity.PeerAddedByUser, userID, "peer", accountID, 10) - events, err := manager.GetEvents(accountID, userID) + events, err := manager.GetEvents(context.Background(), accountID, userID) if err != nil { return } assert.Len(t, events, 10) - _ = manager.eventStore.Close() //nolint + _ = manager.eventStore.Close(context.Background()) //nolint }) t.Run("get events without duplicates", func(t *testing.T) { generateAndStoreEvents(t, manager, activity.UserJoined, userID, "", accountID, 10) - events, err := manager.GetEvents(accountID, userID) + events, err := manager.GetEvents(context.Background(), accountID, userID) if err != nil { return } assert.Len(t, events, 1) - _ = manager.eventStore.Close() //nolint + _ = manager.eventStore.Close(context.Background()) //nolint }) } diff --git a/management/server/file_store.go b/management/server/file_store.go index 60497824c..3fd543797 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -1,6 +1,7 @@ package server import ( + "context" "os" "path/filepath" "strings" @@ -48,8 +49,8 @@ type FileStore struct { type StoredAccount struct{} // NewFileStore restores a store from the file located in the datadir -func NewFileStore(dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) { - fs, err := restore(filepath.Join(dataDir, storeFileName)) +func NewFileStore(ctx context.Context, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) { + fs, err := restore(ctx, filepath.Join(dataDir, storeFileName)) if err != nil { return nil, err } @@ -58,27 +59,27 @@ func NewFileStore(dataDir string, metrics telemetry.AppMetrics) (*FileStore, err } // NewFilestoreFromSqliteStore restores a store from Sqlite and stores to Filestore json in the file located in datadir -func NewFilestoreFromSqliteStore(sqlStore *SqlStore, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) { - store, err := NewFileStore(dataDir, metrics) +func NewFilestoreFromSqliteStore(ctx context.Context, sqlStore *SqlStore, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) { + store, err := NewFileStore(ctx, dataDir, metrics) if err != nil { return nil, err } - err = store.SaveInstallationID(sqlStore.GetInstallationID()) + err = store.SaveInstallationID(ctx, sqlStore.GetInstallationID()) if err != nil { return nil, err } - for _, account := range sqlStore.GetAllAccounts() { + for _, account := range sqlStore.GetAllAccounts(ctx) { store.Accounts[account.Id] = account } - return store, store.persist(store.storeFile) + return store, store.persist(ctx, store.storeFile) } // restore the state of the store from the file. // Creates a new empty store file if doesn't exist -func restore(file string) (*FileStore, error) { +func restore(ctx context.Context, file string) (*FileStore, error) { if _, err := os.Stat(file); os.IsNotExist(err) { // create a new FileStore if previously didn't exist (e.g. first run) s := &FileStore{ @@ -95,7 +96,7 @@ func restore(file string) (*FileStore, error) { storeFile: file, } - err = s.persist(file) + err = s.persist(ctx, file) if err != nil { return nil, err } @@ -165,7 +166,7 @@ func restore(file string) (*FileStore, error) { // for data migration. Can be removed once most base will be with labels existingLabels := account.getPeerDNSLabels() if len(existingLabels) != len(account.Peers) { - addPeerLabelsToAccount(account, existingLabels) + addPeerLabelsToAccount(ctx, account, existingLabels) } // TODO: delete this block after migration @@ -178,7 +179,7 @@ func restore(file string) (*FileStore, error) { allGroup, err := account.GetGroupAll() if err != nil { - log.Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err) + log.WithContext(ctx).Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err) // if the All group didn't exist we probably don't have routes to update continue } @@ -236,7 +237,7 @@ func restore(file string) (*FileStore, error) { } // we need this persist to apply changes we made to account.Peers (we set them to Disconnected) - err = store.persist(store.storeFile) + err = store.persist(ctx, store.storeFile) if err != nil { return nil, err } @@ -246,7 +247,7 @@ func restore(file string) (*FileStore, error) { // persist account data to a file // It is recommended to call it with locking FileStore.mux -func (s *FileStore) persist(file string) error { +func (s *FileStore) persist(ctx context.Context, file string) error { start := time.Now() err := util.WriteJson(file, s) if err != nil { @@ -256,23 +257,23 @@ func (s *FileStore) persist(file string) error { if s.metrics != nil { s.metrics.StoreMetrics().CountPersistenceDuration(took) } - log.Debugf("took %d ms to persist the FileStore", took.Milliseconds()) + log.WithContext(ctx).Debugf("took %d ms to persist the FileStore", took.Milliseconds()) return nil } // AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock -func (s *FileStore) AcquireGlobalLock() (unlock func()) { - log.Debugf("acquiring global lock") +func (s *FileStore) AcquireGlobalLock(ctx context.Context) (unlock func()) { + log.WithContext(ctx).Debugf("acquiring global lock") start := time.Now() s.globalAccountLock.Lock() unlock = func() { s.globalAccountLock.Unlock() - log.Debugf("released global lock in %v", time.Since(start)) + log.WithContext(ctx).Debugf("released global lock in %v", time.Since(start)) } took := time.Since(start) - log.Debugf("took %v to acquire global lock", took) + log.WithContext(ctx).Debugf("took %v to acquire global lock", took) if s.metrics != nil { s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took) } @@ -281,8 +282,8 @@ func (s *FileStore) AcquireGlobalLock() (unlock func()) { } // AcquireAccountWriteLock acquires account lock for writing to a resource and returns a function that releases the lock -func (s *FileStore) AcquireAccountWriteLock(accountID string) (unlock func()) { - log.Debugf("acquiring lock for account %s", accountID) +func (s *FileStore) AcquireAccountWriteLock(ctx context.Context, accountID string) (unlock func()) { + log.WithContext(ctx).Debugf("acquiring lock for account %s", accountID) start := time.Now() value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{}) mtx := value.(*sync.Mutex) @@ -290,7 +291,7 @@ func (s *FileStore) AcquireAccountWriteLock(accountID string) (unlock func()) { unlock = func() { mtx.Unlock() - log.Debugf("released lock for account %s in %v", accountID, time.Since(start)) + log.WithContext(ctx).Debugf("released lock for account %s in %v", accountID, time.Since(start)) } return unlock @@ -298,11 +299,11 @@ func (s *FileStore) AcquireAccountWriteLock(accountID string) (unlock func()) { // AcquireAccountReadLock AcquireAccountWriteLock acquires account lock for reading a resource and returns a function that releases the lock // This method is still returns a write lock as file store can't handle read locks -func (s *FileStore) AcquireAccountReadLock(accountID string) (unlock func()) { - return s.AcquireAccountWriteLock(accountID) +func (s *FileStore) AcquireAccountReadLock(ctx context.Context, accountID string) (unlock func()) { + return s.AcquireAccountWriteLock(ctx, accountID) } -func (s *FileStore) SaveAccount(account *Account) error { +func (s *FileStore) SaveAccount(ctx context.Context, account *Account) error { s.mux.Lock() defer s.mux.Unlock() @@ -338,10 +339,10 @@ func (s *FileStore) SaveAccount(account *Account) error { s.PrivateDomain2AccountID[accountCopy.Domain] = accountCopy.Id } - return s.persist(s.storeFile) + return s.persist(ctx, s.storeFile) } -func (s *FileStore) DeleteAccount(account *Account) error { +func (s *FileStore) DeleteAccount(ctx context.Context, account *Account) error { s.mux.Lock() defer s.mux.Unlock() @@ -373,7 +374,7 @@ func (s *FileStore) DeleteAccount(account *Account) error { delete(s.Accounts, account.Id) - return s.persist(s.storeFile) + return s.persist(ctx, s.storeFile) } // DeleteHashedPAT2TokenIDIndex removes an entry from the indexing map HashedPAT2TokenID @@ -397,7 +398,7 @@ func (s *FileStore) DeleteTokenID2UserIDIndex(tokenID string) error { } // GetAccountByPrivateDomain returns account by private domain -func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) { +func (s *FileStore) GetAccountByPrivateDomain(_ context.Context, domain string) (*Account, error) { s.mux.Lock() defer s.mux.Unlock() @@ -415,7 +416,7 @@ func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) { } // GetAccountBySetupKey returns account by setup key id -func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) { +func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*Account, error) { s.mux.Lock() defer s.mux.Unlock() @@ -433,7 +434,7 @@ func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) { } // GetTokenIDByHashedToken returns the id of a personal access token by its hashed secret -func (s *FileStore) GetTokenIDByHashedToken(token string) (string, error) { +func (s *FileStore) GetTokenIDByHashedToken(_ context.Context, token string) (string, error) { s.mux.Lock() defer s.mux.Unlock() @@ -446,7 +447,7 @@ func (s *FileStore) GetTokenIDByHashedToken(token string) (string, error) { } // GetUserByTokenID returns a User object a tokenID belongs to -func (s *FileStore) GetUserByTokenID(tokenID string) (*User, error) { +func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User, error) { s.mux.Lock() defer s.mux.Unlock() @@ -469,7 +470,7 @@ func (s *FileStore) GetUserByTokenID(tokenID string) (*User, error) { } // GetAllAccounts returns all accounts -func (s *FileStore) GetAllAccounts() (all []*Account) { +func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) { s.mux.Lock() defer s.mux.Unlock() for _, a := range s.Accounts { @@ -490,7 +491,7 @@ func (s *FileStore) getAccount(accountID string) (*Account, error) { } // GetAccount returns an account for ID -func (s *FileStore) GetAccount(accountID string) (*Account, error) { +func (s *FileStore) GetAccount(_ context.Context, accountID string) (*Account, error) { s.mux.Lock() defer s.mux.Unlock() @@ -503,7 +504,7 @@ func (s *FileStore) GetAccount(accountID string) (*Account, error) { } // GetAccountByUser returns a user account -func (s *FileStore) GetAccountByUser(userID string) (*Account, error) { +func (s *FileStore) GetAccountByUser(_ context.Context, userID string) (*Account, error) { s.mux.Lock() defer s.mux.Unlock() @@ -521,7 +522,7 @@ func (s *FileStore) GetAccountByUser(userID string) (*Account, error) { } // GetAccountByPeerID returns an account for a given peer ID -func (s *FileStore) GetAccountByPeerID(peerID string) (*Account, error) { +func (s *FileStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) { s.mux.Lock() defer s.mux.Unlock() @@ -539,7 +540,7 @@ func (s *FileStore) GetAccountByPeerID(peerID string) (*Account, error) { // check Account.Peers for a match if _, ok := account.Peers[peerID]; !ok { delete(s.PeerID2AccountID, peerID) - log.Warnf("removed stale peerID %s to accountID %s index", peerID, accountID) + log.WithContext(ctx).Warnf("removed stale peerID %s to accountID %s index", peerID, accountID) return nil, status.NewPeerNotFoundError(peerID) } @@ -547,7 +548,7 @@ func (s *FileStore) GetAccountByPeerID(peerID string) (*Account, error) { } // GetAccountByPeerPubKey returns an account for a given peer WireGuard public key -func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) { +func (s *FileStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) { s.mux.Lock() defer s.mux.Unlock() @@ -572,14 +573,14 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) { } if stale { delete(s.PeerKeyID2AccountID, peerKey) - log.Warnf("removed stale peerKey %s to accountID %s index", peerKey, accountID) + log.WithContext(ctx).Warnf("removed stale peerKey %s to accountID %s index", peerKey, accountID) return nil, status.NewPeerNotFoundError(peerKey) } return account.Copy(), nil } -func (s *FileStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) { +func (s *FileStore) GetAccountIDByPeerPubKey(_ context.Context, peerKey string) (string, error) { s.mux.Lock() defer s.mux.Unlock() @@ -603,7 +604,7 @@ func (s *FileStore) GetAccountIDByUserID(userID string) (string, error) { return accountID, nil } -func (s *FileStore) GetAccountIDBySetupKey(setupKey string) (string, error) { +func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) (string, error) { s.mux.Lock() defer s.mux.Unlock() @@ -615,7 +616,7 @@ func (s *FileStore) GetAccountIDBySetupKey(setupKey string) (string, error) { return accountID, nil } -func (s *FileStore) GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error) { +func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbpeer.Peer, error) { s.mux.Lock() defer s.mux.Unlock() @@ -638,7 +639,7 @@ func (s *FileStore) GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error) { return nil, status.NewPeerNotFoundError(peerKey) } -func (s *FileStore) GetAccountSettings(accountID string) (*Settings, error) { +func (s *FileStore) GetAccountSettings(_ context.Context, accountID string) (*Settings, error) { s.mux.Lock() defer s.mux.Unlock() @@ -656,13 +657,13 @@ func (s *FileStore) GetInstallationID() string { } // SaveInstallationID saves the installation ID -func (s *FileStore) SaveInstallationID(ID string) error { +func (s *FileStore) SaveInstallationID(ctx context.Context, ID string) error { s.mux.Lock() defer s.mux.Unlock() s.InstallationID = ID - return s.persist(s.storeFile) + return s.persist(ctx, s.storeFile) } // SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things. @@ -732,13 +733,13 @@ func (s *FileStore) GetPostureCheckByChecksDefinition(accountID string, checks * } // Close the FileStore persisting data to disk -func (s *FileStore) Close() error { +func (s *FileStore) Close(ctx context.Context) error { s.mux.Lock() defer s.mux.Unlock() - log.Infof("closing FileStore") + log.WithContext(ctx).Infof("closing FileStore") - return s.persist(s.storeFile) + return s.persist(ctx, s.storeFile) } // GetStoreEngine returns FileStoreEngine diff --git a/management/server/file_store_test.go b/management/server/file_store_test.go index 11571b0be..56e46b696 100644 --- a/management/server/file_store_test.go +++ b/management/server/file_store_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "crypto/sha256" "net" "path/filepath" @@ -27,12 +28,12 @@ func TestStalePeerIndices(t *testing.T) { t.Fatal(err) } - store, err := NewFileStore(storeDir, nil) + store, err := NewFileStore(context.Background(), storeDir, nil) if err != nil { return } - account, err := store.GetAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b") + account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) peerID := "some_peer" @@ -42,24 +43,24 @@ func TestStalePeerIndices(t *testing.T) { Key: peerKey, } - err = store.SaveAccount(account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) account.DeletePeer(peerID) - err = store.SaveAccount(account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) - _, err = store.GetAccountByPeerID(peerID) + _, err = store.GetAccountByPeerID(context.Background(), peerID) require.Error(t, err, "expecting to get an error when found stale index") - _, err = store.GetAccountByPeerPubKey(peerKey) + _, err = store.GetAccountByPeerPubKey(context.Background(), peerKey) require.Error(t, err, "expecting to get an error when found stale index") } func TestNewStore(t *testing.T) { store := newStore(t) - defer store.Close() + defer store.Close(context.Background()) if store.Accounts == nil || len(store.Accounts) != 0 { t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") @@ -88,9 +89,9 @@ func TestNewStore(t *testing.T) { func TestSaveAccount(t *testing.T) { store := newStore(t) - defer store.Close() + defer store.Close(context.Background()) - account := newAccountWithId("account_id", "testuser", "") + account := newAccountWithId(context.Background(), "account_id", "testuser", "") setupKey := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ @@ -103,7 +104,7 @@ func TestSaveAccount(t *testing.T) { } // SaveAccount should trigger persist - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { return } @@ -133,11 +134,11 @@ func TestDeleteAccount(t *testing.T) { t.Fatal(err) } - store, err := NewFileStore(storeDir, nil) + store, err := NewFileStore(context.Background(), storeDir, nil) if err != nil { t.Fatal(err) } - defer store.Close() + defer store.Close(context.Background()) var account *Account for _, a := range store.Accounts { @@ -147,7 +148,7 @@ func TestDeleteAccount(t *testing.T) { require.NotNil(t, account, "failed to restore a FileStore file and get at least one account") - err = store.DeleteAccount(account) + err = store.DeleteAccount(context.Background(), account) require.NoError(t, err, "failed to delete account, error: %v", err) _, ok := store.Accounts[account.Id] @@ -183,9 +184,9 @@ func TestDeleteAccount(t *testing.T) { func TestStore(t *testing.T) { store := newStore(t) - defer store.Close() + defer store.Close(context.Background()) - account := newAccountWithId("account_id", "testuser", "") + account := newAccountWithId(context.Background(), "account_id", "testuser", "") account.Peers["testpeer"] = &nbpeer.Peer{ Key: "peerkey", SetupKey: "peerkeysetupkey", @@ -228,12 +229,12 @@ func TestStore(t *testing.T) { }) // SaveAccount should trigger persist - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { return } - restored, err := NewFileStore(store.storeFile, nil) + restored, err := NewFileStore(context.Background(), store.storeFile, nil) if err != nil { return } @@ -281,7 +282,7 @@ func TestRestore(t *testing.T) { t.Fatal(err) } - store, err := NewFileStore(storeDir, nil) + store, err := NewFileStore(context.Background(), storeDir, nil) if err != nil { return } @@ -319,7 +320,7 @@ func TestRestoreGroups_Migration(t *testing.T) { t.Fatal(err) } - store, err := NewFileStore(storeDir, nil) + store, err := NewFileStore(context.Background(), storeDir, nil) if err != nil { return } @@ -332,11 +333,11 @@ func TestRestoreGroups_Migration(t *testing.T) { Name: "All", }, } - err = store.SaveAccount(account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err, "failed to save account") // restore account with default group with empty Issue field - if store, err = NewFileStore(storeDir, nil); err != nil { + if store, err = NewFileStore(context.Background(), storeDir, nil); err != nil { return } account = store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] @@ -353,18 +354,18 @@ func TestGetAccountByPrivateDomain(t *testing.T) { t.Fatal(err) } - store, err := NewFileStore(storeDir, nil) + store, err := NewFileStore(context.Background(), storeDir, nil) if err != nil { return } existingDomain := "test.com" - account, err := store.GetAccountByPrivateDomain(existingDomain) + account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain) require.NoError(t, err, "should found account") require.Equal(t, existingDomain, account.Domain, "domains should match") - _, err = store.GetAccountByPrivateDomain("missing-domain.com") + _, err = store.GetAccountByPrivateDomain(context.Background(), "missing-domain.com") require.Error(t, err, "should return error on domain lookup") } @@ -382,7 +383,7 @@ func TestFileStore_GetAccount(t *testing.T) { t.Fatal(err) } - store, err := NewFileStore(storeDir, nil) + store, err := NewFileStore(context.Background(), storeDir, nil) if err != nil { t.Fatal(err) } @@ -393,7 +394,7 @@ func TestFileStore_GetAccount(t *testing.T) { return } - account, err := store.GetAccount(expected.Id) + account, err := store.GetAccount(context.Background(), expected.Id) if err != nil { t.Fatal(err) } @@ -424,13 +425,13 @@ func TestFileStore_GetTokenIDByHashedToken(t *testing.T) { t.Fatal(err) } - store, err := NewFileStore(storeDir, nil) + store, err := NewFileStore(context.Background(), storeDir, nil) if err != nil { t.Fatal(err) } hashedToken := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].HashedToken - tokenID, err := store.GetTokenIDByHashedToken(hashedToken) + tokenID, err := store.GetTokenIDByHashedToken(context.Background(), hashedToken) if err != nil { t.Fatal(err) } @@ -441,7 +442,7 @@ func TestFileStore_GetTokenIDByHashedToken(t *testing.T) { func TestFileStore_DeleteHashedPAT2TokenIDIndex(t *testing.T) { store := newStore(t) - defer store.Close() + defer store.Close(context.Background()) store.HashedPAT2TokenID["someHashedToken"] = "someTokenId" err := store.DeleteHashedPAT2TokenIDIndex("someHashedToken") @@ -478,13 +479,13 @@ func TestFileStore_GetTokenIDByHashedToken_Failure(t *testing.T) { t.Fatal(err) } - store, err := NewFileStore(storeDir, nil) + store, err := NewFileStore(context.Background(), storeDir, nil) if err != nil { t.Fatal(err) } wrongToken := sha256.Sum256([]byte("someNotValidTokenThatFails1234")) - _, err = store.GetTokenIDByHashedToken(string(wrongToken[:])) + _, err = store.GetTokenIDByHashedToken(context.Background(), string(wrongToken[:])) assert.Error(t, err, "GetTokenIDByHashedToken should throw error if token invalid") } @@ -503,13 +504,13 @@ func TestFileStore_GetUserByTokenID(t *testing.T) { t.Fatal(err) } - store, err := NewFileStore(storeDir, nil) + store, err := NewFileStore(context.Background(), storeDir, nil) if err != nil { t.Fatal(err) } tokenID := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].ID - user, err := store.GetUserByTokenID(tokenID) + user, err := store.GetUserByTokenID(context.Background(), tokenID) if err != nil { t.Fatal(err) } @@ -531,13 +532,13 @@ func TestFileStore_GetUserByTokenID_Failure(t *testing.T) { t.Fatal(err) } - store, err := NewFileStore(storeDir, nil) + store, err := NewFileStore(context.Background(), storeDir, nil) if err != nil { t.Fatal(err) } wrongTokenID := "someNonExistingTokenID" - _, err = store.GetUserByTokenID(wrongTokenID) + _, err = store.GetUserByTokenID(context.Background(), wrongTokenID) assert.Error(t, err, "GetUserByTokenID should throw error if tokenID invalid") } @@ -550,7 +551,7 @@ func TestFileStore_SavePeerStatus(t *testing.T) { t.Fatal(err) } - store, err := NewFileStore(storeDir, nil) + store, err := NewFileStore(context.Background(), storeDir, nil) if err != nil { return } @@ -576,7 +577,7 @@ func TestFileStore_SavePeerStatus(t *testing.T) { Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}, } - err = store.SaveAccount(account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatal(err) } @@ -602,11 +603,11 @@ func TestFileStore_SavePeerLocation(t *testing.T) { t.Fatal(err) } - store, err := NewFileStore(storeDir, nil) + store, err := NewFileStore(context.Background(), storeDir, nil) if err != nil { return } - account, err := store.GetAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b") + account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) peer := &nbpeer.Peer{ @@ -625,7 +626,7 @@ func TestFileStore_SavePeerLocation(t *testing.T) { assert.Error(t, err) account.Peers[peer.ID] = peer - err = store.SaveAccount(account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) peer.Location.ConnectionIP = net.ParseIP("35.1.1.1") @@ -636,7 +637,7 @@ func TestFileStore_SavePeerLocation(t *testing.T) { err = store.SavePeerLocation(account.Id, account.Peers[peer.ID]) assert.NoError(t, err) - account, err = store.GetAccount(account.Id) + account, err = store.GetAccount(context.Background(), account.Id) require.NoError(t, err) actual := account.Peers[peer.ID].Location @@ -645,7 +646,7 @@ func TestFileStore_SavePeerLocation(t *testing.T) { func newStore(t *testing.T) *FileStore { t.Helper() - store, err := NewFileStore(t.TempDir(), nil) + store, err := NewFileStore(context.Background(), t.TempDir(), nil) if err != nil { t.Errorf("failed creating a new store") } diff --git a/management/server/geolocation/geolocation.go b/management/server/geolocation/geolocation.go index 4fd28806b..794f9d0be 100644 --- a/management/server/geolocation/geolocation.go +++ b/management/server/geolocation/geolocation.go @@ -2,6 +2,7 @@ package geolocation import ( "bytes" + "context" "fmt" "net" "os" @@ -52,7 +53,7 @@ type Country struct { CountryName string } -func NewGeolocation(dataDir string) (*Geolocation, error) { +func NewGeolocation(ctx context.Context, dataDir string) (*Geolocation, error) { if err := loadGeolocationDatabases(dataDir); err != nil { return nil, fmt.Errorf("failed to load MaxMind databases: %v", err) } @@ -68,7 +69,7 @@ func NewGeolocation(dataDir string) (*Geolocation, error) { return nil, err } - locationDB, err := NewSqliteStore(dataDir) + locationDB, err := NewSqliteStore(ctx, dataDir) if err != nil { return nil, err } @@ -83,7 +84,7 @@ func NewGeolocation(dataDir string) (*Geolocation, error) { stopCh: make(chan struct{}), } - go geo.reloader() + go geo.reloader(ctx) return geo, nil } @@ -165,19 +166,19 @@ func (gl *Geolocation) Stop() error { return nil } -func (gl *Geolocation) reloader() { +func (gl *Geolocation) reloader(ctx context.Context) { for { select { case <-gl.stopCh: return case <-time.After(gl.reloadCheckInterval): - if err := gl.locationDB.reload(); err != nil { - log.Errorf("geonames db reload failed: %s", err) + if err := gl.locationDB.reload(ctx); err != nil { + log.WithContext(ctx).Errorf("geonames db reload failed: %s", err) } newSha256sum1, err := calculateFileSHA256(gl.mmdbPath) if err != nil { - log.Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err) + log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err) continue } if !bytes.Equal(gl.sha256sum, newSha256sum1) { @@ -186,30 +187,30 @@ func (gl *Geolocation) reloader() { time.Sleep(50 * time.Millisecond) newSha256sum2, err := calculateFileSHA256(gl.mmdbPath) if err != nil { - log.Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err) + log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err) continue } if !bytes.Equal(newSha256sum1, newSha256sum2) { - log.Errorf("sha256 sum changed during reloading of '%s'", gl.mmdbPath) + log.WithContext(ctx).Errorf("sha256 sum changed during reloading of '%s'", gl.mmdbPath) continue } - err = gl.reload(newSha256sum2) + err = gl.reload(ctx, newSha256sum2) if err != nil { - log.Errorf("mmdb reload failed: %s", err) + log.WithContext(ctx).Errorf("mmdb reload failed: %s", err) } } else { - log.Tracef("No changes in '%s', no need to reload. Next check is in %.0f seconds.", + log.WithContext(ctx).Tracef("No changes in '%s', no need to reload. Next check is in %.0f seconds.", gl.mmdbPath, gl.reloadCheckInterval.Seconds()) } } } } -func (gl *Geolocation) reload(newSha256sum []byte) error { +func (gl *Geolocation) reload(ctx context.Context, newSha256sum []byte) error { gl.mux.Lock() defer gl.mux.Unlock() - log.Infof("Reloading '%s'", gl.mmdbPath) + log.WithContext(ctx).Infof("Reloading '%s'", gl.mmdbPath) err := gl.db.Close() if err != nil { @@ -224,7 +225,7 @@ func (gl *Geolocation) reload(newSha256sum []byte) error { gl.db = db gl.sha256sum = newSha256sum - log.Infof("Successfully reloaded '%s'", gl.mmdbPath) + log.WithContext(ctx).Infof("Successfully reloaded '%s'", gl.mmdbPath) return nil } diff --git a/management/server/geolocation/store.go b/management/server/geolocation/store.go index 3da7989e1..67d420cfd 100644 --- a/management/server/geolocation/store.go +++ b/management/server/geolocation/store.go @@ -2,6 +2,7 @@ package geolocation import ( "bytes" + "context" "fmt" "path/filepath" "runtime" @@ -50,10 +51,10 @@ type SqliteStore struct { sha256sum []byte } -func NewSqliteStore(dataDir string) (*SqliteStore, error) { +func NewSqliteStore(ctx context.Context, dataDir string) (*SqliteStore, error) { file := filepath.Join(dataDir, GeoSqliteDBFile) - db, err := connectDB(file) + db, err := connectDB(ctx, file) if err != nil { return nil, err } @@ -115,13 +116,13 @@ func (s *SqliteStore) GetCitiesByCountry(countryISOCode string) ([]City, error) } // reload attempts to reload the SqliteStore's database if the database file has changed. -func (s *SqliteStore) reload() error { +func (s *SqliteStore) reload(ctx context.Context) error { s.mux.Lock() defer s.mux.Unlock() newSha256sum1, err := calculateFileSHA256(s.filePath) if err != nil { - log.Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err) + log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err) } if !bytes.Equal(s.sha256sum, newSha256sum1) { @@ -136,11 +137,11 @@ func (s *SqliteStore) reload() error { return fmt.Errorf("sha256 sum changed during reloading of '%s'", s.filePath) } - log.Infof("Reloading '%s'", s.filePath) + log.WithContext(ctx).Infof("Reloading '%s'", s.filePath) _ = s.close() s.closed = true - newDb, err := connectDB(s.filePath) + newDb, err := connectDB(ctx, s.filePath) if err != nil { return err } @@ -148,9 +149,9 @@ func (s *SqliteStore) reload() error { s.closed = false s.db = newDb - log.Infof("Successfully reloaded '%s'", s.filePath) + log.WithContext(ctx).Infof("Successfully reloaded '%s'", s.filePath) } else { - log.Tracef("No changes in '%s', no need to reload", s.filePath) + log.WithContext(ctx).Tracef("No changes in '%s', no need to reload", s.filePath) } return nil @@ -168,10 +169,10 @@ func (s *SqliteStore) close() error { } // connectDB connects to an SQLite database and prepares it by setting up an in-memory database. -func connectDB(filePath string) (*gorm.DB, error) { +func connectDB(ctx context.Context, filePath string) (*gorm.DB, error) { start := time.Now() defer func() { - log.Debugf("took %v to setup geoname db", time.Since(start)) + log.WithContext(ctx).Debugf("took %v to setup geoname db", time.Since(start)) }() _, err := fileExists(filePath) diff --git a/management/server/group.go b/management/server/group.go index 7ede2120d..ea512924b 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "github.com/rs/xid" @@ -21,11 +22,11 @@ func (e *GroupLinkError) Error() string { } // GetGroup object of the peers -func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*nbgroup.Group, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -48,11 +49,11 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*n } // GetAllGroups returns all groups in an account -func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ([]*nbgroup.Group, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -75,11 +76,11 @@ func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ( } // GetGroupByName filters all groups in an account by name and returns the one with the most peers -func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*nbgroup.Group, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -108,11 +109,11 @@ func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*n } // SaveGroup object of the peers -func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *nbgroup.Group) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } @@ -150,11 +151,11 @@ func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *n account.Groups[newGroup.ID] = newGroup account.Network.IncSerial() - if err = am.Store.SaveAccount(account); err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return err } - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) // the following snippet tracks the activity and stores the group events in the event store. // It has to happen after all the operations have been successfully performed. @@ -165,16 +166,16 @@ func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *n removedPeers = difference(oldGroup.Peers, newGroup.Peers) } else { addedPeers = append(addedPeers, newGroup.Peers...) - am.StoreEvent(userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta()) + am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta()) } for _, p := range addedPeers { peer := account.Peers[p] if peer == nil { - log.Errorf("peer %s not found under account %s while saving group", p, accountID) + log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) continue } - am.StoreEvent(userID, peer.ID, accountID, activity.GroupAddedToPeer, + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, map[string]any{ "group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), @@ -184,10 +185,10 @@ func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *n for _, p := range removedPeers { peer := account.Peers[p] if peer == nil { - log.Errorf("peer %s not found under account %s while saving group", p, accountID) + log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) continue } - am.StoreEvent(userID, peer.ID, accountID, activity.GroupRemovedFromPeer, + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, map[string]any{ "group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), @@ -213,11 +214,11 @@ func difference(a, b []string) []string { } // DeleteGroup object of the peers -func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error { - unlock := am.Store.AcquireAccountWriteLock(accountId) +func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountId) defer unlock() - account, err := am.Store.GetAccount(accountId) + account, err := am.Store.GetAccount(ctx, accountId) if err != nil { return err } @@ -315,23 +316,23 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) delete(account.Groups, groupID) account.Network.IncSerial() - if err = am.Store.SaveAccount(account); err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return err } - am.StoreEvent(userId, groupID, accountId, activity.GroupDeleted, g.EventMeta()) + am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, g.EventMeta()) - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) return nil } // ListGroups objects of the peers -func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -345,11 +346,11 @@ func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group, } // GroupAddPeer appends peer to the group -func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } @@ -371,21 +372,21 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string) } account.Network.IncSerial() - if err = am.Store.SaveAccount(account); err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return err } - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) return nil } // GroupDeletePeer removes peer from the group -func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID string) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } @@ -399,13 +400,13 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID stri for i, itemID := range group.Peers { if itemID == peerID { group.Peers = append(group.Peers[:i], group.Peers[i+1:]...) - if err := am.Store.SaveAccount(account); err != nil { + if err := am.Store.SaveAccount(ctx, account); err != nil { return err } } } - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) return nil } diff --git a/management/server/group_test.go b/management/server/group_test.go index 1c718715d..373d72964 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "errors" "testing" @@ -26,7 +27,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { } for _, group := range account.Groups { group.Issued = nbgroup.GroupIssuedIntegration - err = am.SaveGroup(account.Id, groupAdminUserID, group) + err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group) if err != nil { t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedIntegration) } @@ -34,7 +35,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { for _, group := range account.Groups { group.Issued = nbgroup.GroupIssuedJWT - err = am.SaveGroup(account.Id, groupAdminUserID, group) + err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group) if err != nil { t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedJWT) } @@ -42,7 +43,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { for _, group := range account.Groups { group.Issued = nbgroup.GroupIssuedAPI group.ID = "" - err = am.SaveGroup(account.Id, groupAdminUserID, group) + err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group) if err == nil { t.Errorf("should not create api group with the same name, %s", group.Name) } @@ -104,7 +105,7 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - err = am.DeleteGroup(account.Id, groupAdminUserID, testCase.groupID) + err = am.DeleteGroup(context.Background(), account.Id, groupAdminUserID, testCase.groupID) if err == nil { t.Errorf("delete %s group successfully", testCase.groupID) return @@ -225,7 +226,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) { Id: "example user", AutoGroups: []string{groupForUsers.ID}, } - account := newAccountWithId(accountID, groupAdminUserID, domain) + account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain) account.Routes[routeResource.ID] = routeResource account.Routes[routePeerGroupResource.ID] = routePeerGroupResource account.NameServerGroups[nameServerGroup.ID] = nameServerGroup @@ -233,18 +234,18 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) { account.SetupKeys[setupKey.Id] = setupKey account.Users[user.Id] = user - err := am.Store.SaveAccount(account) + err := am.Store.SaveAccount(context.Background(), account) if err != nil { return nil, err } - _ = am.SaveGroup(accountID, groupAdminUserID, groupForRoute) - _ = am.SaveGroup(accountID, groupAdminUserID, groupForRoute2) - _ = am.SaveGroup(accountID, groupAdminUserID, groupForNameServerGroups) - _ = am.SaveGroup(accountID, groupAdminUserID, groupForPolicies) - _ = am.SaveGroup(accountID, groupAdminUserID, groupForSetupKeys) - _ = am.SaveGroup(accountID, groupAdminUserID, groupForUsers) - _ = am.SaveGroup(accountID, groupAdminUserID, groupForIntegration) + _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute) + _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2) + _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups) + _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies) + _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys) + _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers) + _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration) - return am.Store.GetAccount(account.Id) + return am.Store.GetAccount(context.Background(), account.Id) } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index a7d3b675d..3abcd1ccd 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -18,8 +18,10 @@ import ( "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/proto" + nbContext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" internalStatus "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" ) @@ -39,7 +41,7 @@ type GRPCServer struct { } // NewServer creates a new Management server -func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnRelayTokenManager TURNRelayTokenManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) { +func NewServer(ctx context.Context, config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnRelayTokenManager TURNRelayTokenManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) { key, err := wgtypes.GeneratePrivateKey() if err != nil { return nil, err @@ -49,6 +51,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) { jwtValidator, err = jwtclaims.NewJWTValidator( + ctx, config.HttpConfig.AuthIssuer, config.GetAuthAudiences(), config.HttpConfig.AuthKeysLocation, @@ -58,7 +61,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err) } } else { - log.Debug("unable to use http turnCfg to create new jwt middleware") + log.WithContext(ctx).Debug("unable to use http config to create new jwt middleware") } if appMetrics != nil { @@ -125,104 +128,134 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountSyncRequest() } - realIP := getRealIP(srv.Context()) - log.Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String()) + + ctx := srv.Context() + + realIP := getRealIP(ctx) syncReq := &proto.SyncRequest{} - peerKey, err := s.parseRequest(req, syncReq) + peerKey, err := s.parseRequest(ctx, req, syncReq) if err != nil { return err } - peer, netMap, err := s.accountManager.SyncAndMarkPeer(peerKey.String(), realIP) + //nolint + ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String()) + accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String()) if err != nil { - return mapError(err) + // this case should not happen and already indicates an issue but we don't want the system to fail due to being unable to log in detail + accountID = "UNKNOWN" + } + //nolint + ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) + + log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String()) + + if syncReq.GetMeta() == nil { + log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP) } - err = s.sendInitialSync(peerKey, peer, netMap, srv) + peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP) if err != nil { - log.Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) + return mapError(ctx, err) + } + + err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv) + if err != nil { + log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) return err } - updates := s.peersUpdateManager.CreateChannel(peer.ID) + updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID) - s.ephemeralManager.OnPeerConnected(peer) + s.ephemeralManager.OnPeerConnected(ctx, peer) if s.config.TURNConfig.TimeBasedCredentials { - s.turnRelayTokenManager.SetupRefresh(peer.ID) + s.turnRelayTokenManager.SetupRefresh(ctx, peer.ID) } if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart)) } - // keep a connection to the peer and send updates when available + return s.handleUpdates(ctx, peerKey, peer, updates, srv) +} + +// handleUpdates sends updates to the connected peer until the updates channel is closed. +func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error { for { select { // condition when there are some updates case update, open := <-updates: - if s.appMetrics != nil { s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1) } if !open { - log.Debugf("updates channel for peer %s was closed", peerKey.String()) - s.cancelPeerRoutines(peer) + log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String()) + s.cancelPeerRoutines(ctx, peer) return nil } - log.Debugf("received an update for peer %s", peerKey.String()) + log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String()) - encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update) - if err != nil { - s.cancelPeerRoutines(peer) - return status.Errorf(codes.Internal, "failed processing update message") + if err := s.sendUpdate(ctx, peerKey, peer, update, srv); err != nil { + return err } - err = srv.SendMsg(&proto.EncryptedMessage{ - WgPubKey: s.wgKey.PublicKey().String(), - Body: encryptedResp, - }) - if err != nil { - s.cancelPeerRoutines(peer) - return status.Errorf(codes.Internal, "failed sending update message") - } - log.Debugf("sent an update to peer %s", peerKey.String()) // condition when client <-> server connection has been terminated case <-srv.Context().Done(): // happens when connection drops, e.g. client disconnects - log.Debugf("stream of peer %s has been closed", peerKey.String()) - s.cancelPeerRoutines(peer) + log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String()) + s.cancelPeerRoutines(ctx, peer) return srv.Context().Err() } } } -func (s *GRPCServer) cancelPeerRoutines(peer *nbpeer.Peer) { - s.peersUpdateManager.CloseChannel(peer.ID) - s.turnRelayTokenManager.CancelRefresh(peer.ID) - _ = s.accountManager.CancelPeerRoutines(peer) - s.ephemeralManager.OnPeerDisconnected(peer) +// sendUpdate encrypts the update message using the peer key and the server's wireguard key, +// then sends the encrypted message to the connected peer via the sync server. +func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error { + encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update) + if err != nil { + s.cancelPeerRoutines(ctx, peer) + return status.Errorf(codes.Internal, "failed processing update message") + } + err = srv.SendMsg(&proto.EncryptedMessage{ + WgPubKey: s.wgKey.PublicKey().String(), + Body: encryptedResp, + }) + if err != nil { + s.cancelPeerRoutines(ctx, peer) + return status.Errorf(codes.Internal, "failed sending update message") + } + log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String()) + return nil } -func (s *GRPCServer) validateToken(jwtToken string) (string, error) { +func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) { + s.peersUpdateManager.CloseChannel(ctx, peer.ID) + s.turnRelayTokenManager.CancelRefresh(peer.ID) + _ = s.accountManager.CancelPeerRoutines(ctx, peer) + s.ephemeralManager.OnPeerDisconnected(ctx, peer) +} + +func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) { if s.jwtValidator == nil { return "", status.Error(codes.Internal, "no jwt validator set") } - token, err := s.jwtValidator.ValidateAndParse(jwtToken) + token, err := s.jwtValidator.ValidateAndParse(ctx, jwtToken) if err != nil { return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err) } claims := s.jwtClaimsExtractor.FromToken(token) // we need to call this method because if user is new, we will automatically add it to existing or create a new account - _, _, err = s.accountManager.GetAccountFromToken(claims) + _, _, err = s.accountManager.GetAccountFromToken(ctx, claims) if err != nil { return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err) } - if err := s.accountManager.CheckUserAccessByJWTGroups(claims); err != nil { + if err := s.accountManager.CheckUserAccessByJWTGroups(ctx, claims); err != nil { return "", status.Errorf(codes.PermissionDenied, err.Error()) } @@ -230,7 +263,7 @@ func (s *GRPCServer) validateToken(jwtToken string) (string, error) { } // maps internal internalStatus.Error to gRPC status.Error -func mapError(err error) error { +func mapError(ctx context.Context, err error) error { if e, ok := internalStatus.FromError(err); ok { switch e.Type() { case internalStatus.PermissionDenied: @@ -246,21 +279,25 @@ func mapError(err error) error { default: } } - log.Errorf("got an unhandled error: %s", err) + log.WithContext(ctx).Errorf("got an unhandled error: %s", err) return status.Errorf(codes.Internal, "failed handling request") } -func extractPeerMeta(loginReq *proto.LoginRequest) nbpeer.PeerSystemMeta { - osVersion := loginReq.GetMeta().GetOSVersion() - if osVersion == "" { - osVersion = loginReq.GetMeta().GetCore() +func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta { + if meta == nil { + return nbpeer.PeerSystemMeta{} } - networkAddresses := make([]nbpeer.NetworkAddress, 0, len(loginReq.GetMeta().GetNetworkAddresses())) - for _, addr := range loginReq.GetMeta().GetNetworkAddresses() { + osVersion := meta.GetOSVersion() + if osVersion == "" { + osVersion = meta.GetCore() + } + + networkAddresses := make([]nbpeer.NetworkAddress, 0, len(meta.GetNetworkAddresses())) + for _, addr := range meta.GetNetworkAddresses() { netAddr, err := netip.ParsePrefix(addr.GetNetIP()) if err != nil { - log.Warnf("failed to parse netip address, %s: %v", addr.GetNetIP(), err) + log.WithContext(ctx).Warnf("failed to parse netip address, %s: %v", addr.GetNetIP(), err) continue } networkAddresses = append(networkAddresses, nbpeer.NetworkAddress{ @@ -269,31 +306,41 @@ func extractPeerMeta(loginReq *proto.LoginRequest) nbpeer.PeerSystemMeta { }) } + files := make([]nbpeer.File, 0, len(meta.GetFiles())) + for _, file := range meta.GetFiles() { + files = append(files, nbpeer.File{ + Path: file.GetPath(), + Exist: file.GetExist(), + ProcessIsRunning: file.GetProcessIsRunning(), + }) + } + return nbpeer.PeerSystemMeta{ - Hostname: loginReq.GetMeta().GetHostname(), - GoOS: loginReq.GetMeta().GetGoOS(), - Kernel: loginReq.GetMeta().GetKernel(), - Platform: loginReq.GetMeta().GetPlatform(), - OS: loginReq.GetMeta().GetOS(), + Hostname: meta.GetHostname(), + GoOS: meta.GetGoOS(), + Kernel: meta.GetKernel(), + Platform: meta.GetPlatform(), + OS: meta.GetOS(), OSVersion: osVersion, - WtVersion: loginReq.GetMeta().GetWiretrusteeVersion(), - UIVersion: loginReq.GetMeta().GetUiVersion(), - KernelVersion: loginReq.GetMeta().GetKernelVersion(), + WtVersion: meta.GetWiretrusteeVersion(), + UIVersion: meta.GetUiVersion(), + KernelVersion: meta.GetKernelVersion(), NetworkAddresses: networkAddresses, - SystemSerialNumber: loginReq.GetMeta().GetSysSerialNumber(), - SystemProductName: loginReq.GetMeta().GetSysProductName(), - SystemManufacturer: loginReq.GetMeta().GetSysManufacturer(), + SystemSerialNumber: meta.GetSysSerialNumber(), + SystemProductName: meta.GetSysProductName(), + SystemManufacturer: meta.GetSysManufacturer(), Environment: nbpeer.Environment{ - Cloud: loginReq.GetMeta().GetEnvironment().GetCloud(), - Platform: loginReq.GetMeta().GetEnvironment().GetPlatform(), + Cloud: meta.GetEnvironment().GetCloud(), + Platform: meta.GetEnvironment().GetPlatform(), }, + Files: files, } } -func (s *GRPCServer) parseRequest(req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) { +func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) { peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) if err != nil { - log.Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey) + log.WithContext(ctx).Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey) return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey) } @@ -320,61 +367,57 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p s.appMetrics.GRPCMetrics().CountLoginRequest() } realIP := getRealIP(ctx) - log.Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String()) + log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String()) loginReq := &proto.LoginRequest{} - peerKey, err := s.parseRequest(req, loginReq) + peerKey, err := s.parseRequest(ctx, req, loginReq) if err != nil { return nil, err } + //nolint + ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String()) + accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String()) + if err != nil { + // this case should not happen and already indicates an issue but we don't want the system to fail due to being unable to log in detail + accountID = "UNKNOWN" + } + //nolint + ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) + if loginReq.GetMeta() == nil { msg := status.Errorf(codes.FailedPrecondition, "peer system meta has to be provided to log in. Peer %s, remote addr %s", peerKey.String(), realIP) - log.Warn(msg) + log.WithContext(ctx).Warn(msg) return nil, msg } - userID := "" - // JWT token is not always provided, it is fine for userID to be empty cuz it might be that peer is already registered, - // or it uses a setup key to register. - - if loginReq.GetJwtToken() != "" { - for i := 0; i < 3; i++ { - userID, err = s.validateToken(loginReq.GetJwtToken()) - if err == nil { - break - } - log.Warnf("failed validating JWT token sent from peer %s with error %v. "+ - "Trying again as it may be due to the IdP cache issue", peerKey, err) - time.Sleep(200 * time.Millisecond) - } - if err != nil { - return nil, err - } + userID, err := s.processJwtToken(ctx, loginReq, peerKey) + if err != nil { + return nil, err } + var sshKey []byte if loginReq.GetPeerKeys() != nil { sshKey = loginReq.GetPeerKeys().GetSshPubKey() } - peer, netMap, err := s.accountManager.LoginPeer(PeerLogin{ + peer, netMap, postureChecks, err := s.accountManager.LoginPeer(ctx, PeerLogin{ WireGuardPubKey: peerKey.String(), SSHKey: string(sshKey), - Meta: extractPeerMeta(loginReq), + Meta: extractPeerMeta(ctx, loginReq.GetMeta()), UserID: userID, SetupKey: loginReq.GetSetupKey(), ConnectionIP: realIP, }) - if err != nil { - log.Warnf("failed logging in peer %s: %s", peerKey, err) - return nil, mapError(err) + log.WithContext(ctx).Warnf("failed logging in peer %s: %s", peerKey, err) + return nil, mapError(ctx, err) } // if the login request contains setup key then it is a registration request if loginReq.GetSetupKey() != "" { - s.ephemeralManager.OnPeerDisconnected(peer) + s.ephemeralManager.OnPeerDisconnected(ctx, peer) } trt, err := s.turnRelayTokenManager.Generate() @@ -386,10 +429,11 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p loginResp := &proto.LoginResponse{ WiretrusteeConfig: toWiretrusteeConfig(s.config, nil, trt), PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()), + Checks: toProtocolChecks(ctx, postureChecks), } encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp) if err != nil { - log.Warnf("failed encrypting peer %s message", peer.ID) + log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID) return nil, status.Errorf(codes.Internal, "failed logging in peer") } @@ -399,6 +443,31 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p }, nil } +// processJwtToken validates the existence of a JWT token in the login request, and returns the corresponding user ID if +// the token is valid. +// +// The user ID can be empty if the token is not provided, which is acceptable if the peer is already +// registered or if it uses a setup key to register. +func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) { + userID := "" + if loginReq.GetJwtToken() != "" { + var err error + for i := 0; i < 3; i++ { + userID, err = s.validateToken(ctx, loginReq.GetJwtToken()) + if err == nil { + break + } + log.WithContext(ctx).Warnf("failed validating JWT token sent from peer %s with error %v. "+ + "Trying again as it may be due to the IdP cache issue", peerKey.String(), err) + time.Sleep(200 * time.Millisecond) + } + if err != nil { + return "", err + } + } + return userID, nil +} + func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol { switch configProto { case UDP: @@ -412,7 +481,7 @@ func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol { case TCP: return proto.HostConfig_TCP default: - panic(fmt.Errorf("unexpected turnCfg protocol type %v", configProto)) + panic(fmt.Errorf("unexpected config protocol type %v", configProto)) } } @@ -495,7 +564,7 @@ func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePee return remotePeers } -func toSyncResponse(config *Config, peer *nbpeer.Peer, turnCredentials *TURNRelayToken, relayCredentials *TURNRelayToken, networkMap *NetworkMap, dnsName string) *proto.SyncResponse { +func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNRelayToken, relayCredentials *TURNRelayToken, networkMap *NetworkMap, dnsName string, checks []*posture.Checks) *proto.SyncResponse { wtConfig := toWiretrusteeConfig(config, turnCredentials, relayCredentials) pConfig := toPeerConfig(peer, networkMap.Network, dnsName) @@ -526,6 +595,7 @@ func toSyncResponse(config *Config, peer *nbpeer.Peer, turnCredentials *TURNRela FirewallRules: firewallRules, FirewallRulesIsEmpty: len(firewallRules) == 0, }, + Checks: toProtocolChecks(ctx, checks), } } @@ -535,7 +605,7 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em } // sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization -func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, srv proto.ManagementService_SyncServer) error { +func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error { // make secret time based TURN credentials optional var turnCredentials *TURNRelayToken trt, err := s.turnRelayTokenManager.Generate() @@ -545,8 +615,7 @@ func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, net if s.config.TURNConfig.TimeBasedCredentials { turnCredentials = trt } - - plainResp := toSyncResponse(s.config, peer, turnCredentials, trt, networkMap, s.accountManager.GetDNSDomain()) + plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, trt, networkMap, s.accountManager.GetDNSDomain(), postureChecks) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) if err != nil { @@ -559,7 +628,7 @@ func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, net }) if err != nil { - log.Errorf("failed sending SyncResponse %v", err) + log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err) return status.Errorf(codes.Internal, "error handling request") } @@ -573,14 +642,14 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto. peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) if err != nil { errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetDeviceAuthorizationFlow request.", req.WgPubKey) - log.Warn(errMSG) + log.WithContext(ctx).Warn(errMSG) return nil, status.Error(codes.InvalidArgument, errMSG) } err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.DeviceAuthorizationFlowRequest{}) if err != nil { errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey) - log.Warn(errMSG) + log.WithContext(ctx).Warn(errMSG) return nil, status.Error(codes.InvalidArgument, errMSG) } @@ -621,18 +690,18 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto. // GetPKCEAuthorizationFlow returns a pkce authorization flow information // This is used for initiating an Oauth 2 pkce authorization grant flow // which will be used by our clients to Login -func (s *GRPCServer) GetPKCEAuthorizationFlow(_ context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { +func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) if err != nil { errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetPKCEAuthorizationFlow request.", req.WgPubKey) - log.Warn(errMSG) + log.WithContext(ctx).Warn(errMSG) return nil, status.Error(codes.InvalidArgument, errMSG) } err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.PKCEAuthorizationFlowRequest{}) if err != nil { errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey) - log.Warn(errMSG) + log.WithContext(ctx).Warn(errMSG) return nil, status.Error(codes.InvalidArgument, errMSG) } @@ -663,3 +732,61 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(_ context.Context, req *proto.Encr Body: encryptedResp, }, nil } + +// SyncMeta endpoint is used to synchronize peer's system metadata and notifies the connected, +// peer's under the same account of any updates. +func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) { + realIP := getRealIP(ctx) + log.WithContext(ctx).Debugf("Sync meta request from peer [%s] [%s]", req.WgPubKey, realIP.String()) + + syncMetaReq := &proto.SyncMetaRequest{} + peerKey, err := s.parseRequest(ctx, req, syncMetaReq) + if err != nil { + return nil, err + } + + if syncMetaReq.GetMeta() == nil { + msg := status.Errorf(codes.FailedPrecondition, + "peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP) + log.WithContext(ctx).Warn(msg) + return nil, msg + } + + err = s.accountManager.SyncPeerMeta(ctx, peerKey.String(), extractPeerMeta(ctx, syncMetaReq.GetMeta())) + if err != nil { + return nil, mapError(ctx, err) + } + + return &proto.Empty{}, nil +} + +// toProtocolChecks converts posture checks to protocol checks. +func toProtocolChecks(ctx context.Context, postureChecks []*posture.Checks) []*proto.Checks { + protoChecks := make([]*proto.Checks, 0, len(postureChecks)) + for _, postureCheck := range postureChecks { + protoChecks = append(protoChecks, toProtocolCheck(postureCheck)) + } + + return protoChecks +} + +// toProtocolCheck converts a posture.Checks to a proto.Checks. +func toProtocolCheck(postureCheck *posture.Checks) *proto.Checks { + protoCheck := &proto.Checks{} + + if check := postureCheck.Checks.ProcessCheck; check != nil { + for _, process := range check.Processes { + if process.LinuxPath != "" { + protoCheck.Files = append(protoCheck.Files, process.LinuxPath) + } + if process.MacPath != "" { + protoCheck.Files = append(protoCheck.Files, process.MacPath) + } + if process.WindowsPath != "" { + protoCheck.Files = append(protoCheck.Files, process.WindowsPath) + } + } + } + + return protoCheck +} diff --git a/management/server/http/accounts_handler.go b/management/server/http/accounts_handler.go index d3c9954d3..ffa5b9a28 100644 --- a/management/server/http/accounts_handler.go +++ b/management/server/http/accounts_handler.go @@ -35,34 +35,34 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) * // GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account. func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } if !(user.HasAdminPower() || user.IsServiceUser) { - util.WriteError(status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w) + util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w) return } resp := toAccountResponse(account) - util.WriteJSONObject(w, []*api.Account{resp}) + util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) } // UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings) func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - _, user, err := h.accountManager.GetAccountFromToken(claims) + _, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) accountID := vars["accountId"] if len(accountID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid accountID ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid accountID ID"), w) return } @@ -96,15 +96,15 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) settings.JWTAllowGroups = *req.Settings.JwtAllowGroups } - updatedAccount, err := h.accountManager.UpdateAccountSettings(accountID, user.Id, settings) + updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, user.Id, settings) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } resp := toAccountResponse(updatedAccount) - util.WriteJSONObject(w, &resp) + util.WriteJSONObject(r.Context(), w, &resp) } // DeleteAccount is a HTTP DELETE handler to delete an account @@ -118,17 +118,17 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) vars := mux.Vars(r) targetAccountID := vars["accountId"] if len(targetAccountID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid account ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid account ID"), w) return } - err := h.accountManager.DeleteAccount(targetAccountID, claims.UserId) + err := h.accountManager.DeleteAccount(r.Context(), targetAccountID, claims.UserId) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, emptyObject{}) } func toAccountResponse(account *server.Account) *api.Account { diff --git a/management/server/http/accounts_handler_test.go b/management/server/http/accounts_handler_test.go index 9d174d0be..45c7679e5 100644 --- a/management/server/http/accounts_handler_test.go +++ b/management/server/http/accounts_handler_test.go @@ -2,6 +2,7 @@ package http import ( "bytes" + "context" "encoding/json" "io" "net/http" @@ -22,10 +23,10 @@ import ( func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler { return &AccountsHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { return account, admin, nil }, - UpdateAccountSettingsFunc: func(accountID, userID string, newSettings *server.Settings) (*server.Account, error) { + UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) { halfYearLimit := 180 * 24 * time.Hour if newSettings.PeerLoginExpiration > halfYearLimit { return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index aeaef6f64..30cb19c0c 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -817,6 +817,8 @@ components: $ref: '#/components/schemas/GeoLocationCheck' peer_network_range_check: $ref: '#/components/schemas/PeerNetworkRangeCheck' + process_check: + $ref: '#/components/schemas/ProcessCheck' NBVersionCheck: description: Posture check for the version of NetBird type: object @@ -905,6 +907,32 @@ components: required: - ranges - action + ProcessCheck: + description: Posture Check for binaries exist and are running in the peer’s system + type: object + properties: + processes: + type: array + items: + $ref: '#/components/schemas/Process' + required: + - processes + Process: + description: Describes the operational activity within a peer's system. + type: object + properties: + linux_path: + description: Path to the process executable file in a Linux operating system + type: string + example: "/usr/local/bin/netbird" + mac_path: + description: Path to the process executable file in a Mac operating system + type: string + example: "/Applications/NetBird.app/Contents/MacOS/netbird" + windows_path: + description: Path to the process executable file in a Windows operating system + type: string + example: "C:\ProgramData\NetBird\netbird.exe" Location: description: Describe geographical location information type: object @@ -995,9 +1023,17 @@ components: type: string example: chacbco6lnnbn6cg5s91 network: - description: Network range in CIDR format + description: Network range in CIDR format, Conflicts with domains type: string example: 10.64.0.0/24 + domains: + description: Domain list to be dynamically resolved. Conflicts with network + type: array + items: + type: string + minLength: 1 + maxLength: 255 + example: "example.com" metric: description: Route metric number. Lowest number has higher priority type: integer @@ -1014,6 +1050,10 @@ components: items: type: string example: "chacdk86lnnboviihd70" + keep_route: + description: Indicate if the route should be kept after a domain doesn't resolve that IP anymore + type: boolean + example: true required: - id - description @@ -1022,10 +1062,13 @@ components: # Only one property has to be set #- peer #- peer_groups - - network + # Only one property has to be set + #- network + #- domains - metric - masquerade - groups + - keep_route Route: allOf: - type: object @@ -1035,7 +1078,7 @@ components: type: string example: chacdk86lnnboviihd7g network_type: - description: Network type indicating if it is IPv4 or IPv6 + description: Network type indicating if it is a domain route or a IPv4/IPv6 route type: string example: IPv4 required: diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index e378213a1..f731356ee 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -225,6 +225,9 @@ type Checks struct { // PeerNetworkRangeCheck Posture check for allow or deny access based on peer local network addresses PeerNetworkRangeCheck *PeerNetworkRangeCheck `json:"peer_network_range_check,omitempty"` + + // ProcessCheck Posture Check for binaries exist and are running in the peer’s system + ProcessCheck *ProcessCheck `json:"process_check,omitempty"` } // City Describe city geographical location information @@ -949,11 +952,31 @@ type PostureCheckUpdate struct { Name string `json:"name"` } +// Process Describes the operational activity within a peer's system. +type Process struct { + // LinuxPath Path to the process executable file in a Linux operating system + LinuxPath *string `json:"linux_path,omitempty"` + + // MacPath Path to the process executable file in a Mac operating system + MacPath *string `json:"mac_path,omitempty"` + + // WindowsPath Path to the process executable file in a Windows operating system + WindowsPath *string `json:"windows_path,omitempty"` +} + +// ProcessCheck Posture Check for binaries exist and are running in the peer’s system +type ProcessCheck struct { + Processes []Process `json:"processes"` +} + // Route defines model for Route. type Route struct { // Description Route description Description string `json:"description"` + // Domains Domain list to be dynamically resolved. Conflicts with network + Domains *[]string `json:"domains,omitempty"` + // Enabled Route status Enabled bool `json:"enabled"` @@ -963,19 +986,22 @@ type Route struct { // Id Route Id Id string `json:"id"` + // KeepRoute Indicate if the route should be kept after a domain doesn't resolve that IP anymore + KeepRoute bool `json:"keep_route"` + // Masquerade Indicate if peer should masquerade traffic to this route's prefix Masquerade bool `json:"masquerade"` // Metric Route metric number. Lowest number has higher priority Metric int `json:"metric"` - // Network Network range in CIDR format - Network string `json:"network"` + // Network Network range in CIDR format, Conflicts with domains + Network *string `json:"network,omitempty"` // NetworkId Route network identifier, to group HA routes NetworkId string `json:"network_id"` - // NetworkType Network type indicating if it is IPv4 or IPv6 + // NetworkType Network type indicating if it is a domain route or a IPv4/IPv6 route NetworkType string `json:"network_type"` // Peer Peer Identifier associated with route. This property can not be set together with `peer_groups` @@ -990,20 +1016,26 @@ type RouteRequest struct { // Description Route description Description string `json:"description"` + // Domains Domain list to be dynamically resolved. Conflicts with network + Domains *[]string `json:"domains,omitempty"` + // Enabled Route status Enabled bool `json:"enabled"` // Groups Group IDs containing routing peers Groups []string `json:"groups"` + // KeepRoute Indicate if the route should be kept after a domain doesn't resolve that IP anymore + KeepRoute bool `json:"keep_route"` + // Masquerade Indicate if peer should masquerade traffic to this route's prefix Masquerade bool `json:"masquerade"` // Metric Route metric number. Lowest number has higher priority Metric int `json:"metric"` - // Network Network range in CIDR format - Network string `json:"network"` + // Network Network range in CIDR format, Conflicts with domains + Network *string `json:"network,omitempty"` // NetworkId Route network identifier, to group HA routes NetworkId string `json:"network_id"` diff --git a/management/server/http/dns_settings_handler.go b/management/server/http/dns_settings_handler.go index baaf7ba69..74b0e1a55 100644 --- a/management/server/http/dns_settings_handler.go +++ b/management/server/http/dns_settings_handler.go @@ -32,16 +32,16 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg // GetDNSSettings returns the DNS settings for the account func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - log.Error(err) + log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - dnsSettings, err := h.accountManager.GetDNSSettings(account.Id, user.Id) + dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), account.Id, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -49,15 +49,15 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque DisabledManagementGroups: dnsSettings.DisabledManagementGroups, } - util.WriteJSONObject(w, apiDNSSettings) + util.WriteJSONObject(r.Context(), w, apiDNSSettings) } // UpdateDNSSettings handles update to DNS settings of an account func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -72,9 +72,9 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re DisabledManagementGroups: req.DisabledManagementGroups, } - err = h.accountManager.SaveDNSSettings(account.Id, user.Id, updateDNSSettings) + err = h.accountManager.SaveDNSSettings(r.Context(), account.Id, user.Id, updateDNSSettings) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -82,5 +82,5 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re DisabledManagementGroups: updateDNSSettings.DisabledManagementGroups, } - util.WriteJSONObject(w, &resp) + util.WriteJSONObject(r.Context(), w, &resp) } diff --git a/management/server/http/dns_settings_handler_test.go b/management/server/http/dns_settings_handler_test.go index a2f65a521..897ae63dc 100644 --- a/management/server/http/dns_settings_handler_test.go +++ b/management/server/http/dns_settings_handler_test.go @@ -2,6 +2,7 @@ package http import ( "bytes" + "context" "encoding/json" "io" "net/http" @@ -42,16 +43,16 @@ var testingDNSSettingsAccount = &server.Account{ func initDNSSettingsTestData() *DNSSettingsHandler { return &DNSSettingsHandler{ accountManager: &mock_server.MockAccountManager{ - GetDNSSettingsFunc: func(accountID string, userID string) (*server.DNSSettings, error) { + GetDNSSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) { return &testingDNSSettingsAccount.DNSSettings, nil }, - SaveDNSSettingsFunc: func(accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error { + SaveDNSSettingsFunc: func(ctx context.Context, accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error { if dnsSettingsToSave != nil { return nil } return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") }, - GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { return testingDNSSettingsAccount, testingDNSSettingsAccount.Users[testDNSSettingsUserID], nil }, }, diff --git a/management/server/http/events_handler.go b/management/server/http/events_handler.go index a89c206a3..428b4c164 100644 --- a/management/server/http/events_handler.go +++ b/management/server/http/events_handler.go @@ -1,6 +1,7 @@ package http import ( + "context" "fmt" "net/http" @@ -33,16 +34,16 @@ func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ev // GetAllEvents list of the given account func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - log.Error(err) + log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - accountEvents, err := h.accountManager.GetEvents(account.Id, user.Id) + accountEvents, err := h.accountManager.GetEvents(r.Context(), account.Id, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } events := make([]*api.Event, len(accountEvents)) @@ -50,20 +51,20 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { events[i] = toEventResponse(e) } - err = h.fillEventsWithUserInfo(events, account.Id, user.Id) + err = h.fillEventsWithUserInfo(r.Context(), events, account.Id, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, events) + util.WriteJSONObject(r.Context(), w, events) } -func (h *EventsHandler) fillEventsWithUserInfo(events []*api.Event, accountId, userId string) error { +func (h *EventsHandler) fillEventsWithUserInfo(ctx context.Context, events []*api.Event, accountId, userId string) error { // build email, name maps based on users - userInfos, err := h.accountManager.GetUsersFromAccount(accountId, userId) + userInfos, err := h.accountManager.GetUsersFromAccount(ctx, accountId, userId) if err != nil { - log.Errorf("failed to get users from account: %s", err) + log.WithContext(ctx).Errorf("failed to get users from account: %s", err) return err } @@ -80,7 +81,7 @@ func (h *EventsHandler) fillEventsWithUserInfo(events []*api.Event, accountId, u if event.InitiatorEmail == "" { event.InitiatorEmail, ok = emails[event.InitiatorId] if !ok { - log.Warnf("failed to resolve email for initiator: %s", event.InitiatorId) + log.WithContext(ctx).Warnf("failed to resolve email for initiator: %s", event.InitiatorId) } } diff --git a/management/server/http/events_handler_test.go b/management/server/http/events_handler_test.go index 4cfad922b..8bdd508bf 100644 --- a/management/server/http/events_handler_test.go +++ b/management/server/http/events_handler_test.go @@ -1,6 +1,7 @@ package http import ( + "context" "encoding/json" "io" "net/http" @@ -22,13 +23,13 @@ import ( func initEventsTestData(account string, user *server.User, events ...*activity.Event) *EventsHandler { return &EventsHandler{ accountManager: &mock_server.MockAccountManager{ - GetEventsFunc: func(accountID, userID string) ([]*activity.Event, error) { + GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) { if accountID == account { return events, nil } return []*activity.Event{}, nil }, - GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { return &server.Account{ Id: claims.AccountId, Domain: "hotmail.com", @@ -37,7 +38,7 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E }, }, user, nil }, - GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) { + GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) { return make([]*server.UserInfo, 0), nil }, }, diff --git a/management/server/http/geolocation_handler_test.go b/management/server/http/geolocation_handler_test.go index 226711002..b8247f78d 100644 --- a/management/server/http/geolocation_handler_test.go +++ b/management/server/http/geolocation_handler_test.go @@ -1,6 +1,7 @@ package http import ( + "context" "encoding/json" "io" "net/http" @@ -35,13 +36,13 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler { err = util.CopyFileContents(geonamesDBPath, path.Join(tempDir, geolocation.GeoSqliteDBFile)) assert.NoError(t, err) - geo, err := geolocation.NewGeolocation(tempDir) + geo, err := geolocation.NewGeolocation(context.Background(), tempDir) assert.NoError(t, err) t.Cleanup(func() { _ = geo.Stop() }) return &GeolocationsHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { user := server.NewAdminUser("test_user") return &server.Account{ Id: claims.AccountId, diff --git a/management/server/http/geolocations_handler.go b/management/server/http/geolocations_handler.go index 070aa6350..af4d3116f 100644 --- a/management/server/http/geolocations_handler.go +++ b/management/server/http/geolocations_handler.go @@ -2,6 +2,7 @@ package http import ( "net/http" + "regexp" "github.com/gorilla/mux" @@ -13,6 +14,10 @@ import ( "github.com/netbirdio/netbird/management/server/status" ) +var ( + countryCodeRegex = regexp.MustCompile("^[a-zA-Z]{2}$") +) + // GeolocationsHandler is a handler that returns locations. type GeolocationsHandler struct { accountManager server.AccountManager @@ -35,19 +40,19 @@ func NewGeolocationsHandlerHandler(accountManager server.AccountManager, geoloca // GetAllCountries retrieves a list of all countries func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Request) { if err := l.authenticateUser(r); err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } if l.geolocationManager == nil { // TODO: update error message to include geo db self hosted doc link when ready - util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w) + util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w) return } allCountries, err := l.geolocationManager.GetAllCountries() if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -55,32 +60,32 @@ func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Req for _, country := range allCountries { countries = append(countries, toCountryResponse(country)) } - util.WriteJSONObject(w, countries) + util.WriteJSONObject(r.Context(), w, countries) } // GetCitiesByCountry retrieves a list of cities based on the given country code func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.Request) { if err := l.authenticateUser(r); err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) countryCode := vars["country"] if !countryCodeRegex.MatchString(countryCode) { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid country code"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid country code"), w) return } if l.geolocationManager == nil { - // TODO: update error message to include geo db self hosted doc link when ready - util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w) + util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+ + "Check the self-hosted Geo database documentation at https://docs.netbird.io/selfhosted/geo-support"), w) return } allCities, err := l.geolocationManager.GetCitiesByCountry(countryCode) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -88,12 +93,12 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http. for _, city := range allCities { cities = append(cities, toCityResponse(city)) } - util.WriteJSONObject(w, cities) + util.WriteJSONObject(r.Context(), w, cities) } func (l *GeolocationsHandler) authenticateUser(r *http.Request) error { claims := l.claimsExtractor.FromRequestContext(r) - _, user, err := l.accountManager.GetAccountFromToken(claims) + _, user, err := l.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { return err } diff --git a/management/server/http/groups_handler.go b/management/server/http/groups_handler.go index 47bcf2f32..c622d873a 100644 --- a/management/server/http/groups_handler.go +++ b/management/server/http/groups_handler.go @@ -35,16 +35,16 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr // GetAllGroups list for the account func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - log.Error(err) + log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - groups, err := h.accountManager.GetAllGroups(account.Id, user.Id) + groups, err := h.accountManager.GetAllGroups(r.Context(), account.Id, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -53,42 +53,42 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { groupsResponse = append(groupsResponse, toGroupResponse(account, group)) } - util.WriteJSONObject(w, groupsResponse) + util.WriteJSONObject(r.Context(), w, groupsResponse) } // UpdateGroup handles update to a group identified by a given ID func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) groupID, ok := vars["groupId"] if !ok { - util.WriteError(status.Errorf(status.InvalidArgument, "group ID field is missing"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group ID field is missing"), w) return } if len(groupID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "group ID can't be empty"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group ID can't be empty"), w) return } eg, ok := account.Groups[groupID] if !ok { - util.WriteError(status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w) + util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w) return } allGroup, err := account.GetGroupAll() if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } if allGroup.ID == groupID { - util.WriteError(status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w) return } @@ -100,7 +100,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { } if req.Name == "" { - util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w) return } @@ -118,21 +118,21 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { IntegrationReference: eg.IntegrationReference, } - if err := h.accountManager.SaveGroup(account.Id, user.Id, &group); err != nil { - log.Errorf("failed updating group %s under account %s %v", groupID, account.Id, err) - util.WriteError(err, w) + if err := h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group); err != nil { + log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, account.Id, err) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, toGroupResponse(account, &group)) + util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group)) } // CreateGroup handles group creation request func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -144,7 +144,7 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { } if req.Name == "" { - util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w) return } @@ -160,62 +160,62 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { Issued: nbgroup.GroupIssuedAPI, } - err = h.accountManager.SaveGroup(account.Id, user.Id, &group) + err = h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, toGroupResponse(account, &group)) + util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group)) } // DeleteGroup handles group deletion request func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } aID := account.Id groupID := mux.Vars(r)["groupId"] if len(groupID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w) return } allGroup, err := account.GetGroupAll() if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } if allGroup.ID == groupID { - util.WriteError(status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w) return } - err = h.accountManager.DeleteGroup(aID, user.Id, groupID) + err = h.accountManager.DeleteGroup(r.Context(), aID, user.Id, groupID) if err != nil { _, ok := err.(*server.GroupLinkError) if ok { util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w) return } - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, emptyObject{}) } // GetGroup returns a group func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -223,19 +223,19 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { case http.MethodGet: groupID := mux.Vars(r)["groupId"] if len(groupID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w) return } - group, err := h.accountManager.GetGroup(account.Id, groupID, user.Id) + group, err := h.accountManager.GetGroup(r.Context(), account.Id, groupID, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, toGroupResponse(account, group)) + util.WriteJSONObject(r.Context(), w, toGroupResponse(account, group)) default: - util.WriteError(status.Errorf(status.NotFound, "HTTP method not found"), w) + util.WriteError(r.Context(), status.Errorf(status.NotFound, "HTTP method not found"), w) return } } diff --git a/management/server/http/groups_handler_test.go b/management/server/http/groups_handler_test.go index 3d74b848c..d5ed07c9e 100644 --- a/management/server/http/groups_handler_test.go +++ b/management/server/http/groups_handler_test.go @@ -2,6 +2,7 @@ package http import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -32,13 +33,13 @@ var TestPeers = map[string]*nbpeer.Peer{ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler { return &GroupsHandler{ accountManager: &mock_server.MockAccountManager{ - SaveGroupFunc: func(accountID, userID string, group *nbgroup.Group) error { + SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error { if !strings.HasPrefix(group.ID, "id-") { group.ID = "id-was-set" } return nil }, - GetGroupFunc: func(_, groupID, _ string) (*nbgroup.Group, error) { + GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) { if groupID != "idofthegroup" { return nil, status.Errorf(status.NotFound, "not found") } @@ -55,7 +56,7 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler { Issued: nbgroup.GroupIssuedAPI, }, nil }, - GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { return &server.Account{ Id: claims.AccountId, Domain: "hotmail.com", @@ -70,7 +71,7 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler { }, }, user, nil }, - DeleteGroupFunc: func(accountID, userId, groupID string) error { + DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error { if groupID == "linked-grp" { return &server.GroupLinkError{ Resource: "something", diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 4405d295c..3fe26d0ce 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -9,6 +9,7 @@ import ( "github.com/rs/cors" "github.com/netbirdio/management-integrations/integrations" + s "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/middleware" @@ -57,6 +58,11 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa corsMiddleware := cors.AllowAll() + claimsExtractor = jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(authCfg.Audience), + jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), + ) + acMiddleware := middleware.NewAccessControl( authCfg.Audience, authCfg.UserIDClaim, diff --git a/management/server/http/middleware/access_control.go b/management/server/http/middleware/access_control.go index de386f173..0ad250f43 100644 --- a/management/server/http/middleware/access_control.go +++ b/management/server/http/middleware/access_control.go @@ -1,6 +1,7 @@ package middleware import ( + "context" "net/http" "regexp" @@ -15,7 +16,7 @@ import ( ) // GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims -type GetUser func(claims jwtclaims.AuthorizationClaims) (*server.User, error) +type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) // AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only type AccessControl struct { @@ -46,15 +47,15 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler { claims := a.claimsExtract.FromRequestContext(r) - user, err := a.getUser(claims) + user, err := a.getUser(r.Context(), claims) if err != nil { - log.Errorf("failed to get user from claims: %s", err) - util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w) + log.WithContext(r.Context()).Errorf("failed to get user from claims: %s", err) + util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid JWT"), w) return } if user.IsBlocked() { - util.WriteError(status.Errorf(status.PermissionDenied, "the user has no access to the API or is blocked"), w) + util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no access to the API or is blocked"), w) return } @@ -63,12 +64,12 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler { case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut: if tokenPathRegexp.MatchString(r.URL.Path) { - log.Debugf("valid Path") + log.WithContext(r.Context()).Debugf("valid Path") h.ServeHTTP(w, r) return } - util.WriteError(status.Errorf(status.PermissionDenied, "only users with admin power can perform this operation"), w) + util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only users with admin power can perform this operation"), w) return } } diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 204c9f4eb..b25aad99c 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" + nbContext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -19,16 +20,16 @@ import ( ) // GetAccountFromPATFunc function -type GetAccountFromPATFunc func(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) +type GetAccountFromPATFunc func(ctx context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) // ValidateAndParseTokenFunc function -type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error) +type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error) // MarkPATUsedFunc function -type MarkPATUsedFunc func(token string) error +type MarkPATUsedFunc func(ctx context.Context, token string) error // CheckUserAccessByJWTGroupsFunc function -type CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error +type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error // AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens type AuthMiddleware struct { @@ -85,23 +86,27 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { case "bearer": err := m.checkJWTFromRequest(w, r, auth) if err != nil { - log.Errorf("Error when validating JWT claims: %s", err.Error()) - util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w) + log.WithContext(r.Context()).Errorf("Error when validating JWT claims: %s", err.Error()) + util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w) return } - h.ServeHTTP(w, r) case "token": err := m.checkPATFromRequest(w, r, auth) if err != nil { - log.Debugf("Error when validating PAT claims: %s", err.Error()) - util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w) + log.WithContext(r.Context()).Debugf("Error when validating PAT claims: %s", err.Error()) + util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w) return } - h.ServeHTTP(w, r) default: - util.WriteError(status.Errorf(status.Unauthorized, "no valid authentication provided"), w) + util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w) return } + claims := m.claimsExtractor.FromRequestContext(r) + //nolint + ctx := context.WithValue(r.Context(), nbContext.UserIDKey, claims.UserId) + //nolint + ctx = context.WithValue(ctx, nbContext.AccountIDKey, claims.AccountId) + h.ServeHTTP(w, r.WithContext(ctx)) }) } @@ -114,7 +119,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ return fmt.Errorf("Error extracting token: %w", err) } - validatedToken, err := m.validateAndParseToken(token) + validatedToken, err := m.validateAndParseToken(r.Context(), token) if err != nil { return err } @@ -123,7 +128,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ return nil } - if err := m.verifyUserAccess(validatedToken); err != nil { + if err := m.verifyUserAccess(r.Context(), validatedToken); err != nil { return err } @@ -138,9 +143,9 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ // verifyUserAccess checks if a user, based on a validated JWT token, // is allowed access, particularly in cases where the admin enabled JWT // group propagation and designated certain groups with access permissions. -func (m *AuthMiddleware) verifyUserAccess(validatedToken *jwt.Token) error { +func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *jwt.Token) error { authClaims := m.claimsExtractor.FromToken(validatedToken) - return m.checkUserAccessByJWTGroups(authClaims) + return m.checkUserAccessByJWTGroups(ctx, authClaims) } // CheckPATFromRequest checks if the PAT is valid @@ -152,7 +157,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ return fmt.Errorf("Error extracting token: %w", err) } - account, user, pat, err := m.getAccountFromPAT(token) + account, user, pat, err := m.getAccountFromPAT(r.Context(), token) if err != nil { return fmt.Errorf("invalid Token: %w", err) } @@ -160,7 +165,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ return fmt.Errorf("token expired") } - err = m.markPATUsed(pat.ID) + err = m.markPATUsed(r.Context(), pat.ID) if err != nil { return err } diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index 588bcaf02..fdfb0ea24 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -1,6 +1,7 @@ package middleware import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -15,15 +16,16 @@ import ( ) const ( - audience = "audience" - userIDClaim = "userIDClaim" - accountID = "accountID" - domain = "domain" - userID = "userID" - tokenID = "tokenID" - PAT = "nbp_PAT" - JWT = "JWT" - wrongToken = "wrongToken" + audience = "audience" + userIDClaim = "userIDClaim" + accountID = "accountID" + domain = "domain" + domainCategory = "domainCategory" + userID = "userID" + tokenID = "tokenID" + PAT = "nbp_PAT" + JWT = "JWT" + wrongToken = "wrongToken" ) var testAccount = &server.Account{ @@ -47,14 +49,14 @@ var testAccount = &server.Account{ }, } -func mockGetAccountFromPAT(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { +func mockGetAccountFromPAT(_ context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { if token == PAT { return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil } return nil, nil, nil, fmt.Errorf("PAT invalid") } -func mockValidateAndParseToken(token string) (*jwt.Token, error) { +func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) { if token == JWT { return &jwt.Token{ Claims: jwt.MapClaims{ @@ -67,14 +69,14 @@ func mockValidateAndParseToken(token string) (*jwt.Token, error) { return nil, fmt.Errorf("JWT invalid") } -func mockMarkPATUsed(token string) error { +func mockMarkPATUsed(_ context.Context, token string) error { if token == tokenID { return nil } return fmt.Errorf("Should never get reached") } -func mockCheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error { +func mockCheckUserAccessByJWTGroups(_ context.Context, claims jwtclaims.AuthorizationClaims) error { if testAccount.Id != claims.AccountId { return fmt.Errorf("account with id %s does not exist", claims.AccountId) } diff --git a/management/server/http/middleware/bypass/bypass.go b/management/server/http/middleware/bypass/bypass.go index 87b41c6fc..9447704cb 100644 --- a/management/server/http/middleware/bypass/bypass.go +++ b/management/server/http/middleware/bypass/bypass.go @@ -56,7 +56,7 @@ func ShouldBypass(requestPath string, h http.Handler, w http.ResponseWriter, r * for bypassPath := range bypassPaths { matched, err := path.Match(bypassPath, requestPath) if err != nil { - log.Errorf("Error matching path %s with %s from %s: %v", bypassPath, requestPath, GetList(), err) + log.WithContext(r.Context()).Errorf("Error matching path %s with %s from %s: %v", bypassPath, requestPath, GetList(), err) continue } if matched { diff --git a/management/server/http/nameservers_handler.go b/management/server/http/nameservers_handler.go index 8d9f0d717..c6e00bb2d 100644 --- a/management/server/http/nameservers_handler.go +++ b/management/server/http/nameservers_handler.go @@ -36,16 +36,16 @@ func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg // GetAllNameservers returns the list of nameserver groups for the account func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - log.Error(err) + log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - nsGroups, err := h.accountManager.ListNameServerGroups(account.Id, user.Id) + nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), account.Id, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -54,15 +54,15 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re apiNameservers = append(apiNameservers, toNameserverGroupResponse(r)) } - util.WriteJSONObject(w, apiNameservers) + util.WriteJSONObject(r.Context(), w, apiNameservers) } // CreateNameserverGroup handles nameserver group creation request func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -75,33 +75,33 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt nsList, err := toServerNSList(req.Nameservers) if err != nil { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid NS servers format"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid NS servers format"), w) return } - nsGroup, err := h.accountManager.CreateNameServerGroup(account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled) + nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } resp := toNameserverGroupResponse(nsGroup) - util.WriteJSONObject(w, &resp) + util.WriteJSONObject(r.Context(), w, &resp) } // UpdateNameserverGroup handles update to a nameserver group identified by a given ID func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } nsGroupID := mux.Vars(r)["nsgroupId"] if len(nsGroupID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) return } @@ -114,7 +114,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt nsList, err := toServerNSList(req.Nameservers) if err != nil { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid NS servers format"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid NS servers format"), w) return } @@ -130,66 +130,66 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt SearchDomainsEnabled: req.SearchDomainsEnabled, } - err = h.accountManager.SaveNameServerGroup(account.Id, user.Id, updatedNSGroup) + err = h.accountManager.SaveNameServerGroup(r.Context(), account.Id, user.Id, updatedNSGroup) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } resp := toNameserverGroupResponse(updatedNSGroup) - util.WriteJSONObject(w, &resp) + util.WriteJSONObject(r.Context(), w, &resp) } // DeleteNameserverGroup handles nameserver group deletion request func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } nsGroupID := mux.Vars(r)["nsgroupId"] if len(nsGroupID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) return } - err = h.accountManager.DeleteNameServerGroup(account.Id, nsGroupID, user.Id) + err = h.accountManager.DeleteNameServerGroup(r.Context(), account.Id, nsGroupID, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, emptyObject{}) } // GetNameserverGroup handles a nameserver group Get request identified by ID func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - log.Error(err) + log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } nsGroupID := mux.Vars(r)["nsgroupId"] if len(nsGroupID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) return } - nsGroup, err := h.accountManager.GetNameServerGroup(account.Id, user.Id, nsGroupID) + nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), account.Id, user.Id, nsGroupID) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } resp := toNameserverGroupResponse(nsGroup) - util.WriteJSONObject(w, &resp) + util.WriteJSONObject(r.Context(), w, &resp) } func toServerNSList(apiNSList []api.Nameserver) ([]nbdns.NameServer, error) { diff --git a/management/server/http/nameservers_handler_test.go b/management/server/http/nameservers_handler_test.go index e1fabb198..28b080571 100644 --- a/management/server/http/nameservers_handler_test.go +++ b/management/server/http/nameservers_handler_test.go @@ -2,6 +2,7 @@ package http import ( "bytes" + "context" "encoding/json" "io" "net/http" @@ -61,13 +62,13 @@ var baseExistingNSGroup = &nbdns.NameServerGroup{ func initNameserversTestData() *NameserversHandler { return &NameserversHandler{ accountManager: &mock_server.MockAccountManager{ - GetNameServerGroupFunc: func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { + GetNameServerGroupFunc: func(_ context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { if nsGroupID == existingNSGroupID { return baseExistingNSGroup.Copy(), nil } return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID) }, - CreateNameServerGroupFunc: func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, _ string, searchDomains bool) (*nbdns.NameServerGroup, error) { + CreateNameServerGroupFunc: func(_ context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, _ string, searchDomains bool) (*nbdns.NameServerGroup, error) { return &nbdns.NameServerGroup{ ID: existingNSGroupID, Name: name, @@ -80,16 +81,16 @@ func initNameserversTestData() *NameserversHandler { SearchDomainsEnabled: searchDomains, }, nil }, - DeleteNameServerGroupFunc: func(accountID, nsGroupID, _ string) error { + DeleteNameServerGroupFunc: func(_ context.Context, accountID, nsGroupID, _ string) error { return nil }, - SaveNameServerGroupFunc: func(accountID, _ string, nsGroupToSave *nbdns.NameServerGroup) error { + SaveNameServerGroupFunc: func(_ context.Context, accountID, _ string, nsGroupToSave *nbdns.NameServerGroup) error { if nsGroupToSave.ID == existingNSGroupID { return nil } return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID) }, - GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { return testingNSAccount, testingAccount.Users["test_user"], nil }, }, diff --git a/management/server/http/pat_handler.go b/management/server/http/pat_handler.go index d2398a7e1..9d8448d3d 100644 --- a/management/server/http/pat_handler.go +++ b/management/server/http/pat_handler.go @@ -34,22 +34,22 @@ func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATH // GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) userID := vars["userId"] if len(userID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } - pats, err := h.accountManager.GetAllPATs(account.Id, user.Id, userID) + pats, err := h.accountManager.GetAllPATs(r.Context(), account.Id, user.Id, userID) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -58,53 +58,53 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { patResponse = append(patResponse, toPATResponse(pat)) } - util.WriteJSONObject(w, patResponse) + util.WriteJSONObject(r.Context(), w, patResponse) } // GetToken is HTTP GET handler that returns a personal access token for the given user func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } tokenID := vars["tokenId"] if len(tokenID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid token ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid token ID"), w) return } - pat, err := h.accountManager.GetPAT(account.Id, user.Id, targetUserID, tokenID) + pat, err := h.accountManager.GetPAT(r.Context(), account.Id, user.Id, targetUserID, tokenID) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, toPATResponse(pat)) + util.WriteJSONObject(r.Context(), w, toPATResponse(pat)) } // CreateToken is HTTP POST handler that creates a personal access token for the given user func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } @@ -115,44 +115,44 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { return } - pat, err := h.accountManager.CreatePAT(account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn) + pat, err := h.accountManager.CreatePAT(r.Context(), account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, toPATGeneratedResponse(pat)) + util.WriteJSONObject(r.Context(), w, toPATGeneratedResponse(pat)) } // DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } tokenID := vars["tokenId"] if len(tokenID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid token ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid token ID"), w) return } - err = h.accountManager.DeletePAT(account.Id, user.Id, targetUserID, tokenID) + err = h.accountManager.DeletePAT(r.Context(), account.Id, user.Id, targetUserID, tokenID) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, emptyObject{}) } func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken { diff --git a/management/server/http/pat_handler_test.go b/management/server/http/pat_handler_test.go index 45fda0a55..b72f71468 100644 --- a/management/server/http/pat_handler_test.go +++ b/management/server/http/pat_handler_test.go @@ -2,6 +2,7 @@ package http import ( "bytes" + "context" "encoding/json" "io" "net/http" @@ -27,12 +28,12 @@ const ( notFoundUserID = "notFoundUserID" existingTokenID = "existingTokenID" notFoundTokenID = "notFoundTokenID" - domain = "hotmail.com" + testDomain = "hotmail.com" ) var testAccount = &server.Account{ Id: existingAccountID, - Domain: domain, + Domain: testDomain, Users: map[string]*server.User{ existingUserID: { Id: existingUserID, @@ -63,7 +64,7 @@ var testAccount = &server.Account{ func initPATTestData() *PATHandler { return &PATHandler{ accountManager: &mock_server.MockAccountManager{ - CreatePATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { + CreatePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { if accountID != existingAccountID { return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) } @@ -76,10 +77,10 @@ func initPATTestData() *PATHandler { }, nil }, - GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { return testAccount, testAccount.Users[existingUserID], nil }, - DeletePATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenID string) error { + DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error { if accountID != existingAccountID { return status.Errorf(status.NotFound, "account with ID %s not found", accountID) } @@ -91,7 +92,7 @@ func initPATTestData() *PATHandler { } return nil }, - GetPATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) { + GetPATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) { if accountID != existingAccountID { return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) } @@ -103,7 +104,7 @@ func initPATTestData() *PATHandler { } return testAccount.Users[existingUserID].PATs[existingTokenID], nil }, - GetAllPATsFunc: func(accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) { + GetAllPATsFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) { if accountID != existingAccountID { return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) } @@ -117,7 +118,7 @@ func initPATTestData() *PATHandler { jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { return jwtclaims.AuthorizationClaims{ UserId: existingUserID, - Domain: domain, + Domain: testDomain, AccountId: testNSGroupAccountID, } }), diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index 762576506..1fb18669c 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -1,6 +1,7 @@ package http import ( + "context" "encoding/json" "fmt" "net/http" @@ -47,16 +48,16 @@ func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) return peerToReturn, nil } -func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w http.ResponseWriter) { - peer, err := h.accountManager.GetPeer(account.Id, peerID, userID) +func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) { + peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID) if err != nil { - util.WriteError(err, w) + util.WriteError(ctx, err, w) return } peerToReturn, err := h.checkPeerStatus(peer) if err != nil { - util.WriteError(err, w) + util.WriteError(ctx, err, w) return } dnsDomain := h.accountManager.GetDNSDomain() @@ -65,19 +66,19 @@ func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w validPeers, err := h.accountManager.GetValidatedPeers(account) if err != nil { - log.Errorf("failed to list appreoved peers: %v", err) - util.WriteError(fmt.Errorf("internal error"), w) + log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) + util.WriteError(ctx, fmt.Errorf("internal error"), w) return } - netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validPeers) + netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers) accessiblePeers := toAccessiblePeers(netMap, dnsDomain) _, valid := validPeers[peer.ID] - util.WriteJSONObject(w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid)) + util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid)) } -func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) { +func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) { req := &api.PeerRequest{} err := json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -99,9 +100,9 @@ func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, pe } } - peer, err := h.accountManager.UpdatePeer(account.Id, user.Id, update) + peer, err := h.accountManager.UpdatePeer(ctx, account.Id, user.Id, update) if err != nil { - util.WriteError(err, w) + util.WriteError(ctx, err, w) return } dnsDomain := h.accountManager.GetDNSDomain() @@ -110,75 +111,75 @@ func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, pe validPeers, err := h.accountManager.GetValidatedPeers(account) if err != nil { - log.Errorf("failed to list appreoved peers: %v", err) - util.WriteError(fmt.Errorf("internal error"), w) + log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) + util.WriteError(ctx, fmt.Errorf("internal error"), w) return } - netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validPeers) + netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers) accessiblePeers := toAccessiblePeers(netMap, dnsDomain) _, valid := validPeers[peer.ID] - util.WriteJSONObject(w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid)) + util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid)) } -func (h *PeersHandler) deletePeer(accountID, userID string, peerID string, w http.ResponseWriter) { - err := h.accountManager.DeletePeer(accountID, peerID, userID) +func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) { + err := h.accountManager.DeletePeer(ctx, accountID, peerID, userID) if err != nil { - log.Errorf("failed to delete peer: %v", err) - util.WriteError(err, w) + log.WithContext(ctx).Errorf("failed to delete peer: %v", err) + util.WriteError(ctx, err, w) return } - util.WriteJSONObject(w, emptyObject{}) + util.WriteJSONObject(ctx, w, emptyObject{}) } // HandlePeer handles all peer requests for GET, PUT and DELETE operations func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) peerID := vars["peerId"] if len(peerID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid peer ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w) return } switch r.Method { case http.MethodDelete: - h.deletePeer(account.Id, user.Id, peerID, w) + h.deletePeer(r.Context(), account.Id, user.Id, peerID, w) return case http.MethodPut: - h.updatePeer(account, user, peerID, w, r) + h.updatePeer(r.Context(), account, user, peerID, w, r) return case http.MethodGet: - h.getPeer(account, peerID, user.Id, w) + h.getPeer(r.Context(), account, peerID, user.Id, w) return default: - util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w) + util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w) } } // GetAllPeers returns a list of all peers associated with a provided account func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { - util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w) + util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w) return } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - peers, err := h.accountManager.GetPeers(account.Id, user.Id) + peers, err := h.accountManager.GetPeers(r.Context(), account.Id, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -188,34 +189,34 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { for _, peer := range peers { peerToReturn, err := h.checkPeerStatus(peer) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) - accessiblePeerNumbers, _ := h.accessiblePeersNumber(account, peer.ID) + accessiblePeerNumbers, _ := h.accessiblePeersNumber(r.Context(), account, peer.ID) respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, accessiblePeerNumbers)) } validPeersMap, err := h.accountManager.GetValidatedPeers(account) if err != nil { - log.Errorf("failed to list appreoved peers: %v", err) - util.WriteError(fmt.Errorf("internal error"), w) + log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err) + util.WriteError(r.Context(), fmt.Errorf("internal error"), w) return } h.setApprovalRequiredFlag(respBody, validPeersMap) - util.WriteJSONObject(w, respBody) + util.WriteJSONObject(r.Context(), w, respBody) } -func (h *PeersHandler) accessiblePeersNumber(account *server.Account, peerID string) (int, error) { +func (h *PeersHandler) accessiblePeersNumber(ctx context.Context, account *server.Account, peerID string) (int, error) { validatedPeersMap, err := h.accountManager.GetValidatedPeers(account) if err != nil { return 0, err } - netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validatedPeersMap) + netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validatedPeersMap) return len(netMap.Peers) + len(netMap.OfflinePeers), nil } diff --git a/management/server/http/peers_handler_test.go b/management/server/http/peers_handler_test.go index 53df5cb00..153c8f03a 100644 --- a/management/server/http/peers_handler_test.go +++ b/management/server/http/peers_handler_test.go @@ -2,6 +2,7 @@ package http import ( "bytes" + "context" "encoding/json" "io" "net" @@ -29,7 +30,7 @@ const noUpdateChannelTestPeerID = "no-update-channel" func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { return &PeersHandler{ accountManager: &mock_server.MockAccountManager{ - UpdatePeerFunc: func(accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { + UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { var p *nbpeer.Peer for _, peer := range peers { if update.ID == peer.ID { @@ -42,7 +43,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { p.Name = update.Name return p, nil }, - GetPeerFunc: func(accountID, peerID, userID string) (*nbpeer.Peer, error) { + GetPeerFunc: func(_ context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) { var p *nbpeer.Peer for _, peer := range peers { if peerID == peer.ID { @@ -52,13 +53,13 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { } return p, nil }, - GetPeersFunc: func(accountID, userID string) ([]*nbpeer.Peer, error) { + GetPeersFunc: func(_ context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { return peers, nil }, GetDNSDomainFunc: func() string { return "netbird.selfhosted" }, - GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { user := server.NewAdminUser("test_user") return &server.Account{ Id: claims.AccountId, diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go index e163e63b9..9622668f4 100644 --- a/management/server/http/policies_handler.go +++ b/management/server/http/policies_handler.go @@ -35,15 +35,15 @@ func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) * // GetAllPolicies list for the account func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - accountPolicies, err := h.accountManager.ListPolicies(account.Id, user.Id) + accountPolicies, err := h.accountManager.ListPolicies(r.Context(), account.Id, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -51,28 +51,28 @@ func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) { for _, policy := range accountPolicies { resp := toPolicyResponse(account, policy) if len(resp.Rules) == 0 { - util.WriteError(status.Errorf(status.Internal, "no rules in the policy"), w) + util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) return } policies = append(policies, resp) } - util.WriteJSONObject(w, policies) + util.WriteJSONObject(r.Context(), w, policies) } // UpdatePolicy handles update to a policy identified by a given ID func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) policyID := vars["policyId"] if len(policyID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w) return } @@ -84,7 +84,7 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { } } if policyIdx < 0 { - util.WriteError(status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w) + util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w) return } @@ -94,9 +94,9 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { // CreatePolicy handles policy creation request func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -118,12 +118,12 @@ func (h *Policies) savePolicy( } if req.Name == "" { - util.WriteError(status.Errorf(status.InvalidArgument, "policy name shouldn't be empty"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy name shouldn't be empty"), w) return } if len(req.Rules) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "policy rules shouldn't be empty"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy rules shouldn't be empty"), w) return } @@ -137,31 +137,31 @@ func (h *Policies) savePolicy( Enabled: req.Enabled, Description: req.Description, } - for _, r := range req.Rules { + for _, rule := range req.Rules { pr := server.PolicyRule{ - ID: policyID, //TODO: when policy can contain multiple rules, need refactor - Name: r.Name, - Destinations: groupMinimumsToStrings(account, r.Destinations), - Sources: groupMinimumsToStrings(account, r.Sources), - Bidirectional: r.Bidirectional, + ID: policyID, // TODO: when policy can contain multiple rules, need refactor + Name: rule.Name, + Destinations: groupMinimumsToStrings(account, rule.Destinations), + Sources: groupMinimumsToStrings(account, rule.Sources), + Bidirectional: rule.Bidirectional, } - pr.Enabled = r.Enabled - if r.Description != nil { - pr.Description = *r.Description + pr.Enabled = rule.Enabled + if rule.Description != nil { + pr.Description = *rule.Description } - switch r.Action { + switch rule.Action { case api.PolicyRuleUpdateActionAccept: pr.Action = server.PolicyTrafficActionAccept case api.PolicyRuleUpdateActionDrop: pr.Action = server.PolicyTrafficActionDrop default: - util.WriteError(status.Errorf(status.InvalidArgument, "unknown action type"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown action type"), w) return } - switch r.Protocol { + switch rule.Protocol { case api.PolicyRuleUpdateProtocolAll: pr.Protocol = server.PolicyRuleProtocolALL case api.PolicyRuleUpdateProtocolTcp: @@ -171,14 +171,14 @@ func (h *Policies) savePolicy( case api.PolicyRuleUpdateProtocolIcmp: pr.Protocol = server.PolicyRuleProtocolICMP default: - util.WriteError(status.Errorf(status.InvalidArgument, "unknown protocol type: %v", r.Protocol), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown protocol type: %v", rule.Protocol), w) return } - if r.Ports != nil && len(*r.Ports) != 0 { - for _, v := range *r.Ports { + if rule.Ports != nil && len(*rule.Ports) != 0 { + for _, v := range *rule.Ports { if port, err := strconv.Atoi(v); err != nil || port < 1 || port > 65535 { - util.WriteError(status.Errorf(status.InvalidArgument, "valid port value is in 1..65535 range"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "valid port value is in 1..65535 range"), w) return } pr.Ports = append(pr.Ports, v) @@ -189,16 +189,16 @@ func (h *Policies) savePolicy( switch pr.Protocol { case server.PolicyRuleProtocolALL, server.PolicyRuleProtocolICMP: if len(pr.Ports) != 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w) return } if !pr.Bidirectional { - util.WriteError(status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w) return } case server.PolicyRuleProtocolTCP, server.PolicyRuleProtocolUDP: if !pr.Bidirectional && len(pr.Ports) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w) return } } @@ -210,26 +210,26 @@ func (h *Policies) savePolicy( policy.SourcePostureChecks = sourcePostureChecksToStrings(account, *req.SourcePostureChecks) } - if err := h.accountManager.SavePolicy(account.Id, user.Id, &policy); err != nil { - util.WriteError(err, w) + if err := h.accountManager.SavePolicy(r.Context(), account.Id, user.Id, &policy); err != nil { + util.WriteError(r.Context(), err, w) return } resp := toPolicyResponse(account, &policy) if len(resp.Rules) == 0 { - util.WriteError(status.Errorf(status.Internal, "no rules in the policy"), w) + util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) return } - util.WriteJSONObject(w, resp) + util.WriteJSONObject(r.Context(), w, resp) } // DeletePolicy handles policy deletion request func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } aID := account.Id @@ -237,24 +237,24 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) policyID := vars["policyId"] if len(policyID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w) return } - if err = h.accountManager.DeletePolicy(aID, policyID, user.Id); err != nil { - util.WriteError(err, w) + if err = h.accountManager.DeletePolicy(r.Context(), aID, policyID, user.Id); err != nil { + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, emptyObject{}) } // GetPolicy handles a group Get request identified by ID func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -263,25 +263,25 @@ func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) policyID := vars["policyId"] if len(policyID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w) return } - policy, err := h.accountManager.GetPolicy(account.Id, policyID, user.Id) + policy, err := h.accountManager.GetPolicy(r.Context(), account.Id, policyID, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } resp := toPolicyResponse(account, policy) if len(resp.Rules) == 0 { - util.WriteError(status.Errorf(status.Internal, "no rules in the policy"), w) + util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) return } - util.WriteJSONObject(w, resp) + util.WriteJSONObject(r.Context(), w, resp) default: - util.WriteError(status.Errorf(status.NotFound, "method not found"), w) + util.WriteError(r.Context(), status.Errorf(status.NotFound, "method not found"), w) } } diff --git a/management/server/http/policies_handler_test.go b/management/server/http/policies_handler_test.go index 74e682854..06274fb07 100644 --- a/management/server/http/policies_handler_test.go +++ b/management/server/http/policies_handler_test.go @@ -2,6 +2,7 @@ package http import ( "bytes" + "context" "encoding/json" "io" "net/http" @@ -30,21 +31,21 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies { } return &Policies{ accountManager: &mock_server.MockAccountManager{ - GetPolicyFunc: func(_, policyID, _ string) (*server.Policy, error) { + GetPolicyFunc: func(_ context.Context, _, policyID, _ string) (*server.Policy, error) { policy, ok := testPolicies[policyID] if !ok { return nil, status.Errorf(status.NotFound, "policy not found") } return policy, nil }, - SavePolicyFunc: func(_, _ string, policy *server.Policy) error { + SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) error { if !strings.HasPrefix(policy.ID, "id-") { policy.ID = "id-was-set" policy.Rules[0].ID = "id-was-set" } return nil }, - GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { user := server.NewAdminUser("test_user") return &server.Account{ Id: claims.AccountId, diff --git a/management/server/http/posture_checks_handler.go b/management/server/http/posture_checks_handler.go index f256d9ee0..059cb3b80 100644 --- a/management/server/http/posture_checks_handler.go +++ b/management/server/http/posture_checks_handler.go @@ -3,8 +3,6 @@ package http import ( "encoding/json" "net/http" - "regexp" - "slices" "github.com/gorilla/mux" @@ -17,10 +15,6 @@ import ( "github.com/netbirdio/netbird/management/server/status" ) -var ( - countryCodeRegex = regexp.MustCompile("^[a-zA-Z]{2}$") -) - // PostureChecksHandler is a handler that returns posture checks of the account. type PostureChecksHandler struct { accountManager server.AccountManager @@ -43,15 +37,15 @@ func NewPostureChecksHandler(accountManager server.AccountManager, geolocationMa // GetAllPostureChecks list for the account func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(claims) + account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - accountPostureChecks, err := p.accountManager.ListPostureChecks(account.Id, user.Id) + accountPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), account.Id, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -60,22 +54,22 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt postureChecks = append(postureChecks, postureCheck.ToAPIResponse()) } - util.WriteJSONObject(w, postureChecks) + util.WriteJSONObject(r.Context(), w, postureChecks) } // UpdatePostureCheck handles update to a posture check identified by a given ID func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(claims) + account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) postureChecksID := vars["postureCheckId"] if len(postureChecksID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w) return } @@ -87,7 +81,7 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http } } if postureChecksIdx < 0 { - util.WriteError(status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w) + util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w) return } @@ -97,9 +91,9 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http // CreatePostureCheck handles posture check creation request func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(claims) + account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -109,50 +103,50 @@ func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http // GetPostureCheck handles a posture check Get request identified by ID func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(claims) + account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) postureChecksID := vars["postureCheckId"] if len(postureChecksID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w) return } - postureChecks, err := p.accountManager.GetPostureChecks(account.Id, postureChecksID, user.Id) + postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), account.Id, postureChecksID, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, postureChecks.ToAPIResponse()) + util.WriteJSONObject(r.Context(), w, postureChecks.ToAPIResponse()) } // DeletePostureCheck handles posture check deletion request func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(claims) + account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) postureChecksID := vars["postureCheckId"] if len(postureChecksID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w) return } - if err = p.accountManager.DeletePostureChecks(account.Id, postureChecksID, user.Id); err != nil { - util.WriteError(err, w) + if err = p.accountManager.DeletePostureChecks(r.Context(), account.Id, postureChecksID, user.Id); err != nil { + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, emptyObject{}) } // savePostureChecks handles posture checks create and update @@ -163,103 +157,34 @@ func (p *PostureChecksHandler) savePostureChecks( user *server.User, postureChecksID string, ) { + var ( + err error + req api.PostureCheckUpdate + ) - var req api.PostureCheckUpdate - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + if err = json.NewDecoder(r.Body).Decode(&req); err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return } - err := validatePostureChecksUpdate(req) - if err != nil { - util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w) - return - } - if geoLocationCheck := req.Checks.GeoLocationCheck; geoLocationCheck != nil { if p.geolocationManager == nil { - // TODO: update error message to include geo db self hosted doc link when ready - util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w) + util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+ + "Check the self-hosted Geo database documentation at https://docs.netbird.io/selfhosted/geo-support"), w) return } } postureChecks, err := posture.NewChecksFromAPIPostureCheckUpdate(req, postureChecksID) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - if err := p.accountManager.SavePostureChecks(account.Id, user.Id, postureChecks); err != nil { - util.WriteError(err, w) + if err := p.accountManager.SavePostureChecks(r.Context(), account.Id, user.Id, postureChecks); err != nil { + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, postureChecks.ToAPIResponse()) -} - -func validatePostureChecksUpdate(req api.PostureCheckUpdate) error { - if req.Name == "" { - return status.Errorf(status.InvalidArgument, "posture checks name shouldn't be empty") - } - - if req.Checks == nil || (req.Checks.NbVersionCheck == nil && req.Checks.OsVersionCheck == nil && - req.Checks.GeoLocationCheck == nil && req.Checks.PeerNetworkRangeCheck == nil) { - return status.Errorf(status.InvalidArgument, "posture checks shouldn't be empty") - } - - if req.Checks.NbVersionCheck != nil && req.Checks.NbVersionCheck.MinVersion == "" { - return status.Errorf(status.InvalidArgument, "minimum version for NetBird's version check shouldn't be empty") - } - - if osVersionCheck := req.Checks.OsVersionCheck; osVersionCheck != nil { - emptyOS := osVersionCheck.Android == nil && osVersionCheck.Darwin == nil && osVersionCheck.Ios == nil && - osVersionCheck.Linux == nil && osVersionCheck.Windows == nil - emptyMinVersion := osVersionCheck.Android != nil && osVersionCheck.Android.MinVersion == "" || - osVersionCheck.Darwin != nil && osVersionCheck.Darwin.MinVersion == "" || - osVersionCheck.Ios != nil && osVersionCheck.Ios.MinVersion == "" || - osVersionCheck.Linux != nil && osVersionCheck.Linux.MinKernelVersion == "" || - osVersionCheck.Windows != nil && osVersionCheck.Windows.MinKernelVersion == "" - if emptyOS || emptyMinVersion { - return status.Errorf(status.InvalidArgument, - "minimum version for at least one OS in the OS version check shouldn't be empty") - } - } - - if geoLocationCheck := req.Checks.GeoLocationCheck; geoLocationCheck != nil { - if geoLocationCheck.Action == "" { - return status.Errorf(status.InvalidArgument, "action for geolocation check shouldn't be empty") - } - allowedActions := []api.GeoLocationCheckAction{api.GeoLocationCheckActionAllow, api.GeoLocationCheckActionDeny} - if !slices.Contains(allowedActions, geoLocationCheck.Action) { - return status.Errorf(status.InvalidArgument, "action for geolocation check is not valid value") - } - if len(geoLocationCheck.Locations) == 0 { - return status.Errorf(status.InvalidArgument, "locations for geolocation check shouldn't be empty") - } - for _, loc := range geoLocationCheck.Locations { - if loc.CountryCode == "" { - return status.Errorf(status.InvalidArgument, "country code for geolocation check shouldn't be empty") - } - if !countryCodeRegex.MatchString(loc.CountryCode) { - return status.Errorf(status.InvalidArgument, "country code must be 2 letters (ISO 3166-1 alpha-2 format)") - } - } - } - - if peerNetworkRangeCheck := req.Checks.PeerNetworkRangeCheck; peerNetworkRangeCheck != nil { - if peerNetworkRangeCheck.Action == "" { - return status.Errorf(status.InvalidArgument, "action for peer network range check shouldn't be empty") - } - - allowedActions := []api.PeerNetworkRangeCheckAction{api.PeerNetworkRangeCheckActionAllow, api.PeerNetworkRangeCheckActionDeny} - if !slices.Contains(allowedActions, peerNetworkRangeCheck.Action) { - return status.Errorf(status.InvalidArgument, "action for peer network range check is not valid value") - } - if len(peerNetworkRangeCheck.Ranges) == 0 { - return status.Errorf(status.InvalidArgument, "network ranges for peer network range check shouldn't be empty") - } - } - - return nil + util.WriteJSONObject(r.Context(), w, postureChecks.ToAPIResponse()) } diff --git a/management/server/http/posture_checks_handler_test.go b/management/server/http/posture_checks_handler_test.go index 70e803214..dcb6e4924 100644 --- a/management/server/http/posture_checks_handler_test.go +++ b/management/server/http/posture_checks_handler_test.go @@ -2,6 +2,7 @@ package http import ( "bytes" + "context" "encoding/json" "io" "net/http" @@ -33,19 +34,24 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH return &PostureChecksHandler{ accountManager: &mock_server.MockAccountManager{ - GetPostureChecksFunc: func(accountID, postureChecksID, userID string) (*posture.Checks, error) { + GetPostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { p, ok := testPostureChecks[postureChecksID] if !ok { return nil, status.Errorf(status.NotFound, "posture checks not found") } return p, nil }, - SavePostureChecksFunc: func(accountID, userID string, postureChecks *posture.Checks) error { + SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) error { postureChecks.ID = "postureCheck" testPostureChecks[postureChecks.ID] = postureChecks + + if err := postureChecks.Validate(); err != nil { + return status.Errorf(status.InvalidArgument, err.Error()) + } + return nil }, - DeletePostureChecksFunc: func(accountID, postureChecksID, userID string) error { + DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error { _, ok := testPostureChecks[postureChecksID] if !ok { return status.Errorf(status.NotFound, "posture checks not found") @@ -54,14 +60,14 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH return nil }, - ListPostureChecksFunc: func(accountID, userID string) ([]*posture.Checks, error) { + ListPostureChecksFunc: func(_ context.Context, accountID, userID string) ([]*posture.Checks, error) { accountPostureChecks := make([]*posture.Checks, len(testPostureChecks)) for _, p := range testPostureChecks { accountPostureChecks = append(accountPostureChecks, p) } return accountPostureChecks, nil }, - GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { user := server.NewAdminUser("test_user") return &server.Account{ Id: claims.AccountId, @@ -433,6 +439,45 @@ func TestPostureCheckUpdate(t *testing.T) { handler.geolocationManager = nil }, }, + { + name: "Create Posture Checks Process Check", + requestType: http.MethodPost, + requestPath: "/api/posture-checks", + requestBody: bytes.NewBuffer( + []byte(`{ + "name": "default", + "description": "default", + "checks": { + "process_check": { + "processes": [ + { + "linux_path": "/usr/local/bin/netbird", + "mac_path": "/Applications/NetBird.app/Contents/MacOS/netbird", + "windows_path": "C:\\ProgramData\\NetBird\\netbird.exe" + } + ] + } + } + }`)), + expectedStatus: http.StatusOK, + expectedBody: true, + expectedPostureCheck: &api.PostureCheck{ + Id: "postureCheck", + Name: "default", + Description: str("default"), + Checks: api.Checks{ + ProcessCheck: &api.ProcessCheck{ + Processes: []api.Process{ + { + LinuxPath: str("/usr/local/bin/netbird"), + MacPath: str("/Applications/NetBird.app/Contents/MacOS/netbird"), + WindowsPath: str("C:\\ProgramData\\NetBird\\netbird.exe"), + }, + }, + }, + }, + }, + }, { name: "Create Posture Checks Invalid Check", requestType: http.MethodPost, @@ -446,7 +491,7 @@ func TestPostureCheckUpdate(t *testing.T) { } } }`)), - expectedStatus: http.StatusBadRequest, + expectedStatus: http.StatusUnprocessableEntity, expectedBody: false, }, { @@ -461,7 +506,7 @@ func TestPostureCheckUpdate(t *testing.T) { } } }`)), - expectedStatus: http.StatusBadRequest, + expectedStatus: http.StatusUnprocessableEntity, expectedBody: false, }, { @@ -475,7 +520,7 @@ func TestPostureCheckUpdate(t *testing.T) { "nb_version_check": {} } }`)), - expectedStatus: http.StatusBadRequest, + expectedStatus: http.StatusUnprocessableEntity, expectedBody: false, }, { @@ -489,7 +534,7 @@ func TestPostureCheckUpdate(t *testing.T) { "geo_location_check": {} } }`)), - expectedStatus: http.StatusBadRequest, + expectedStatus: http.StatusUnprocessableEntity, expectedBody: false, }, { @@ -663,11 +708,8 @@ func TestPostureCheckUpdate(t *testing.T) { } } }`)), - expectedStatus: http.StatusBadRequest, + expectedStatus: http.StatusUnprocessableEntity, expectedBody: false, - setupHandlerFunc: func(handler *PostureChecksHandler) { - handler.geolocationManager = nil - }, }, { name: "Update Posture Checks Invalid Check", @@ -682,7 +724,7 @@ func TestPostureCheckUpdate(t *testing.T) { } } }`)), - expectedStatus: http.StatusBadRequest, + expectedStatus: http.StatusUnprocessableEntity, expectedBody: false, }, { @@ -697,7 +739,7 @@ func TestPostureCheckUpdate(t *testing.T) { } } }`)), - expectedStatus: http.StatusBadRequest, + expectedStatus: http.StatusUnprocessableEntity, expectedBody: false, }, { @@ -711,7 +753,7 @@ func TestPostureCheckUpdate(t *testing.T) { "nb_version_check": {} } }`)), - expectedStatus: http.StatusBadRequest, + expectedStatus: http.StatusUnprocessableEntity, expectedBody: false, }, { @@ -841,100 +883,3 @@ func TestPostureCheckUpdate(t *testing.T) { }) } } - -func TestPostureCheck_validatePostureChecksUpdate(t *testing.T) { - // empty name - err := validatePostureChecksUpdate(api.PostureCheckUpdate{}) - assert.Error(t, err) - - // empty checks - err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default"}) - assert.Error(t, err) - err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{}}) - assert.Error(t, err) - - // not valid NbVersionCheck - nbVersionCheck := api.NBVersionCheck{} - err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{NbVersionCheck: &nbVersionCheck}}) - assert.Error(t, err) - - // valid NbVersionCheck - nbVersionCheck = api.NBVersionCheck{MinVersion: "1.0"} - err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{NbVersionCheck: &nbVersionCheck}}) - assert.NoError(t, err) - - // not valid OsVersionCheck - osVersionCheck := api.OSVersionCheck{} - err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}}) - assert.Error(t, err) - - // not valid OsVersionCheck - osVersionCheck = api.OSVersionCheck{Linux: &api.MinKernelVersionCheck{}} - err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}}) - assert.Error(t, err) - - // not valid OsVersionCheck - osVersionCheck = api.OSVersionCheck{Linux: &api.MinKernelVersionCheck{}, Darwin: &api.MinVersionCheck{MinVersion: "14.2"}} - err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}}) - assert.Error(t, err) - - // valid OsVersionCheck - osVersionCheck = api.OSVersionCheck{Linux: &api.MinKernelVersionCheck{MinKernelVersion: "6.0"}} - err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}}) - assert.NoError(t, err) - - // valid OsVersionCheck - osVersionCheck = api.OSVersionCheck{ - Linux: &api.MinKernelVersionCheck{MinKernelVersion: "6.0"}, - Darwin: &api.MinVersionCheck{MinVersion: "14.2"}, - } - err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}}) - assert.NoError(t, err) - - // valid peer network range check - peerNetworkRangeCheck := api.PeerNetworkRangeCheck{ - Action: api.PeerNetworkRangeCheckActionAllow, - Ranges: []string{ - "192.168.1.0/24", "10.0.0.0/8", - }, - } - err = validatePostureChecksUpdate( - api.PostureCheckUpdate{ - Name: "Default", - Checks: &api.Checks{ - PeerNetworkRangeCheck: &peerNetworkRangeCheck, - }, - }, - ) - assert.NoError(t, err) - - // invalid peer network range check - peerNetworkRangeCheck = api.PeerNetworkRangeCheck{ - Action: api.PeerNetworkRangeCheckActionDeny, - Ranges: []string{}, - } - err = validatePostureChecksUpdate( - api.PostureCheckUpdate{ - Name: "Default", - Checks: &api.Checks{ - PeerNetworkRangeCheck: &peerNetworkRangeCheck, - }, - }, - ) - assert.Error(t, err) - - // invalid peer network range check - peerNetworkRangeCheck = api.PeerNetworkRangeCheck{ - Action: "unknownAction", - Ranges: []string{}, - } - err = validatePostureChecksUpdate( - api.PostureCheckUpdate{ - Name: "Default", - Checks: &api.Checks{ - PeerNetworkRangeCheck: &peerNetworkRangeCheck, - }, - }, - ) - assert.Error(t, err) -} diff --git a/management/server/http/routes_handler.go b/management/server/http/routes_handler.go index f755e7a16..18c347334 100644 --- a/management/server/http/routes_handler.go +++ b/management/server/http/routes_handler.go @@ -2,11 +2,16 @@ package http import ( "encoding/json" + "fmt" "net/http" + "net/netip" + "regexp" + "strings" "unicode/utf8" "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" @@ -15,6 +20,9 @@ import ( "github.com/netbirdio/netbird/route" ) +const maxDomains = 32 +const failedToConvertRoute = "failed to convert route to response: %v" + // RoutesHandler is the routes handler of the account type RoutesHandler struct { accountManager server.AccountManager @@ -35,31 +43,36 @@ func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ro // GetAllRoutes returns the list of routes for the account func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - routes, err := h.accountManager.ListRoutes(account.Id, user.Id) + routes, err := h.accountManager.ListRoutes(r.Context(), account.Id, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } apiRoutes := make([]*api.Route, 0) - for _, r := range routes { - apiRoutes = append(apiRoutes, toRouteResponse(r)) + for _, route := range routes { + route, err := toRouteResponse(route) + if err != nil { + util.WriteError(r.Context(), status.Errorf(status.Internal, failedToConvertRoute, err), w) + return + } + apiRoutes = append(apiRoutes, route) } - util.WriteJSONObject(w, apiRoutes) + util.WriteJSONObject(r.Context(), w, apiRoutes) } // CreateRoute handles route creation request func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -70,16 +83,28 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { return } - _, newPrefix, err := route.ParseNetwork(req.Network) - if err != nil { - util.WriteError(err, w) + if err := h.validateRoute(req); err != nil { + util.WriteError(r.Context(), err, w) return } - if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" { - util.WriteError(status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", - route.MaxNetIDChar), w) - return + var domains domain.List + var networkType route.NetworkType + var newPrefix netip.Prefix + if req.Domains != nil { + d, err := validateDomains(*req.Domains) + if err != nil { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w) + return + } + domains = d + networkType = route.DomainNetwork + } else if req.Network != nil { + networkType, newPrefix, err = route.ParseNetwork(*req.Network) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } } peerId := "" @@ -87,57 +112,78 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { peerId = *req.Peer } - peerGroupIds := []string{} + var peerGroupIds []string if req.PeerGroups != nil { peerGroupIds = *req.PeerGroups } - if (peerId != "" && len(peerGroupIds) > 0) || (peerId == "" && len(peerGroupIds) == 0) { - util.WriteError(status.Errorf(status.InvalidArgument, "only one peer or peer_groups should be provided"), w) - return - } - - // do not allow non Linux peers + // Do not allow non-Linux peers if peer := account.GetPeer(peerId); peer != nil { if peer.Meta.GoOS != "linux" { - util.WriteError(status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes"), w) return } } - newRoute, err := h.accountManager.CreateRoute( - account.Id, newPrefix.String(), peerId, peerGroupIds, - req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, - ) + newRoute, err := h.accountManager.CreateRoute(r.Context(), account.Id, newPrefix, networkType, domains, peerId, peerGroupIds, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.KeepRoute) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - resp := toRouteResponse(newRoute) + routes, err := toRouteResponse(newRoute) + if err != nil { + util.WriteError(r.Context(), status.Errorf(status.Internal, failedToConvertRoute, err), w) + return + } - util.WriteJSONObject(w, &resp) + util.WriteJSONObject(r.Context(), w, routes) +} + +func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) error { + if req.Network != nil && req.Domains != nil { + return status.Errorf(status.InvalidArgument, "only one of 'network' or 'domains' should be provided") + } + + if req.Network == nil && req.Domains == nil { + return status.Errorf(status.InvalidArgument, "either 'network' or 'domains' should be provided") + } + + if req.Peer == nil && req.PeerGroups == nil { + return status.Errorf(status.InvalidArgument, "either 'peer' or 'peers_group' should be provided") + } + + if req.Peer != nil && req.PeerGroups != nil { + return status.Errorf(status.InvalidArgument, "only one of 'peer' or 'peer_groups' should be provided") + } + + if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" { + return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d characters", + route.MaxNetIDChar) + } + + return nil } // UpdateRoute handles update to a route identified by a given ID func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) routeID := vars["routeId"] if len(routeID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w) return } - _, err = h.accountManager.GetRoute(account.Id, route.ID(routeID), user.Id) + _, err = h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -148,26 +194,8 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { return } - prefixType, newPrefix, err := route.ParseNetwork(req.Network) - if err != nil { - util.WriteError(status.Errorf(status.InvalidArgument, "couldn't parse update prefix %s for route ID %s", - req.Network, routeID), w) - return - } - - if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" { - util.WriteError(status.Errorf(status.InvalidArgument, - "identifier should be between 1 and %d", route.MaxNetIDChar), w) - return - } - - if req.Peer != nil && req.PeerGroups != nil { - util.WriteError(status.Errorf(status.InvalidArgument, "only peer or peers_group should be provided"), w) - return - } - - if req.Peer == nil && req.PeerGroups == nil { - util.WriteError(status.Errorf(status.InvalidArgument, "either peer or peers_group should be provided"), w) + if err := h.validateRoute(req); err != nil { + util.WriteError(r.Context(), err, w) return } @@ -179,21 +207,36 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { // do not allow non Linux peers if peer := account.GetPeer(peerID); peer != nil { if peer.Meta.GoOS != "linux" { - util.WriteError(status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w) return } } newRoute := &route.Route{ ID: route.ID(routeID), - Network: newPrefix, NetID: route.NetID(req.NetworkId), - NetworkType: prefixType, Masquerade: req.Masquerade, Metric: req.Metric, Description: req.Description, Enabled: req.Enabled, Groups: req.Groups, + KeepRoute: req.KeepRoute, + } + + if req.Domains != nil { + d, err := validateDomains(*req.Domains) + if err != nil { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w) + return + } + newRoute.Domains = d + newRoute.NetworkType = route.DomainNetwork + } else if req.Network != nil { + newRoute.NetworkType, newRoute.Network, err = route.ParseNetwork(*req.Network) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } } if req.Peer != nil { @@ -204,81 +247,129 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { newRoute.PeerGroups = *req.PeerGroups } - err = h.accountManager.SaveRoute(account.Id, user.Id, newRoute) + err = h.accountManager.SaveRoute(r.Context(), account.Id, user.Id, newRoute) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - resp := toRouteResponse(newRoute) + routes, err := toRouteResponse(newRoute) + if err != nil { + util.WriteError(r.Context(), status.Errorf(status.Internal, failedToConvertRoute, err), w) + return + } - util.WriteJSONObject(w, &resp) + util.WriteJSONObject(r.Context(), w, routes) } // DeleteRoute handles route deletion request func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } routeID := mux.Vars(r)["routeId"] if len(routeID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w) return } - err = h.accountManager.DeleteRoute(account.Id, route.ID(routeID), user.Id) + err = h.accountManager.DeleteRoute(r.Context(), account.Id, route.ID(routeID), user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, emptyObject{}) } // GetRoute handles a route Get request identified by ID func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } routeID := mux.Vars(r)["routeId"] if len(routeID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w) return } - foundRoute, err := h.accountManager.GetRoute(account.Id, route.ID(routeID), user.Id) + foundRoute, err := h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id) if err != nil { - util.WriteError(status.Errorf(status.NotFound, "route not found"), w) + util.WriteError(r.Context(), status.Errorf(status.NotFound, "route not found"), w) return } - util.WriteJSONObject(w, toRouteResponse(foundRoute)) + routes, err := toRouteResponse(foundRoute) + if err != nil { + util.WriteError(r.Context(), status.Errorf(status.Internal, failedToConvertRoute, err), w) + return + } + + util.WriteJSONObject(r.Context(), w, routes) } -func toRouteResponse(serverRoute *route.Route) *api.Route { +func toRouteResponse(serverRoute *route.Route) (*api.Route, error) { + domains, err := serverRoute.Domains.ToStringList() + if err != nil { + return nil, err + } + network := serverRoute.Network.String() route := &api.Route{ Id: string(serverRoute.ID), Description: serverRoute.Description, NetworkId: string(serverRoute.NetID), Enabled: serverRoute.Enabled, Peer: &serverRoute.Peer, - Network: serverRoute.Network.String(), + Network: &network, + Domains: &domains, NetworkType: serverRoute.NetworkType.String(), Masquerade: serverRoute.Masquerade, Metric: serverRoute.Metric, Groups: serverRoute.Groups, + KeepRoute: serverRoute.KeepRoute, } if len(serverRoute.PeerGroups) > 0 { route.PeerGroups = &serverRoute.PeerGroups } - return route + return route, nil +} + +// validateDomains checks if each domain in the list is valid and returns a punycode-encoded DomainList. +func validateDomains(domains []string) (domain.List, error) { + if len(domains) == 0 { + return nil, fmt.Errorf("domains list is empty") + } + if len(domains) > maxDomains { + return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains) + } + + domainRegex := regexp.MustCompile(`^(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`) + + var domainList domain.List + + for _, d := range domains { + d := strings.ToLower(d) + + // handles length and idna conversion + punycode, err := domain.FromString(d) + if err != nil { + return domainList, fmt.Errorf("failed to convert domain to punycode: %s: %v", d, err) + } + + if !domainRegex.MatchString(string(punycode)) { + return domainList, fmt.Errorf("invalid domain format: %s", d) + } + + domainList = append(domainList, punycode) + } + return domainList, nil } diff --git a/management/server/http/routes_handler_test.go b/management/server/http/routes_handler_test.go index 1c8288d5f..40075eb9d 100644 --- a/management/server/http/routes_handler_test.go +++ b/management/server/http/routes_handler_test.go @@ -2,6 +2,7 @@ package http import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -10,6 +11,8 @@ import ( "net/netip" "testing" + "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/server/http/api" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" @@ -18,6 +21,7 @@ import ( "github.com/gorilla/mux" "github.com/magiconair/properties/assert" + "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" @@ -26,6 +30,7 @@ import ( const ( existingRouteID = "existingRouteID" existingRouteID2 = "existingRouteID2" // for peer_groups test + existingRouteID3 = "existingRouteID3" // for domains test notFoundRouteID = "notFoundRouteID" existingPeerIP1 = "100.64.0.100" existingPeerIP2 = "100.64.0.101" @@ -35,6 +40,7 @@ const ( testAccountID = "test_id" existingGroupID = "testGroup" notFoundGroupID = "nonExistingGroup" + existingDomain = "example.com" ) var emptyString = "" @@ -46,6 +52,8 @@ var baseExistingRoute = &route.Route{ Description: "base route", NetID: "awesomeNet", Network: netip.MustParsePrefix("192.168.0.0/24"), + Domains: domain.List{}, + KeepRoute: false, NetworkType: route.IPv4Network, Metric: 9999, Masquerade: false, @@ -82,7 +90,7 @@ var testingAccount = &server.Account{ func initRoutesTestData() *RoutesHandler { return &RoutesHandler{ accountManager: &mock_server.MockAccountManager{ - GetRouteFunc: func(_ string, routeID route.ID, _ string) (*route.Route, error) { + GetRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) (*route.Route, error) { if routeID == existingRouteID { return baseExistingRoute, nil } @@ -90,43 +98,48 @@ func initRoutesTestData() *RoutesHandler { route := baseExistingRoute.Copy() route.PeerGroups = []string{existingGroupID} return route, nil + } else if routeID == existingRouteID3 { + route := baseExistingRoute.Copy() + route.Domains = domain.List{existingDomain} + return route, nil } return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID) }, - CreateRouteFunc: func(accountID, network, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string) (*route.Route, error) { + CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) { if peerID == notFoundPeerID { return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) } if len(peerGroups) > 0 && peerGroups[0] == notFoundGroupID { return nil, status.Errorf(status.InvalidArgument, "peer groups with ID %s not found", peerGroups[0]) } - networkType, p, _ := route.ParseNetwork(network) return &route.Route{ ID: existingRouteID, NetID: netID, Peer: peerID, PeerGroups: peerGroups, - Network: p, + Network: prefix, + Domains: domains, NetworkType: networkType, Description: description, Masquerade: masquerade, Enabled: enabled, Groups: groups, + KeepRoute: keepRoute, }, nil }, - SaveRouteFunc: func(_, _ string, r *route.Route) error { + SaveRouteFunc: func(_ context.Context, _, _ string, r *route.Route) error { if r.Peer == notFoundPeerID { return status.Errorf(status.InvalidArgument, "peer with ID %s not found", r.Peer) } return nil }, - DeleteRouteFunc: func(_ string, routeID route.ID, _ string) error { + DeleteRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) error { if routeID != existingRouteID { return status.Errorf(status.NotFound, "Peer with ID %s not found", routeID) } return nil }, - GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { return testingAccount, testingAccount.Users["test_user"], nil }, }, @@ -146,6 +159,9 @@ func TestRoutesHandlers(t *testing.T) { baseExistingRouteWithPeerGroups := baseExistingRoute.Copy() baseExistingRouteWithPeerGroups.PeerGroups = []string{existingGroupID} + baseExistingRouteWithDomains := baseExistingRoute.Copy() + baseExistingRouteWithDomains.Domains = domain.List{existingDomain} + tt := []struct { name string expectedStatus int @@ -161,7 +177,7 @@ func TestRoutesHandlers(t *testing.T) { requestPath: "/api/routes/" + existingRouteID, expectedStatus: http.StatusOK, expectedBody: true, - expectedRoute: toRouteResponse(baseExistingRoute), + expectedRoute: toApiRoute(t, baseExistingRoute), }, { name: "Get Not Existing Route", @@ -175,7 +191,15 @@ func TestRoutesHandlers(t *testing.T) { requestPath: "/api/routes/" + existingRouteID2, expectedStatus: http.StatusOK, expectedBody: true, - expectedRoute: toRouteResponse(baseExistingRouteWithPeerGroups), + expectedRoute: toApiRoute(t, baseExistingRouteWithPeerGroups), + }, + { + name: "Get Existing Route with Domains", + requestType: http.MethodGet, + requestPath: "/api/routes/" + existingRouteID3, + expectedStatus: http.StatusOK, + expectedBody: true, + expectedRoute: toApiRoute(t, baseExistingRouteWithDomains), }, { name: "Delete Existing Route", @@ -191,18 +215,18 @@ func TestRoutesHandlers(t *testing.T) { expectedStatus: http.StatusNotFound, }, { - name: "POST OK", + name: "Network POST OK", requestType: http.MethodPost, requestPath: "/api/routes", requestBody: bytes.NewBuffer( - []byte(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"]}", existingPeerID, existingGroupID))), + []byte(fmt.Sprintf(`{"Description":"Post","Network":"192.168.0.0/16","network_id":"awesomeNet","Peer":"%s","groups":["%s"]}`, existingPeerID, existingGroupID))), expectedStatus: http.StatusOK, expectedBody: true, expectedRoute: &api.Route{ Id: existingRouteID, Description: "Post", NetworkId: "awesomeNet", - Network: "192.168.0.0/16", + Network: toPtr("192.168.0.0/16"), Peer: &existingPeerID, NetworkType: route.IPv4NetworkString, Masquerade: false, @@ -210,6 +234,28 @@ func TestRoutesHandlers(t *testing.T) { Groups: []string{existingGroupID}, }, }, + { + name: "Domains POST OK", + requestType: http.MethodPost, + requestPath: "/api/routes", + requestBody: bytes.NewBuffer( + []byte(fmt.Sprintf(`{"description":"Post","domains":["example.com"],"network_id":"domainNet","peer":"%s","groups":["%s"],"keep_route":true}`, existingPeerID, existingGroupID))), + expectedStatus: http.StatusOK, + expectedBody: true, + expectedRoute: &api.Route{ + Id: existingRouteID, + Description: "Post", + NetworkId: "domainNet", + Network: toPtr("invalid Prefix"), + KeepRoute: true, + Domains: &[]string{existingDomain}, + Peer: &existingPeerID, + NetworkType: route.DomainNetworkString, + Masquerade: false, + Enabled: false, + Groups: []string{existingGroupID}, + }, + }, { name: "POST Non Linux Peer", requestType: http.MethodPost, @@ -242,6 +288,32 @@ func TestRoutesHandlers(t *testing.T) { expectedStatus: http.StatusUnprocessableEntity, expectedBody: false, }, + { + name: "POST Invalid Domains", + requestType: http.MethodPost, + requestPath: "/api/routes", + requestBody: bytes.NewBufferString(fmt.Sprintf(`{"Description":"Post","domains":["-example.com"],"network_id":"awesomeNet","Peer":"%s","groups":["%s"]}`, existingPeerID, existingGroupID)), + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: false, + }, + { + name: "POST UnprocessableEntity when both network and domains are provided", + requestType: http.MethodPost, + requestPath: "/api/routes", + requestBody: bytes.NewBuffer( + []byte(fmt.Sprintf(`{"Description":"Post","Network":"192.168.0.0/16","domains":["example.com"],"network_id":"awesomeNet","peer":"%s","peer_groups":["%s"],"groups":["%s"]}`, existingPeerID, existingGroupID, existingGroupID))), + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: false, + }, + { + name: "POST UnprocessableEntity when no network and domains are provided", + requestType: http.MethodPost, + requestPath: "/api/routes", + requestBody: bytes.NewBuffer( + []byte(fmt.Sprintf(`{"Description":"Post","network_id":"awesomeNet","groups":["%s"]}`, existingPeerID))), + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: false, + }, { name: "POST UnprocessableEntity when both peer and peer_groups are provided", requestType: http.MethodPost, @@ -261,7 +333,7 @@ func TestRoutesHandlers(t *testing.T) { expectedBody: false, }, { - name: "PUT OK", + name: "Network PUT OK", requestType: http.MethodPut, requestPath: "/api/routes/" + existingRouteID, requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"]}", existingPeerID, existingGroupID)), @@ -271,7 +343,7 @@ func TestRoutesHandlers(t *testing.T) { Id: existingRouteID, Description: "Post", NetworkId: "awesomeNet", - Network: "192.168.0.0/16", + Network: toPtr("192.168.0.0/16"), Peer: &existingPeerID, NetworkType: route.IPv4NetworkString, Masquerade: false, @@ -279,6 +351,27 @@ func TestRoutesHandlers(t *testing.T) { Groups: []string{existingGroupID}, }, }, + { + name: "Domains PUT OK", + requestType: http.MethodPut, + requestPath: "/api/routes/" + existingRouteID, + requestBody: bytes.NewBufferString(fmt.Sprintf(`{"Description":"Post","domains":["example.com"],"network_id":"awesomeNet","Peer":"%s","groups":["%s"],"keep_route":true}`, existingPeerID, existingGroupID)), + expectedStatus: http.StatusOK, + expectedBody: true, + expectedRoute: &api.Route{ + Id: existingRouteID, + Description: "Post", + NetworkId: "awesomeNet", + Network: toPtr("invalid Prefix"), + Domains: &[]string{existingDomain}, + Peer: &existingPeerID, + NetworkType: route.DomainNetworkString, + Masquerade: false, + Enabled: false, + Groups: []string{existingGroupID}, + KeepRoute: true, + }, + }, { name: "PUT OK when peer_groups provided", requestType: http.MethodPut, @@ -290,7 +383,7 @@ func TestRoutesHandlers(t *testing.T) { Id: existingRouteID, Description: "Post", NetworkId: "awesomeNet", - Network: "192.168.0.0/16", + Network: toPtr("192.168.0.0/16"), Peer: &emptyString, PeerGroups: &[]string{existingGroupID}, NetworkType: route.IPv4NetworkString, @@ -339,6 +432,33 @@ func TestRoutesHandlers(t *testing.T) { expectedStatus: http.StatusUnprocessableEntity, expectedBody: false, }, + { + name: "PUT Invalid Domains", + requestType: http.MethodPut, + requestPath: "/api/routes/" + existingRouteID, + requestBody: bytes.NewBuffer( + []byte(fmt.Sprintf(`{"Description":"Post","domains":["-example.com"],"network_id":"awesomeNet","peer":"%s","peer_groups":["%s"],"groups":["%s"]}`, existingPeerID, existingGroupID, existingGroupID))), + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: false, + }, + { + name: "PUT UnprocessableEntity when both network and domains are provided", + requestType: http.MethodPut, + requestPath: "/api/routes/" + existingRouteID, + requestBody: bytes.NewBuffer( + []byte(fmt.Sprintf(`{"Description":"Post","Network":"192.168.0.0/16","domains":["example.com"],"network_id":"awesomeNet","peer":"%s","peer_groups":["%s"],"groups":["%s"]}`, existingPeerID, existingGroupID, existingGroupID))), + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: false, + }, + { + name: "PUT UnprocessableEntity when no network and domains are provided", + requestType: http.MethodPut, + requestPath: "/api/routes/" + existingRouteID, + requestBody: bytes.NewBuffer( + []byte(fmt.Sprintf(`{"Description":"Post","network_id":"awesomeNet","peer":"%s","peer_groups":["%s"],"groups":["%s"]}`, existingPeerID, existingGroupID, existingGroupID))), + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: false, + }, { name: "PUT UnprocessableEntity when both peer and peer_groups are provided", requestType: http.MethodPut, @@ -399,3 +519,85 @@ func TestRoutesHandlers(t *testing.T) { }) } } + +func TestValidateDomains(t *testing.T) { + tests := []struct { + name string + domains []string + expected domain.List + wantErr bool + }{ + { + name: "Empty list", + domains: nil, + expected: nil, + wantErr: true, + }, + { + name: "Valid ASCII domain", + domains: []string{"sub.ex-ample.com"}, + expected: domain.List{"sub.ex-ample.com"}, + wantErr: false, + }, + { + name: "Valid Unicode domain", + domains: []string{"münchen.de"}, + expected: domain.List{"xn--mnchen-3ya.de"}, + wantErr: false, + }, + { + name: "Valid Unicode, all labels", + domains: []string{"中国.中国.中国"}, + expected: domain.List{"xn--fiqs8s.xn--fiqs8s.xn--fiqs8s"}, + wantErr: false, + }, + { + name: "With underscores", + domains: []string{"_jabber._tcp.gmail.com"}, + expected: domain.List{"_jabber._tcp.gmail.com"}, + wantErr: false, + }, + { + name: "Invalid domain format", + domains: []string{"-example.com"}, + expected: nil, + wantErr: true, + }, + { + name: "Invalid domain format 2", + domains: []string{"example.com-"}, + expected: nil, + wantErr: true, + }, + { + name: "Multiple domains valid and invalid", + domains: []string{"google.com", "invalid,nbdomain.com", "münchen.de"}, + expected: domain.List{"google.com"}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := validateDomains(tt.domains) + assert.Equal(t, tt.wantErr, err != nil) + assert.Equal(t, got, tt.expected) + }) + } +} + +func toApiRoute(t *testing.T, r *route.Route) *api.Route { + t.Helper() + + apiRoute, err := toRouteResponse(r) + // json flattens pointer to nil slices to null + if apiRoute.Domains != nil && *apiRoute.Domains == nil { + apiRoute.Domains = nil + } + require.NoError(t, err, "Failed to convert route") + return apiRoute +} + +func toPtr[T any](v T) *T { + return &v +} diff --git a/management/server/http/setupkeys_handler.go b/management/server/http/setupkeys_handler.go index 5faedea13..8ee7dfaba 100644 --- a/management/server/http/setupkeys_handler.go +++ b/management/server/http/setupkeys_handler.go @@ -1,6 +1,7 @@ package http import ( + "context" "encoding/json" "net/http" "time" @@ -34,9 +35,9 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg) // CreateSetupKey is a POST requests that creates a new SetupKey func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -48,13 +49,13 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request } if req.Name == "" { - util.WriteError(status.Errorf(status.InvalidArgument, "setup key name shouldn't be empty"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key name shouldn't be empty"), w) return } if !(server.SetupKeyType(req.Type) == server.SetupKeyReusable || server.SetupKeyType(req.Type) == server.SetupKeyOneOff) { - util.WriteError(status.Errorf(status.InvalidArgument, "unknown setup key type %s", req.Type), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown setup key type %s", req.Type), w) return } @@ -63,7 +64,7 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request day := time.Hour * 24 year := day * 365 if expiresIn < day || expiresIn > year { - util.WriteError(status.Errorf(status.InvalidArgument, "expiresIn should be between 1 day and 365 days"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "expiresIn should be between 1 day and 365 days"), w) return } @@ -75,54 +76,54 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request if req.Ephemeral != nil { ephemeral = *req.Ephemeral } - setupKey, err := h.accountManager.CreateSetupKey(account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn, + setupKey, err := h.accountManager.CreateSetupKey(r.Context(), account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn, req.AutoGroups, req.UsageLimit, user.Id, ephemeral) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - writeSuccess(w, setupKey) + writeSuccess(r.Context(), w, setupKey) } // GetSetupKey is a GET request to get a SetupKey by ID func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) keyID := vars["keyId"] if len(keyID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid key ID"), w) return } - key, err := h.accountManager.GetSetupKey(account.Id, user.Id, keyID) + key, err := h.accountManager.GetSetupKey(r.Context(), account.Id, user.Id, keyID) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - writeSuccess(w, key) + writeSuccess(r.Context(), w, key) } // UpdateSetupKey is a PUT request to update server.SetupKey func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) keyID := vars["keyId"] if len(keyID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid key ID"), w) return } @@ -134,12 +135,12 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request } if req.Name == "" { - util.WriteError(status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w) return } if req.AutoGroups == nil { - util.WriteError(status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w) return } @@ -149,26 +150,26 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request newKey.Name = req.Name newKey.Id = keyID - newKey, err = h.accountManager.SaveSetupKey(account.Id, newKey, user.Id) + newKey, err = h.accountManager.SaveSetupKey(r.Context(), account.Id, newKey, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - writeSuccess(w, newKey) + writeSuccess(r.Context(), w, newKey) } // GetAllSetupKeys is a GET request that returns a list of SetupKey func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - setupKeys, err := h.accountManager.ListSetupKeys(account.Id, user.Id) + setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), account.Id, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -177,15 +178,15 @@ func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Reques apiSetupKeys = append(apiSetupKeys, toResponseBody(key)) } - util.WriteJSONObject(w, apiSetupKeys) + util.WriteJSONObject(r.Context(), w, apiSetupKeys) } -func writeSuccess(w http.ResponseWriter, key *server.SetupKey) { +func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupKey) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(200) err := json.NewEncoder(w).Encode(toResponseBody(key)) if err != nil { - util.WriteError(err, w) + util.WriteError(ctx, err, w) return } } diff --git a/management/server/http/setupkeys_handler_test.go b/management/server/http/setupkeys_handler_test.go index ebbd5954f..bfa0ec008 100644 --- a/management/server/http/setupkeys_handler_test.go +++ b/management/server/http/setupkeys_handler_test.go @@ -2,6 +2,7 @@ package http import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -33,7 +34,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup ) *SetupKeysHandler { return &SetupKeysHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { return &server.Account{ Id: testAccountID, Domain: "hotmail.com", @@ -49,7 +50,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup }, }, user, nil }, - CreateSetupKeyFunc: func(_ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string, + CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string, _ int, _ string, ephemeral bool, ) (*server.SetupKey, error) { if keyName == newKey.Name || typ != newKey.Type { @@ -59,7 +60,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup } return nil, fmt.Errorf("failed creating setup key") }, - GetSetupKeyFunc: func(accountID, userID, keyID string) (*server.SetupKey, error) { + GetSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) (*server.SetupKey, error) { switch keyID { case defaultKey.Id: return defaultKey, nil @@ -70,14 +71,14 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup } }, - SaveSetupKeyFunc: func(accountID string, key *server.SetupKey, _ string) (*server.SetupKey, error) { + SaveSetupKeyFunc: func(_ context.Context, accountID string, key *server.SetupKey, _ string) (*server.SetupKey, error) { if key.Id == updatedSetupKey.Id { return updatedSetupKey, nil } return nil, status.Errorf(status.NotFound, "key %s not found", key.Id) }, - ListSetupKeysFunc: func(accountID, userID string) ([]*server.SetupKey, error) { + ListSetupKeysFunc: func(_ context.Context, accountID, userID string) ([]*server.SetupKey, error) { return []*server.SetupKey{defaultKey}, nil }, }, diff --git a/management/server/http/users_handler.go b/management/server/http/users_handler.go index 531822668..2c2aed842 100644 --- a/management/server/http/users_handler.go +++ b/management/server/http/users_handler.go @@ -41,22 +41,22 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) userID := vars["userId"] if len(userID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } existingUser, ok := account.Users[userID] if !ok { - util.WriteError(status.Errorf(status.NotFound, "couldn't find user with ID %s", userID), w) + util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find user with ID %s", userID), w) return } @@ -74,11 +74,11 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { userRole := server.StrRoleToUserRole(req.Role) if userRole == server.UserRoleUnknown { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid user role"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user role"), w) return } - newUser, err := h.accountManager.SaveUser(account.Id, user.Id, &server.User{ + newUser, err := h.accountManager.SaveUser(r.Context(), account.Id, user.Id, &server.User{ Id: userID, Role: userRole, AutoGroups: req.AutoGroups, @@ -88,10 +88,10 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { }) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, toUserResponse(newUser, claims.UserId)) + util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId)) } // DeleteUser is a DELETE request to delete a user @@ -102,26 +102,26 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } - err = h.accountManager.DeleteUser(account.Id, user.Id, targetUserID) + err = h.accountManager.DeleteUser(r.Context(), account.Id, user.Id, targetUserID) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, emptyObject{}) } // CreateUser creates a User in the system with a status "invited" (effectively this is a user invite). @@ -132,9 +132,9 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -146,7 +146,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { } if server.StrRoleToUserRole(req.Role) == server.UserRoleUnknown { - util.WriteError(status.Errorf(status.InvalidArgument, "unknown user role %s", req.Role), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown user role %s", req.Role), w) return } @@ -160,7 +160,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { name = *req.Name } - newUser, err := h.accountManager.CreateUser(account.Id, user.Id, &server.UserInfo{ + newUser, err := h.accountManager.CreateUser(r.Context(), account.Id, user.Id, &server.UserInfo{ Email: email, Name: name, Role: req.Role, @@ -169,10 +169,10 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { Issued: server.UserIssuedAPI, }) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, toUserResponse(newUser, claims.UserId)) + util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId)) } // GetAllUsers returns a list of users of the account this user belongs to. @@ -184,42 +184,42 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - data, err := h.accountManager.GetUsersFromAccount(account.Id, user.Id) + data, err := h.accountManager.GetUsersFromAccount(r.Context(), account.Id, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } serviceUser := r.URL.Query().Get("service_user") users := make([]*api.User, 0) - for _, r := range data { - if r.NonDeletable { + for _, d := range data { + if d.NonDeletable { continue } if serviceUser == "" { - users = append(users, toUserResponse(r, claims.UserId)) + users = append(users, toUserResponse(d, claims.UserId)) continue } includeServiceUser, err := strconv.ParseBool(serviceUser) - log.Debugf("Should include service user: %v", includeServiceUser) + log.WithContext(r.Context()).Debugf("Should include service user: %v", includeServiceUser) if err != nil { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid service_user query parameter"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid service_user query parameter"), w) return } - if includeServiceUser == r.IsServiceUser { - users = append(users, toUserResponse(r, claims.UserId)) + if includeServiceUser == d.IsServiceUser { + users = append(users, toUserResponse(d, claims.UserId)) } } - util.WriteJSONObject(w, users) + util.WriteJSONObject(r.Context(), w, users) } // InviteUser resend invitations to users who haven't activated their accounts, @@ -231,26 +231,26 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } - err = h.accountManager.InviteUser(account.Id, user.Id, targetUserID) + err = h.accountManager.InviteUser(r.Context(), account.Id, user.Id, targetUserID) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, emptyObject{}) } func toUserResponse(user *server.UserInfo, currenUserID string) *api.User { diff --git a/management/server/http/users_handler_test.go b/management/server/http/users_handler_test.go index 91f19d8d8..a78ac3a4e 100644 --- a/management/server/http/users_handler_test.go +++ b/management/server/http/users_handler_test.go @@ -2,6 +2,7 @@ package http import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -27,7 +28,7 @@ const ( var usersTestAccount = &server.Account{ Id: existingAccountID, - Domain: domain, + Domain: testDomain, Users: map[string]*server.User{ existingUserID: { Id: existingUserID, @@ -63,10 +64,10 @@ var usersTestAccount = &server.Account{ func initUsersTestData() *UsersHandler { return &UsersHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { return usersTestAccount, usersTestAccount.Users[claims.UserId], nil }, - GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) { + GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) { users := make([]*server.UserInfo, 0) for _, v := range usersTestAccount.Users { users = append(users, &server.UserInfo{ @@ -81,13 +82,13 @@ func initUsersTestData() *UsersHandler { } return users, nil }, - CreateUserFunc: func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) { + CreateUserFunc: func(_ context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) { if userID != existingUserID { return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID) } return key, nil }, - DeleteUserFunc: func(accountID string, initiatorUserID string, targetUserID string) error { + DeleteUserFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) error { if targetUserID == notFoundUserID { return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID) } @@ -96,7 +97,7 @@ func initUsersTestData() *UsersHandler { } return nil }, - SaveUserFunc: func(accountID, userID string, update *server.User) (*server.UserInfo, error) { + SaveUserFunc: func(_ context.Context, accountID, userID string, update *server.User) (*server.UserInfo, error) { if update.Id == notFoundUserID { return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", update.Id) } @@ -111,7 +112,7 @@ func initUsersTestData() *UsersHandler { } return info, nil }, - InviteUserFunc: func(accountID string, initiatorUserID string, targetUserID string) error { + InviteUserFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) error { if initiatorUserID != existingUserID { return status.Errorf(status.NotFound, "user with ID %s does not exists", initiatorUserID) } @@ -127,7 +128,7 @@ func initUsersTestData() *UsersHandler { jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { return jwtclaims.AuthorizationClaims{ UserId: existingUserID, - Domain: domain, + Domain: testDomain, AccountId: existingAccountID, } }), diff --git a/management/server/http/util/util.go b/management/server/http/util/util.go index acaa2838c..603c1c696 100644 --- a/management/server/http/util/util.go +++ b/management/server/http/util/util.go @@ -1,6 +1,7 @@ package util import ( + "context" "encoding/json" "errors" "fmt" @@ -19,12 +20,12 @@ type ErrorResponse struct { } // WriteJSONObject simply writes object to the HTTP response in JSON format -func WriteJSONObject(w http.ResponseWriter, obj interface{}) { +func WriteJSONObject(ctx context.Context, w http.ResponseWriter, obj interface{}) { w.Header().Set("Content-Type", "application/json; charset=UTF-8") w.WriteHeader(http.StatusOK) err := json.NewEncoder(w).Encode(obj) if err != nil { - WriteError(err, w) + WriteError(ctx, err, w) return } } @@ -76,8 +77,8 @@ func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) { // WriteError converts an error to an JSON error response. // If it is known internal error of type server.Error then it sets the messages from the error, a generic message otherwise -func WriteError(err error, w http.ResponseWriter) { - log.Errorf("got a handler error: %s", err.Error()) +func WriteError(ctx context.Context, err error, w http.ResponseWriter) { + log.WithContext(ctx).Errorf("got a handler error: %s", err.Error()) errStatus, ok := status.FromError(err) httpStatus := http.StatusInternalServerError msg := "internal server error" @@ -106,7 +107,7 @@ func WriteError(err error, w http.ResponseWriter) { msg = strings.ToLower(err.Error()) } else { unhandledMSG := fmt.Sprintf("got unhandled error code, error: %s", err.Error()) - log.Error(unhandledMSG) + log.WithContext(ctx).Error(unhandledMSG) } WriteErrorResponse(msg, httpStatus, w) diff --git a/management/server/idp/auth0.go b/management/server/idp/auth0.go index 34a5c0de5..497f1944f 100644 --- a/management/server/idp/auth0.go +++ b/management/server/idp/auth0.go @@ -183,7 +183,7 @@ func (c *Auth0Credentials) jwtStillValid() bool { } // requestJWTToken performs request to get jwt token -func (c *Auth0Credentials) requestJWTToken() (*http.Response, error) { +func (c *Auth0Credentials) requestJWTToken(ctx context.Context) (*http.Response, error) { var res *http.Response reqURL := c.clientConfig.AuthIssuer + "/oauth/token" @@ -200,7 +200,7 @@ func (c *Auth0Credentials) requestJWTToken() (*http.Response, error) { req.Header.Add("content-type", "application/json") - log.Debug("requesting new jwt token for idp manager") + log.WithContext(ctx).Debug("requesting new jwt token for idp manager") res, err = c.httpClient.Do(req) if err != nil { @@ -247,7 +247,7 @@ func (c *Auth0Credentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTTo } // Authenticate retrieves access token to use the Auth0 Management API -func (c *Auth0Credentials) Authenticate() (JWTToken, error) { +func (c *Auth0Credentials) Authenticate(ctx context.Context) (JWTToken, error) { c.mux.Lock() defer c.mux.Unlock() @@ -260,14 +260,14 @@ func (c *Auth0Credentials) Authenticate() (JWTToken, error) { return c.jwtToken, nil } - res, err := c.requestJWTToken() + res, err := c.requestJWTToken(ctx) if err != nil { return c.jwtToken, err } defer func() { err = res.Body.Close() if err != nil { - log.Errorf("error while closing get jwt token response body: %v", err) + log.WithContext(ctx).Errorf("error while closing get jwt token response body: %v", err) } }() @@ -301,8 +301,8 @@ func requestByUserIDURL(authIssuer, userID string) string { } // GetAccount returns all the users for a given profile. Calls Auth0 API. -func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) { - jwtToken, err := am.credentials.Authenticate() +func (am *Auth0Manager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) { + jwtToken, err := am.credentials.Authenticate(ctx) if err != nil { return nil, err } @@ -353,7 +353,7 @@ func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) { return nil, err } - log.Debugf("returned user batch for accountID %s on page %d, batch length %d", accountID, page, len(batch)) + log.WithContext(ctx).Debugf("returned user batch for accountID %s on page %d, batch length %d", accountID, page, len(batch)) err = res.Body.Close() if err != nil { @@ -365,7 +365,7 @@ func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) { } if len(batch) == 0 || len(batch) < resultsPerPage { - log.Debugf("finished loading users for accountID %s", accountID) + log.WithContext(ctx).Debugf("finished loading users for accountID %s", accountID) return list, nil } } @@ -374,8 +374,8 @@ func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) { } // GetUserDataByID requests user data from auth0 via ID -func (am *Auth0Manager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { - jwtToken, err := am.credentials.Authenticate() +func (am *Auth0Manager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) { + jwtToken, err := am.credentials.Authenticate(ctx) if err != nil { return nil, err } @@ -414,7 +414,7 @@ func (am *Auth0Manager) GetUserDataByID(userID string, appMetadata AppMetadata) defer func() { err = res.Body.Close() if err != nil { - log.Errorf("error while closing update user app metadata response body: %v", err) + log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err) } }() @@ -426,9 +426,9 @@ func (am *Auth0Manager) GetUserDataByID(userID string, appMetadata AppMetadata) } // UpdateUserAppMetadata updates user app metadata based on userId and metadata map -func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error { +func (am *Auth0Manager) UpdateUserAppMetadata(ctx context.Context, userID string, appMetadata AppMetadata) error { - jwtToken, err := am.credentials.Authenticate() + jwtToken, err := am.credentials.Authenticate(ctx) if err != nil { return err } @@ -449,7 +449,7 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) req.Header.Add("content-type", "application/json") - log.Debugf("updating IdP metadata for user %s", userID) + log.WithContext(ctx).Debugf("updating IdP metadata for user %s", userID) res, err := am.httpClient.Do(req) if err != nil { @@ -466,7 +466,7 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta defer func() { err = res.Body.Close() if err != nil { - log.Errorf("error while closing update user app metadata response body: %v", err) + log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err) } }() @@ -530,9 +530,9 @@ func buildUserExportRequest() (string, error) { } func (am *Auth0Manager) createRequest( - method string, endpoint string, body io.Reader, + ctx context.Context, method string, endpoint string, body io.Reader, ) (*http.Request, error) { - jwtToken, err := am.credentials.Authenticate() + jwtToken, err := am.credentials.Authenticate(ctx) if err != nil { return nil, err } @@ -548,8 +548,8 @@ func (am *Auth0Manager) createRequest( return req, nil } -func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*http.Request, error) { - req, err := am.createRequest("POST", endpoint, strings.NewReader(payloadStr)) +func (am *Auth0Manager) createPostRequest(ctx context.Context, endpoint string, payloadStr string) (*http.Request, error) { + req, err := am.createRequest(ctx, "POST", endpoint, strings.NewReader(payloadStr)) if err != nil { return nil, err } @@ -560,20 +560,20 @@ func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (* // GetAllAccounts gets all registered accounts with corresponding user data. // It returns a list of users indexed by accountID. -func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) { +func (am *Auth0Manager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) { payloadString, err := buildUserExportRequest() if err != nil { return nil, err } - exportJobReq, err := am.createPostRequest("/api/v2/jobs/users-exports", payloadString) + exportJobReq, err := am.createPostRequest(ctx, "/api/v2/jobs/users-exports", payloadString) if err != nil { return nil, err } jobResp, err := am.httpClient.Do(exportJobReq) if err != nil { - log.Debugf("Couldn't get job response %v", err) + log.WithContext(ctx).Debugf("Couldn't get job response %v", err) if am.appMetrics != nil { am.appMetrics.IDPMetrics().CountRequestError() } @@ -583,7 +583,7 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) { defer func() { err = jobResp.Body.Close() if err != nil { - log.Errorf("error while closing update user app metadata response body: %v", err) + log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err) } }() if jobResp.StatusCode != 200 { @@ -597,13 +597,13 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) { body, err := io.ReadAll(jobResp.Body) if err != nil { - log.Debugf("Couldn't read export job response; %v", err) + log.WithContext(ctx).Debugf("Couldn't read export job response; %v", err) return nil, err } err = am.helper.Unmarshal(body, &exportJobResp) if err != nil { - log.Debugf("Couldn't unmarshal export job response; %v", err) + log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err) return nil, err } @@ -614,16 +614,16 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) { return nil, fmt.Errorf("couldn't get an batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp) } - log.Debugf("batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp) + log.WithContext(ctx).Debugf("batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp) - done, downloadLink, err := am.checkExportJobStatus(exportJobResp.ID) + done, downloadLink, err := am.checkExportJobStatus(ctx, exportJobResp.ID) if err != nil { - log.Debugf("Failed at getting status checks from exportJob; %v", err) + log.WithContext(ctx).Debugf("Failed at getting status checks from exportJob; %v", err) return nil, err } if done { - return am.downloadProfileExport(downloadLink) + return am.downloadProfileExport(ctx, downloadLink) } return nil, fmt.Errorf("failed extracting user profiles from auth0") @@ -632,13 +632,13 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) { // GetUserByEmail searches users with a given email. If no users have been found, this function returns an empty list. // This function can return multiple users. This is due to the Auth0 internals - there could be multiple users with // the same email but different connections that are considered as separate accounts (e.g., Google and username/password). -func (am *Auth0Manager) GetUserByEmail(email string) ([]*UserData, error) { - jwtToken, err := am.credentials.Authenticate() +func (am *Auth0Manager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) { + jwtToken, err := am.credentials.Authenticate(ctx) if err != nil { return nil, err } reqURL := am.authIssuer + "/api/v2/users-by-email?email=" + url.QueryEscape(email) - body, err := doGetReq(am.httpClient, reqURL, jwtToken.AccessToken) + body, err := doGetReq(ctx, am.httpClient, reqURL, jwtToken.AccessToken) if err != nil { return nil, err } @@ -651,7 +651,7 @@ func (am *Auth0Manager) GetUserByEmail(email string) ([]*UserData, error) { err = am.helper.Unmarshal(body, &userResp) if err != nil { - log.Debugf("Couldn't unmarshal export job response; %v", err) + log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err) return nil, err } @@ -659,13 +659,13 @@ func (am *Auth0Manager) GetUserByEmail(email string) ([]*UserData, error) { } // CreateUser creates a new user in Auth0 Idp and sends an invite -func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { +func (am *Auth0Manager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) { payloadString, err := buildCreateUserRequestPayload(email, name, accountID, invitedByEmail) if err != nil { return nil, err } - req, err := am.createPostRequest("/api/v2/users", payloadString) + req, err := am.createPostRequest(ctx, "/api/v2/users", payloadString) if err != nil { return nil, err } @@ -676,7 +676,7 @@ func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string resp, err := am.httpClient.Do(req) if err != nil { - log.Debugf("Couldn't get job response %v", err) + log.WithContext(ctx).Debugf("Couldn't get job response %v", err) if am.appMetrics != nil { am.appMetrics.IDPMetrics().CountRequestError() } @@ -686,7 +686,7 @@ func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string defer func() { err = resp.Body.Close() if err != nil { - log.Errorf("error while closing create user response body: %v", err) + log.WithContext(ctx).Errorf("error while closing create user response body: %v", err) } }() if !(resp.StatusCode == 200 || resp.StatusCode == 201) { @@ -700,13 +700,13 @@ func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string body, err := io.ReadAll(resp.Body) if err != nil { - log.Debugf("Couldn't read export job response; %v", err) + log.WithContext(ctx).Debugf("Couldn't read export job response; %v", err) return nil, err } err = am.helper.Unmarshal(body, &createResp) if err != nil { - log.Debugf("Couldn't unmarshal export job response; %v", err) + log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err) return nil, err } @@ -714,14 +714,14 @@ func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string return nil, fmt.Errorf("couldn't create user: response %v", resp) } - log.Debugf("created user %s in account %s", createResp.ID, accountID) + log.WithContext(ctx).Debugf("created user %s in account %s", createResp.ID, accountID) return &createResp, nil } // InviteUserByID resend invitations to users who haven't activated, // their accounts prior to the expiration period. -func (am *Auth0Manager) InviteUserByID(userID string) error { +func (am *Auth0Manager) InviteUserByID(ctx context.Context, userID string) error { userVerificationReq := userVerificationJobRequest{ UserID: userID, } @@ -731,14 +731,14 @@ func (am *Auth0Manager) InviteUserByID(userID string) error { return err } - req, err := am.createPostRequest("/api/v2/jobs/verification-email", string(payload)) + req, err := am.createPostRequest(ctx, "/api/v2/jobs/verification-email", string(payload)) if err != nil { return err } resp, err := am.httpClient.Do(req) if err != nil { - log.Debugf("Couldn't get job response %v", err) + log.WithContext(ctx).Debugf("Couldn't get job response %v", err) if am.appMetrics != nil { am.appMetrics.IDPMetrics().CountRequestError() } @@ -748,7 +748,7 @@ func (am *Auth0Manager) InviteUserByID(userID string) error { defer func() { err = resp.Body.Close() if err != nil { - log.Errorf("error while closing invite user response body: %v", err) + log.WithContext(ctx).Errorf("error while closing invite user response body: %v", err) } }() if !(resp.StatusCode == 200 || resp.StatusCode == 201) { @@ -762,15 +762,15 @@ func (am *Auth0Manager) InviteUserByID(userID string) error { } // DeleteUser from Auth0 -func (am *Auth0Manager) DeleteUser(userID string) error { - req, err := am.createRequest(http.MethodDelete, "/api/v2/users/"+url.QueryEscape(userID), nil) +func (am *Auth0Manager) DeleteUser(ctx context.Context, userID string) error { + req, err := am.createRequest(ctx, http.MethodDelete, "/api/v2/users/"+url.QueryEscape(userID), nil) if err != nil { return err } resp, err := am.httpClient.Do(req) if err != nil { - log.Debugf("execute delete request: %v", err) + log.WithContext(ctx).Debugf("execute delete request: %v", err) if am.appMetrics != nil { am.appMetrics.IDPMetrics().CountRequestError() } @@ -780,7 +780,7 @@ func (am *Auth0Manager) DeleteUser(userID string) error { defer func() { err = resp.Body.Close() if err != nil { - log.Errorf("close delete request body: %v", err) + log.WithContext(ctx).Errorf("close delete request body: %v", err) } }() if resp.StatusCode != 204 { @@ -795,20 +795,20 @@ func (am *Auth0Manager) DeleteUser(userID string) error { // GetAllConnections returns detailed list of all connections filtered by given params. // Note this method is not part of the IDP Manager interface as this is Auth0 specific. -func (am *Auth0Manager) GetAllConnections(strategy []string) ([]Connection, error) { +func (am *Auth0Manager) GetAllConnections(ctx context.Context, strategy []string) ([]Connection, error) { var connections []Connection q := make(url.Values) q.Set("strategy", strings.Join(strategy, ",")) - req, err := am.createRequest(http.MethodGet, "/api/v2/connections?"+q.Encode(), nil) + req, err := am.createRequest(ctx, http.MethodGet, "/api/v2/connections?"+q.Encode(), nil) if err != nil { return connections, err } resp, err := am.httpClient.Do(req) if err != nil { - log.Debugf("execute get connections request: %v", err) + log.WithContext(ctx).Debugf("execute get connections request: %v", err) if am.appMetrics != nil { am.appMetrics.IDPMetrics().CountRequestError() } @@ -818,7 +818,7 @@ func (am *Auth0Manager) GetAllConnections(strategy []string) ([]Connection, erro defer func() { err = resp.Body.Close() if err != nil { - log.Errorf("close get connections request body: %v", err) + log.WithContext(ctx).Errorf("close get connections request body: %v", err) } }() if resp.StatusCode != 200 { @@ -830,13 +830,13 @@ func (am *Auth0Manager) GetAllConnections(strategy []string) ([]Connection, erro body, err := io.ReadAll(resp.Body) if err != nil { - log.Debugf("Couldn't read get connections response; %v", err) + log.WithContext(ctx).Debugf("Couldn't read get connections response; %v", err) return connections, err } err = am.helper.Unmarshal(body, &connections) if err != nil { - log.Debugf("Couldn't unmarshal get connection response; %v", err) + log.WithContext(ctx).Debugf("Couldn't unmarshal get connection response; %v", err) return connections, err } @@ -845,23 +845,23 @@ func (am *Auth0Manager) GetAllConnections(strategy []string) ([]Connection, erro // checkExportJobStatus checks the status of the job created at CreateExportUsersJob. // If the status is "completed", then return the downloadLink -func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) { - ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second) +func (am *Auth0Manager) checkExportJobStatus(ctx context.Context, jobID string) (bool, string, error) { + ctx, cancel := context.WithTimeout(ctx, 90*time.Second) defer cancel() retry := time.NewTicker(10 * time.Second) for { select { case <-ctx.Done(): - log.Debugf("Export job status stopped...\n") + log.WithContext(ctx).Debugf("Export job status stopped...\n") return false, "", ctx.Err() case <-retry.C: - jwtToken, err := am.credentials.Authenticate() + jwtToken, err := am.credentials.Authenticate(ctx) if err != nil { return false, "", err } statusURL := am.authIssuer + "/api/v2/jobs/" + jobID - body, err := doGetReq(am.httpClient, statusURL, jwtToken.AccessToken) + body, err := doGetReq(ctx, am.httpClient, statusURL, jwtToken.AccessToken) if err != nil { return false, "", err } @@ -872,7 +872,7 @@ func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) return false, "", err } - log.Debugf("current export job status is %v", status.Status) + log.WithContext(ctx).Debugf("current export job status is %v", status.Status) if status.Status != "completed" { continue @@ -884,8 +884,8 @@ func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) } // downloadProfileExport downloads user profiles from auth0 batch job -func (am *Auth0Manager) downloadProfileExport(location string) (map[string][]*UserData, error) { - body, err := doGetReq(am.httpClient, location, "") +func (am *Auth0Manager) downloadProfileExport(ctx context.Context, location string) (map[string][]*UserData, error) { + body, err := doGetReq(ctx, am.httpClient, location, "") if err != nil { return nil, err } @@ -927,7 +927,7 @@ func (am *Auth0Manager) downloadProfileExport(location string) (map[string][]*Us } // Boilerplate implementation for Get Requests. -func doGetReq(client ManagerHTTPClient, url, accessToken string) ([]byte, error) { +func doGetReq(ctx context.Context, client ManagerHTTPClient, url, accessToken string) ([]byte, error) { req, err := http.NewRequest("GET", url, nil) if err != nil { return nil, err @@ -945,7 +945,7 @@ func doGetReq(client ManagerHTTPClient, url, accessToken string) ([]byte, error) defer func() { err = res.Body.Close() if err != nil { - log.Errorf("error while closing body for url %s: %v", url, err) + log.WithContext(ctx).Errorf("error while closing body for url %s: %v", url, err) } }() body, err := io.ReadAll(res.Body) diff --git a/management/server/idp/auth0_test.go b/management/server/idp/auth0_test.go index febc0ab89..de42ced99 100644 --- a/management/server/idp/auth0_test.go +++ b/management/server/idp/auth0_test.go @@ -1,6 +1,7 @@ package idp import ( + "context" "encoding/json" "fmt" "io" @@ -60,7 +61,7 @@ type mockAuth0Credentials struct { err error } -func (mc *mockAuth0Credentials) Authenticate() (JWTToken, error) { +func (mc *mockAuth0Credentials) Authenticate(_ context.Context) (JWTToken, error) { return mc.jwtToken, mc.err } @@ -126,7 +127,7 @@ func TestAuth0_RequestJWTToken(t *testing.T) { helper: testCase.helper, } - res, err := creds.requestJWTToken() + res, err := creds.requestJWTToken(context.Background()) if err != nil { if testCase.expectedFuncExitErrDiff != nil { assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") @@ -295,7 +296,7 @@ func TestAuth0_Authenticate(t *testing.T) { creds.jwtToken.expiresInTime = testCase.inputExpireToken - _, err := creds.Authenticate() + _, err := creds.Authenticate(context.Background()) if err != nil { if testCase.expectedFuncExitErrDiff != nil { assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") @@ -417,7 +418,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) { helper: testCase.helper, } - err := manager.UpdateUserAppMetadata("1", testCase.appMetadata) + err := manager.UpdateUserAppMetadata(context.Background(), "1", testCase.appMetadata) testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage) assert.Equal(t, testCase.expectedReqBody, jwtReqClient.reqBody, "request body should match") diff --git a/management/server/idp/authentik.go b/management/server/idp/authentik.go index b39f2b5cb..00d30d645 100644 --- a/management/server/idp/authentik.go +++ b/management/server/idp/authentik.go @@ -116,7 +116,7 @@ func (ac *AuthentikCredentials) jwtStillValid() bool { } // requestJWTToken performs request to get jwt token. -func (ac *AuthentikCredentials) requestJWTToken() (*http.Response, error) { +func (ac *AuthentikCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) { data := url.Values{} data.Set("client_id", ac.clientConfig.ClientID) data.Set("username", ac.clientConfig.Username) @@ -131,7 +131,7 @@ func (ac *AuthentikCredentials) requestJWTToken() (*http.Response, error) { } req.Header.Add("content-type", "application/x-www-form-urlencoded") - log.Debug("requesting new jwt token for authentik idp manager") + log.WithContext(ctx).Debug("requesting new jwt token for authentik idp manager") resp, err := ac.httpClient.Do(req) if err != nil { @@ -183,7 +183,7 @@ func (ac *AuthentikCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) ( } // Authenticate retrieves access token to use the authentik management API. -func (ac *AuthentikCredentials) Authenticate() (JWTToken, error) { +func (ac *AuthentikCredentials) Authenticate(ctx context.Context) (JWTToken, error) { ac.mux.Lock() defer ac.mux.Unlock() @@ -197,7 +197,7 @@ func (ac *AuthentikCredentials) Authenticate() (JWTToken, error) { return ac.jwtToken, nil } - resp, err := ac.requestJWTToken() + resp, err := ac.requestJWTToken(ctx) if err != nil { return ac.jwtToken, err } @@ -214,13 +214,13 @@ func (ac *AuthentikCredentials) Authenticate() (JWTToken, error) { } // UpdateUserAppMetadata updates user app metadata based on userID and metadata map. -func (am *AuthentikManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error { +func (am *AuthentikManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error { return nil } // GetUserDataByID requests user data from authentik via ID. -func (am *AuthentikManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { - ctx, err := am.authenticationContext() +func (am *AuthentikManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) { + ctx, err := am.authenticationContext(ctx) if err != nil { return nil, err } @@ -254,8 +254,8 @@ func (am *AuthentikManager) GetUserDataByID(userID string, appMetadata AppMetada } // GetAccount returns all the users for a given profile. -func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) { - users, err := am.getAllUsers() +func (am *AuthentikManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) { + users, err := am.getAllUsers(ctx) if err != nil { return nil, err } @@ -274,8 +274,8 @@ func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) { // GetAllAccounts gets all registered accounts with corresponding user data. // It returns a list of users indexed by accountID. -func (am *AuthentikManager) GetAllAccounts() (map[string][]*UserData, error) { - users, err := am.getAllUsers() +func (am *AuthentikManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) { + users, err := am.getAllUsers(ctx) if err != nil { return nil, err } @@ -291,12 +291,12 @@ func (am *AuthentikManager) GetAllAccounts() (map[string][]*UserData, error) { } // getAllUsers returns all users in a Authentik account. -func (am *AuthentikManager) getAllUsers() ([]*UserData, error) { +func (am *AuthentikManager) getAllUsers(ctx context.Context) ([]*UserData, error) { users := make([]*UserData, 0) page := int32(1) for { - ctx, err := am.authenticationContext() + ctx, err := am.authenticationContext(ctx) if err != nil { return nil, err } @@ -329,14 +329,14 @@ func (am *AuthentikManager) getAllUsers() ([]*UserData, error) { } // CreateUser creates a new user in authentik Idp and sends an invitation. -func (am *AuthentikManager) CreateUser(_, _, _, _ string) (*UserData, error) { +func (am *AuthentikManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) { return nil, fmt.Errorf("method CreateUser not implemented") } // GetUserByEmail searches users with a given email. // If no users have been found, this function returns an empty list. -func (am *AuthentikManager) GetUserByEmail(email string) ([]*UserData, error) { - ctx, err := am.authenticationContext() +func (am *AuthentikManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) { + ctx, err := am.authenticationContext(ctx) if err != nil { return nil, err } @@ -368,13 +368,13 @@ func (am *AuthentikManager) GetUserByEmail(email string) ([]*UserData, error) { // InviteUserByID resend invitations to users who haven't activated, // their accounts prior to the expiration period. -func (am *AuthentikManager) InviteUserByID(_ string) error { +func (am *AuthentikManager) InviteUserByID(_ context.Context, _ string) error { return fmt.Errorf("method InviteUserByID not implemented") } // DeleteUser from Authentik -func (am *AuthentikManager) DeleteUser(userID string) error { - ctx, err := am.authenticationContext() +func (am *AuthentikManager) DeleteUser(ctx context.Context, userID string) error { + ctx, err := am.authenticationContext(ctx) if err != nil { return err } @@ -404,8 +404,8 @@ func (am *AuthentikManager) DeleteUser(userID string) error { return nil } -func (am *AuthentikManager) authenticationContext() (context.Context, error) { - jwtToken, err := am.credentials.Authenticate() +func (am *AuthentikManager) authenticationContext(ctx context.Context) (context.Context, error) { + jwtToken, err := am.credentials.Authenticate(ctx) if err != nil { return nil, err } diff --git a/management/server/idp/authentik_test.go b/management/server/idp/authentik_test.go index 342e16384..029acdce3 100644 --- a/management/server/idp/authentik_test.go +++ b/management/server/idp/authentik_test.go @@ -1,6 +1,7 @@ package idp import ( + "context" "fmt" "io" "strings" @@ -138,7 +139,7 @@ func TestAuthentikRequestJWTToken(t *testing.T) { helper: testCase.helper, } - resp, err := creds.requestJWTToken() + resp, err := creds.requestJWTToken(context.Background()) if err != nil { if testCase.expectedFuncExitErrDiff != nil { assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") @@ -304,7 +305,7 @@ func TestAuthentikAuthenticate(t *testing.T) { } creds.jwtToken.expiresInTime = testCase.inputExpireToken - _, err := creds.Authenticate() + _, err := creds.Authenticate(context.Background()) if err != nil { if testCase.expectedFuncExitErrDiff != nil { assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") diff --git a/management/server/idp/azure.go b/management/server/idp/azure.go index 2f21b3b54..35b86764d 100644 --- a/management/server/idp/azure.go +++ b/management/server/idp/azure.go @@ -1,6 +1,7 @@ package idp import ( + "context" "fmt" "io" "net/http" @@ -110,7 +111,7 @@ func (ac *AzureCredentials) jwtStillValid() bool { } // requestJWTToken performs request to get jwt token. -func (ac *AzureCredentials) requestJWTToken() (*http.Response, error) { +func (ac *AzureCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) { data := url.Values{} data.Set("client_id", ac.clientConfig.ClientID) data.Set("client_secret", ac.clientConfig.ClientSecret) @@ -132,7 +133,7 @@ func (ac *AzureCredentials) requestJWTToken() (*http.Response, error) { } req.Header.Add("content-type", "application/x-www-form-urlencoded") - log.Debug("requesting new jwt token for azure idp manager") + log.WithContext(ctx).Debug("requesting new jwt token for azure idp manager") resp, err := ac.httpClient.Do(req) if err != nil { @@ -184,7 +185,7 @@ func (ac *AzureCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTT } // Authenticate retrieves access token to use the azure Management API. -func (ac *AzureCredentials) Authenticate() (JWTToken, error) { +func (ac *AzureCredentials) Authenticate(ctx context.Context) (JWTToken, error) { ac.mux.Lock() defer ac.mux.Unlock() @@ -198,7 +199,7 @@ func (ac *AzureCredentials) Authenticate() (JWTToken, error) { return ac.jwtToken, nil } - resp, err := ac.requestJWTToken() + resp, err := ac.requestJWTToken(ctx) if err != nil { return ac.jwtToken, err } @@ -215,16 +216,16 @@ func (ac *AzureCredentials) Authenticate() (JWTToken, error) { } // CreateUser creates a new user in azure AD Idp. -func (am *AzureManager) CreateUser(_, _, _, _ string) (*UserData, error) { +func (am *AzureManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) { return nil, fmt.Errorf("method CreateUser not implemented") } // GetUserDataByID requests user data from keycloak via ID. -func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { +func (am *AzureManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) { q := url.Values{} q.Add("$select", profileFields) - body, err := am.get("users/"+userID, q) + body, err := am.get(ctx, "users/"+userID, q) if err != nil { return nil, err } @@ -247,11 +248,11 @@ func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata) // GetUserByEmail searches users with a given email. // If no users have been found, this function returns an empty list. -func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) { +func (am *AzureManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) { q := url.Values{} q.Add("$select", profileFields) - body, err := am.get("users/"+email, q) + body, err := am.get(ctx, "users/"+email, q) if err != nil { return nil, err } @@ -273,8 +274,8 @@ func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) { } // GetAccount returns all the users for a given profile. -func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) { - users, err := am.getAllUsers() +func (am *AzureManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) { + users, err := am.getAllUsers(ctx) if err != nil { return nil, err } @@ -293,8 +294,8 @@ func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) { // GetAllAccounts gets all registered accounts with corresponding user data. // It returns a list of users indexed by accountID. -func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) { - users, err := am.getAllUsers() +func (am *AzureManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) { + users, err := am.getAllUsers(ctx) if err != nil { return nil, err } @@ -310,19 +311,19 @@ func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) { } // UpdateUserAppMetadata updates user app metadata based on userID. -func (am *AzureManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error { +func (am *AzureManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error { return nil } // InviteUserByID resend invitations to users who haven't activated, // their accounts prior to the expiration period. -func (am *AzureManager) InviteUserByID(_ string) error { +func (am *AzureManager) InviteUserByID(_ context.Context, _ string) error { return fmt.Errorf("method InviteUserByID not implemented") } // DeleteUser from Azure. -func (am *AzureManager) DeleteUser(userID string) error { - jwtToken, err := am.credentials.Authenticate() +func (am *AzureManager) DeleteUser(ctx context.Context, userID string) error { + jwtToken, err := am.credentials.Authenticate(ctx) if err != nil { return err } @@ -335,7 +336,7 @@ func (am *AzureManager) DeleteUser(userID string) error { req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) req.Header.Add("content-type", "application/json") - log.Debugf("delete idp user %s", userID) + log.WithContext(ctx).Debugf("delete idp user %s", userID) resp, err := am.httpClient.Do(req) if err != nil { @@ -358,7 +359,7 @@ func (am *AzureManager) DeleteUser(userID string) error { } // getAllUsers returns all users in an Azure AD account. -func (am *AzureManager) getAllUsers() ([]*UserData, error) { +func (am *AzureManager) getAllUsers(ctx context.Context) ([]*UserData, error) { users := make([]*UserData, 0) q := url.Values{} @@ -366,7 +367,7 @@ func (am *AzureManager) getAllUsers() ([]*UserData, error) { q.Add("$top", "500") for nextLink := "users"; nextLink != ""; { - body, err := am.get(nextLink, q) + body, err := am.get(ctx, nextLink, q) if err != nil { return nil, err } @@ -391,8 +392,8 @@ func (am *AzureManager) getAllUsers() ([]*UserData, error) { } // get perform Get requests. -func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) { - jwtToken, err := am.credentials.Authenticate() +func (am *AzureManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) { + jwtToken, err := am.credentials.Authenticate(ctx) if err != nil { return nil, err } diff --git a/management/server/idp/azure_test.go b/management/server/idp/azure_test.go index b4dc96b23..80e85d2b1 100644 --- a/management/server/idp/azure_test.go +++ b/management/server/idp/azure_test.go @@ -1,6 +1,7 @@ package idp import ( + "context" "fmt" "testing" "time" @@ -101,7 +102,7 @@ func TestAzureAuthenticate(t *testing.T) { } creds.jwtToken.expiresInTime = testCase.inputExpireToken - _, err := creds.Authenticate() + _, err := creds.Authenticate(context.Background()) if err != nil { if testCase.expectedFuncExitErrDiff != nil { assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") diff --git a/management/server/idp/google_workspace.go b/management/server/idp/google_workspace.go index 896fb707b..09ea8c430 100644 --- a/management/server/idp/google_workspace.go +++ b/management/server/idp/google_workspace.go @@ -39,12 +39,12 @@ type GoogleWorkspaceCredentials struct { appMetrics telemetry.AppMetrics } -func (gc *GoogleWorkspaceCredentials) Authenticate() (JWTToken, error) { +func (gc *GoogleWorkspaceCredentials) Authenticate(_ context.Context) (JWTToken, error) { return JWTToken{}, nil } // NewGoogleWorkspaceManager creates a new instance of the GoogleWorkspaceManager. -func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics telemetry.AppMetrics) (*GoogleWorkspaceManager, error) { +func NewGoogleWorkspaceManager(ctx context.Context, config GoogleWorkspaceClientConfig, appMetrics telemetry.AppMetrics) (*GoogleWorkspaceManager, error) { httpTransport := http.DefaultTransport.(*http.Transport).Clone() httpTransport.MaxIdleConns = 5 @@ -66,7 +66,7 @@ func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics te } // Create a new Admin SDK Directory service client - adminCredentials, err := getGoogleCredentials(config.ServiceAccountKey) + adminCredentials, err := getGoogleCredentials(ctx, config.ServiceAccountKey) if err != nil { return nil, err } @@ -90,12 +90,12 @@ func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics te } // UpdateUserAppMetadata updates user app metadata based on userID and metadata map. -func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error { +func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error { return nil } // GetUserDataByID requests user data from Google Workspace via ID. -func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { +func (gm *GoogleWorkspaceManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) { user, err := gm.usersService.Get(userID).Do() if err != nil { return nil, err @@ -112,7 +112,7 @@ func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata App } // GetAccount returns all the users for a given profile. -func (gm *GoogleWorkspaceManager) GetAccount(accountID string) ([]*UserData, error) { +func (gm *GoogleWorkspaceManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) { users, err := gm.getAllUsers() if err != nil { return nil, err @@ -132,7 +132,7 @@ func (gm *GoogleWorkspaceManager) GetAccount(accountID string) ([]*UserData, err // GetAllAccounts gets all registered accounts with corresponding user data. // It returns a list of users indexed by accountID. -func (gm *GoogleWorkspaceManager) GetAllAccounts() (map[string][]*UserData, error) { +func (gm *GoogleWorkspaceManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) { users, err := gm.getAllUsers() if err != nil { return nil, err @@ -177,13 +177,13 @@ func (gm *GoogleWorkspaceManager) getAllUsers() ([]*UserData, error) { } // CreateUser creates a new user in Google Workspace and sends an invitation. -func (gm *GoogleWorkspaceManager) CreateUser(_, _, _, _ string) (*UserData, error) { +func (gm *GoogleWorkspaceManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) { return nil, fmt.Errorf("method CreateUser not implemented") } // GetUserByEmail searches users with a given email. // If no users have been found, this function returns an empty list. -func (gm *GoogleWorkspaceManager) GetUserByEmail(email string) ([]*UserData, error) { +func (gm *GoogleWorkspaceManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) { user, err := gm.usersService.Get(email).Do() if err != nil { return nil, err @@ -201,12 +201,12 @@ func (gm *GoogleWorkspaceManager) GetUserByEmail(email string) ([]*UserData, err // InviteUserByID resend invitations to users who haven't activated, // their accounts prior to the expiration period. -func (gm *GoogleWorkspaceManager) InviteUserByID(_ string) error { +func (gm *GoogleWorkspaceManager) InviteUserByID(_ context.Context, _ string) error { return fmt.Errorf("method InviteUserByID not implemented") } // DeleteUser from GoogleWorkspace. -func (gm *GoogleWorkspaceManager) DeleteUser(userID string) error { +func (gm *GoogleWorkspaceManager) DeleteUser(_ context.Context, userID string) error { if err := gm.usersService.Delete(userID).Do(); err != nil { return err } @@ -222,8 +222,8 @@ func (gm *GoogleWorkspaceManager) DeleteUser(userID string) error { // It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it. // If that fails, it falls back to using the default Google credentials path. // It returns the retrieved credentials or an error if unsuccessful. -func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error) { - log.Debug("retrieving google credentials from the base64 encoded service account key") +func getGoogleCredentials(ctx context.Context, serviceAccountKey string) (*google.Credentials, error) { + log.WithContext(ctx).Debug("retrieving google credentials from the base64 encoded service account key") decodeKey, err := base64.StdEncoding.DecodeString(serviceAccountKey) if err != nil { return nil, fmt.Errorf("failed to decode service account key: %w", err) @@ -239,8 +239,8 @@ func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error) return creds, nil } - log.Debugf("failed to retrieve Google credentials from ServiceAccountKey: %v", err) - log.Debug("falling back to default google credentials location") + log.WithContext(ctx).Debugf("failed to retrieve Google credentials from ServiceAccountKey: %v", err) + log.WithContext(ctx).Debug("falling back to default google credentials location") creds, err = google.FindDefaultCredentials( context.Background(), diff --git a/management/server/idp/idp.go b/management/server/idp/idp.go index 7adb76f40..419220942 100644 --- a/management/server/idp/idp.go +++ b/management/server/idp/idp.go @@ -1,6 +1,7 @@ package idp import ( + "context" "fmt" "net/http" "strings" @@ -16,14 +17,14 @@ const ( // Manager idp manager interface type Manager interface { - UpdateUserAppMetadata(userId string, appMetadata AppMetadata) error - GetUserDataByID(userId string, appMetadata AppMetadata) (*UserData, error) - GetAccount(accountId string) ([]*UserData, error) - GetAllAccounts() (map[string][]*UserData, error) - CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) - GetUserByEmail(email string) ([]*UserData, error) - InviteUserByID(userID string) error - DeleteUser(userID string) error + UpdateUserAppMetadata(ctx context.Context, userId string, appMetadata AppMetadata) error + GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) + GetAccount(ctx context.Context, accountId string) ([]*UserData, error) + GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) + CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) + GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) + InviteUserByID(ctx context.Context, userID string) error + DeleteUser(ctx context.Context, userID string) error } // ClientConfig defines common client configuration for all IdP manager @@ -51,7 +52,7 @@ type Config struct { // ManagerCredentials interface that authenticates using the credential of each type of idp type ManagerCredentials interface { - Authenticate() (JWTToken, error) + Authenticate(ctx context.Context) (JWTToken, error) } // ManagerHTTPClient http client interface for API calls @@ -91,7 +92,7 @@ type JWTToken struct { } // NewManager returns a new idp manager based on the configuration that it receives -func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error) { +func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetrics) (Manager, error) { if config.ClientConfig != nil { config.ClientConfig.Issuer = strings.TrimSuffix(config.ClientConfig.Issuer, "/") } @@ -175,7 +176,7 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error) ServiceAccountKey: config.ExtraConfig["ServiceAccountKey"], CustomerID: config.ExtraConfig["CustomerId"], } - return NewGoogleWorkspaceManager(googleClientConfig, appMetrics) + return NewGoogleWorkspaceManager(ctx, googleClientConfig, appMetrics) case "jumpcloud": jumpcloudConfig := JumpCloudClientConfig{ APIToken: config.ExtraConfig["ApiToken"], diff --git a/management/server/idp/jumpcloud.go b/management/server/idp/jumpcloud.go index 0115b4049..6345e424a 100644 --- a/management/server/idp/jumpcloud.go +++ b/management/server/idp/jumpcloud.go @@ -74,7 +74,7 @@ func NewJumpCloudManager(config JumpCloudClientConfig, appMetrics telemetry.AppM } // Authenticate retrieves access token to use the JumpCloud user API. -func (jc *JumpCloudCredentials) Authenticate() (JWTToken, error) { +func (jc *JumpCloudCredentials) Authenticate(_ context.Context) (JWTToken, error) { return JWTToken{}, nil } @@ -85,12 +85,12 @@ func (jm *JumpCloudManager) authenticationContext() context.Context { } // UpdateUserAppMetadata updates user app metadata based on userID and metadata map. -func (jm *JumpCloudManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error { +func (jm *JumpCloudManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error { return nil } // GetUserDataByID requests user data from JumpCloud via ID. -func (jm *JumpCloudManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { +func (jm *JumpCloudManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) { authCtx := jm.authenticationContext() user, resp, err := jm.client.SystemusersApi.SystemusersGet(authCtx, userID, contentType, accept, nil) if err != nil { @@ -116,7 +116,7 @@ func (jm *JumpCloudManager) GetUserDataByID(userID string, appMetadata AppMetada } // GetAccount returns all the users for a given profile. -func (jm *JumpCloudManager) GetAccount(accountID string) ([]*UserData, error) { +func (jm *JumpCloudManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) { authCtx := jm.authenticationContext() userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil) if err != nil { @@ -148,7 +148,7 @@ func (jm *JumpCloudManager) GetAccount(accountID string) ([]*UserData, error) { // GetAllAccounts gets all registered accounts with corresponding user data. // It returns a list of users indexed by accountID. -func (jm *JumpCloudManager) GetAllAccounts() (map[string][]*UserData, error) { +func (jm *JumpCloudManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) { authCtx := jm.authenticationContext() userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil) if err != nil { @@ -177,13 +177,13 @@ func (jm *JumpCloudManager) GetAllAccounts() (map[string][]*UserData, error) { } // CreateUser creates a new user in JumpCloud Idp and sends an invitation. -func (jm *JumpCloudManager) CreateUser(_, _, _, _ string) (*UserData, error) { +func (jm *JumpCloudManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) { return nil, fmt.Errorf("method CreateUser not implemented") } // GetUserByEmail searches users with a given email. // If no users have been found, this function returns an empty list. -func (jm *JumpCloudManager) GetUserByEmail(email string) ([]*UserData, error) { +func (jm *JumpCloudManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) { searchFilter := map[string]interface{}{ "searchFilter": map[string]interface{}{ "filter": []string{email}, @@ -219,12 +219,12 @@ func (jm *JumpCloudManager) GetUserByEmail(email string) ([]*UserData, error) { // InviteUserByID resend invitations to users who haven't activated, // their accounts prior to the expiration period. -func (jm *JumpCloudManager) InviteUserByID(_ string) error { +func (jm *JumpCloudManager) InviteUserByID(_ context.Context, _ string) error { return fmt.Errorf("method InviteUserByID not implemented") } // DeleteUser from jumpCloud directory -func (jm *JumpCloudManager) DeleteUser(userID string) error { +func (jm *JumpCloudManager) DeleteUser(_ context.Context, userID string) error { authCtx := jm.authenticationContext() _, resp, err := jm.client.SystemusersApi.SystemusersDelete(authCtx, userID, contentType, accept, nil) if err != nil { diff --git a/management/server/idp/keycloak.go b/management/server/idp/keycloak.go index 3a6f80d03..07d84058c 100644 --- a/management/server/idp/keycloak.go +++ b/management/server/idp/keycloak.go @@ -1,6 +1,7 @@ package idp import ( + "context" "fmt" "io" "net/http" @@ -109,7 +110,7 @@ func (kc *KeycloakCredentials) jwtStillValid() bool { } // requestJWTToken performs request to get jwt token. -func (kc *KeycloakCredentials) requestJWTToken() (*http.Response, error) { +func (kc *KeycloakCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) { data := url.Values{} data.Set("client_id", kc.clientConfig.ClientID) data.Set("client_secret", kc.clientConfig.ClientSecret) @@ -122,7 +123,7 @@ func (kc *KeycloakCredentials) requestJWTToken() (*http.Response, error) { } req.Header.Add("content-type", "application/x-www-form-urlencoded") - log.Debug("requesting new jwt token for keycloak idp manager") + log.WithContext(ctx).Debug("requesting new jwt token for keycloak idp manager") resp, err := kc.httpClient.Do(req) if err != nil { @@ -174,7 +175,7 @@ func (kc *KeycloakCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (J } // Authenticate retrieves access token to use the keycloak Management API. -func (kc *KeycloakCredentials) Authenticate() (JWTToken, error) { +func (kc *KeycloakCredentials) Authenticate(ctx context.Context) (JWTToken, error) { kc.mux.Lock() defer kc.mux.Unlock() @@ -188,7 +189,7 @@ func (kc *KeycloakCredentials) Authenticate() (JWTToken, error) { return kc.jwtToken, nil } - resp, err := kc.requestJWTToken() + resp, err := kc.requestJWTToken(ctx) if err != nil { return kc.jwtToken, err } @@ -205,18 +206,18 @@ func (kc *KeycloakCredentials) Authenticate() (JWTToken, error) { } // CreateUser creates a new user in keycloak Idp and sends an invite. -func (km *KeycloakManager) CreateUser(_, _, _, _ string) (*UserData, error) { +func (km *KeycloakManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) { return nil, fmt.Errorf("method CreateUser not implemented") } // GetUserByEmail searches users with a given email. // If no users have been found, this function returns an empty list. -func (km *KeycloakManager) GetUserByEmail(email string) ([]*UserData, error) { +func (km *KeycloakManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) { q := url.Values{} q.Add("email", email) q.Add("exact", "true") - body, err := km.get("users", q) + body, err := km.get(ctx, "users", q) if err != nil { return nil, err } @@ -240,8 +241,8 @@ func (km *KeycloakManager) GetUserByEmail(email string) ([]*UserData, error) { } // GetUserDataByID requests user data from keycloak via ID. -func (km *KeycloakManager) GetUserDataByID(userID string, _ AppMetadata) (*UserData, error) { - body, err := km.get("users/"+userID, nil) +func (km *KeycloakManager) GetUserDataByID(ctx context.Context, userID string, _ AppMetadata) (*UserData, error) { + body, err := km.get(ctx, "users/"+userID, nil) if err != nil { return nil, err } @@ -260,8 +261,8 @@ func (km *KeycloakManager) GetUserDataByID(userID string, _ AppMetadata) (*UserD } // GetAccount returns all the users for a given account profile. -func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) { - profiles, err := km.fetchAllUserProfiles() +func (km *KeycloakManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) { + profiles, err := km.fetchAllUserProfiles(ctx) if err != nil { return nil, err } @@ -283,8 +284,8 @@ func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) { // GetAllAccounts gets all registered accounts with corresponding user data. // It returns a list of users indexed by accountID. -func (km *KeycloakManager) GetAllAccounts() (map[string][]*UserData, error) { - profiles, err := km.fetchAllUserProfiles() +func (km *KeycloakManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) { + profiles, err := km.fetchAllUserProfiles(ctx) if err != nil { return nil, err } @@ -303,19 +304,19 @@ func (km *KeycloakManager) GetAllAccounts() (map[string][]*UserData, error) { } // UpdateUserAppMetadata updates user app metadata based on userID and metadata map. -func (km *KeycloakManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error { +func (km *KeycloakManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error { return nil } // InviteUserByID resend invitations to users who haven't activated, // their accounts prior to the expiration period. -func (km *KeycloakManager) InviteUserByID(_ string) error { +func (km *KeycloakManager) InviteUserByID(_ context.Context, _ string) error { return fmt.Errorf("method InviteUserByID not implemented") } // DeleteUser from Keycloak by user ID. -func (km *KeycloakManager) DeleteUser(userID string) error { - jwtToken, err := km.credentials.Authenticate() +func (km *KeycloakManager) DeleteUser(ctx context.Context, userID string) error { + jwtToken, err := km.credentials.Authenticate(ctx) if err != nil { return err } @@ -353,8 +354,8 @@ func (km *KeycloakManager) DeleteUser(userID string) error { return nil } -func (km *KeycloakManager) fetchAllUserProfiles() ([]keycloakProfile, error) { - totalUsers, err := km.totalUsersCount() +func (km *KeycloakManager) fetchAllUserProfiles(ctx context.Context) ([]keycloakProfile, error) { + totalUsers, err := km.totalUsersCount(ctx) if err != nil { return nil, err } @@ -362,7 +363,7 @@ func (km *KeycloakManager) fetchAllUserProfiles() ([]keycloakProfile, error) { q := url.Values{} q.Add("max", fmt.Sprint(*totalUsers)) - body, err := km.get("users", q) + body, err := km.get(ctx, "users", q) if err != nil { return nil, err } @@ -377,8 +378,8 @@ func (km *KeycloakManager) fetchAllUserProfiles() ([]keycloakProfile, error) { } // get perform Get requests. -func (km *KeycloakManager) get(resource string, q url.Values) ([]byte, error) { - jwtToken, err := km.credentials.Authenticate() +func (km *KeycloakManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) { + jwtToken, err := km.credentials.Authenticate(ctx) if err != nil { return nil, err } @@ -414,8 +415,8 @@ func (km *KeycloakManager) get(resource string, q url.Values) ([]byte, error) { // totalUsersCount returns the total count of all user created. // Used when fetching all registered accounts with pagination. -func (km *KeycloakManager) totalUsersCount() (*int, error) { - body, err := km.get("users/count", nil) +func (km *KeycloakManager) totalUsersCount(ctx context.Context) (*int, error) { + body, err := km.get(ctx, "users/count", nil) if err != nil { return nil, err } diff --git a/management/server/idp/keycloak_test.go b/management/server/idp/keycloak_test.go index 9b6c1d3c6..0daca0671 100644 --- a/management/server/idp/keycloak_test.go +++ b/management/server/idp/keycloak_test.go @@ -1,6 +1,7 @@ package idp import ( + "context" "fmt" "io" "strings" @@ -128,7 +129,7 @@ func TestKeycloakRequestJWTToken(t *testing.T) { helper: testCase.helper, } - resp, err := creds.requestJWTToken() + resp, err := creds.requestJWTToken(context.Background()) if err != nil { if testCase.expectedFuncExitErrDiff != nil { assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") @@ -294,7 +295,7 @@ func TestKeycloakAuthenticate(t *testing.T) { } creds.jwtToken.expiresInTime = testCase.inputExpireToken - _, err := creds.Authenticate() + _, err := creds.Authenticate(context.Background()) if err != nil { if testCase.expectedFuncExitErrDiff != nil { assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") diff --git a/management/server/idp/mock.go b/management/server/idp/mock.go index 7605466e7..a07e375bf 100644 --- a/management/server/idp/mock.go +++ b/management/server/idp/mock.go @@ -1,77 +1,79 @@ package idp +import "context" + // MockIDP is a mock implementation of the IDP interface type MockIDP struct { - UpdateUserAppMetadataFunc func(userId string, appMetadata AppMetadata) error - GetUserDataByIDFunc func(userId string, appMetadata AppMetadata) (*UserData, error) - GetAccountFunc func(accountId string) ([]*UserData, error) - GetAllAccountsFunc func() (map[string][]*UserData, error) - CreateUserFunc func(email, name, accountID, invitedByEmail string) (*UserData, error) - GetUserByEmailFunc func(email string) ([]*UserData, error) - InviteUserByIDFunc func(userID string) error - DeleteUserFunc func(userID string) error + UpdateUserAppMetadataFunc func(ctx context.Context, userId string, appMetadata AppMetadata) error + GetUserDataByIDFunc func(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) + GetAccountFunc func(ctx context.Context, accountId string) ([]*UserData, error) + GetAllAccountsFunc func(ctx context.Context) (map[string][]*UserData, error) + CreateUserFunc func(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) + GetUserByEmailFunc func(ctx context.Context, email string) ([]*UserData, error) + InviteUserByIDFunc func(ctx context.Context, userID string) error + DeleteUserFunc func(ctx context.Context, userID string) error } // UpdateUserAppMetadata is a mock implementation of the IDP interface UpdateUserAppMetadata method -func (m *MockIDP) UpdateUserAppMetadata(userId string, appMetadata AppMetadata) error { +func (m *MockIDP) UpdateUserAppMetadata(ctx context.Context, userId string, appMetadata AppMetadata) error { if m.UpdateUserAppMetadataFunc != nil { - return m.UpdateUserAppMetadataFunc(userId, appMetadata) + return m.UpdateUserAppMetadataFunc(ctx, userId, appMetadata) } return nil } // GetUserDataByID is a mock implementation of the IDP interface GetUserDataByID method -func (m *MockIDP) GetUserDataByID(userId string, appMetadata AppMetadata) (*UserData, error) { +func (m *MockIDP) GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) { if m.GetUserDataByIDFunc != nil { - return m.GetUserDataByIDFunc(userId, appMetadata) + return m.GetUserDataByIDFunc(ctx, userId, appMetadata) } return nil, nil } // GetAccount is a mock implementation of the IDP interface GetAccount method -func (m *MockIDP) GetAccount(accountId string) ([]*UserData, error) { +func (m *MockIDP) GetAccount(ctx context.Context, accountId string) ([]*UserData, error) { if m.GetAccountFunc != nil { - return m.GetAccountFunc(accountId) + return m.GetAccountFunc(ctx, accountId) } return nil, nil } // GetAllAccounts is a mock implementation of the IDP interface GetAllAccounts method -func (m *MockIDP) GetAllAccounts() (map[string][]*UserData, error) { +func (m *MockIDP) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) { if m.GetAllAccountsFunc != nil { - return m.GetAllAccountsFunc() + return m.GetAllAccountsFunc(ctx) } return nil, nil } // CreateUser is a mock implementation of the IDP interface CreateUser method -func (m *MockIDP) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { +func (m *MockIDP) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) { if m.CreateUserFunc != nil { - return m.CreateUserFunc(email, name, accountID, invitedByEmail) + return m.CreateUserFunc(ctx, email, name, accountID, invitedByEmail) } return nil, nil } // GetUserByEmail is a mock implementation of the IDP interface GetUserByEmail method -func (m *MockIDP) GetUserByEmail(email string) ([]*UserData, error) { +func (m *MockIDP) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) { if m.GetUserByEmailFunc != nil { - return m.GetUserByEmailFunc(email) + return m.GetUserByEmailFunc(ctx, email) } return nil, nil } // InviteUserByID is a mock implementation of the IDP interface InviteUserByID method -func (m *MockIDP) InviteUserByID(userID string) error { +func (m *MockIDP) InviteUserByID(ctx context.Context, userID string) error { if m.InviteUserByIDFunc != nil { - return m.InviteUserByIDFunc(userID) + return m.InviteUserByIDFunc(ctx, userID) } return nil } // DeleteUser is a mock implementation of the IDP interface DeleteUser method -func (m *MockIDP) DeleteUser(userID string) error { +func (m *MockIDP) DeleteUser(ctx context.Context, userID string) error { if m.DeleteUserFunc != nil { - return m.DeleteUserFunc(userID) + return m.DeleteUserFunc(ctx, userID) } return nil } diff --git a/management/server/idp/okta.go b/management/server/idp/okta.go index c8d33a207..b9cd006be 100644 --- a/management/server/idp/okta.go +++ b/management/server/idp/okta.go @@ -94,17 +94,17 @@ func NewOktaManager(config OktaClientConfig, appMetrics telemetry.AppMetrics) (* } // Authenticate retrieves access token to use the okta user API. -func (oc *OktaCredentials) Authenticate() (JWTToken, error) { +func (oc *OktaCredentials) Authenticate(_ context.Context) (JWTToken, error) { return JWTToken{}, nil } // CreateUser creates a new user in okta Idp and sends an invitation. -func (om *OktaManager) CreateUser(_, _, _, _ string) (*UserData, error) { +func (om *OktaManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) { return nil, fmt.Errorf("method CreateUser not implemented") } // GetUserDataByID requests user data from keycloak via ID. -func (om *OktaManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { +func (om *OktaManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) { user, resp, err := om.client.User.GetUser(context.Background(), userID) if err != nil { return nil, err @@ -132,7 +132,7 @@ func (om *OktaManager) GetUserDataByID(userID string, appMetadata AppMetadata) ( // GetUserByEmail searches users with a given email. // If no users have been found, this function returns an empty list. -func (om *OktaManager) GetUserByEmail(email string) ([]*UserData, error) { +func (om *OktaManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) { user, resp, err := om.client.User.GetUser(context.Background(), url.QueryEscape(email)) if err != nil { return nil, err @@ -160,7 +160,7 @@ func (om *OktaManager) GetUserByEmail(email string) ([]*UserData, error) { } // GetAccount returns all the users for a given profile. -func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) { +func (om *OktaManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) { users, err := om.getAllUsers() if err != nil { return nil, err @@ -180,7 +180,7 @@ func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) { // GetAllAccounts gets all registered accounts with corresponding user data. // It returns a list of users indexed by accountID. -func (om *OktaManager) GetAllAccounts() (map[string][]*UserData, error) { +func (om *OktaManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) { users, err := om.getAllUsers() if err != nil { return nil, err @@ -242,18 +242,18 @@ func (om *OktaManager) getAllUsers() ([]*UserData, error) { } // UpdateUserAppMetadata updates user app metadata based on userID and metadata map. -func (om *OktaManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error { +func (om *OktaManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error { return nil } // InviteUserByID resend invitations to users who haven't activated, // their accounts prior to the expiration period. -func (om *OktaManager) InviteUserByID(_ string) error { +func (om *OktaManager) InviteUserByID(_ context.Context, _ string) error { return fmt.Errorf("method InviteUserByID not implemented") } // DeleteUser from Okta -func (om *OktaManager) DeleteUser(userID string) error { +func (om *OktaManager) DeleteUser(_ context.Context, userID string) error { resp, err := om.client.User.DeactivateOrDeleteUser(context.Background(), userID, nil) if err != nil { return err diff --git a/management/server/idp/zitadel.go b/management/server/idp/zitadel.go index 9021d6752..729b49733 100644 --- a/management/server/idp/zitadel.go +++ b/management/server/idp/zitadel.go @@ -1,6 +1,7 @@ package idp import ( + "context" "fmt" "io" "net/http" @@ -149,7 +150,7 @@ func (zc *ZitadelCredentials) jwtStillValid() bool { } // requestJWTToken performs request to get jwt token. -func (zc *ZitadelCredentials) requestJWTToken() (*http.Response, error) { +func (zc *ZitadelCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) { data := url.Values{} data.Set("client_id", zc.clientConfig.ClientID) data.Set("client_secret", zc.clientConfig.ClientSecret) @@ -163,7 +164,7 @@ func (zc *ZitadelCredentials) requestJWTToken() (*http.Response, error) { } req.Header.Add("content-type", "application/x-www-form-urlencoded") - log.Debug("requesting new jwt token for zitadel idp manager") + log.WithContext(ctx).Debug("requesting new jwt token for zitadel idp manager") resp, err := zc.httpClient.Do(req) if err != nil { @@ -215,7 +216,7 @@ func (zc *ZitadelCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JW } // Authenticate retrieves access token to use the Zitadel Management API. -func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) { +func (zc *ZitadelCredentials) Authenticate(ctx context.Context) (JWTToken, error) { zc.mux.Lock() defer zc.mux.Unlock() @@ -229,7 +230,7 @@ func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) { return zc.jwtToken, nil } - resp, err := zc.requestJWTToken() + resp, err := zc.requestJWTToken(ctx) if err != nil { return zc.jwtToken, err } @@ -246,7 +247,7 @@ func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) { } // CreateUser creates a new user in zitadel Idp and sends an invite via Zitadel. -func (zm *ZitadelManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { +func (zm *ZitadelManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) { firstLast := strings.SplitN(name, " ", 2) var addUser = map[string]any{ @@ -269,7 +270,7 @@ func (zm *ZitadelManager) CreateUser(email, name, accountID, invitedByEmail stri return nil, err } - body, err := zm.post("users/human/_import", string(payload)) + body, err := zm.post(ctx, "users/human/_import", string(payload)) if err != nil { return nil, err } @@ -300,7 +301,7 @@ func (zm *ZitadelManager) CreateUser(email, name, accountID, invitedByEmail stri // GetUserByEmail searches users with a given email. // If no users have been found, this function returns an empty list. -func (zm *ZitadelManager) GetUserByEmail(email string) ([]*UserData, error) { +func (zm *ZitadelManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) { searchByEmail := zitadelAttributes{ "queries": { { @@ -316,7 +317,7 @@ func (zm *ZitadelManager) GetUserByEmail(email string) ([]*UserData, error) { return nil, err } - body, err := zm.post("users/_search", string(payload)) + body, err := zm.post(ctx, "users/_search", string(payload)) if err != nil { return nil, err } @@ -340,8 +341,8 @@ func (zm *ZitadelManager) GetUserByEmail(email string) ([]*UserData, error) { } // GetUserDataByID requests user data from zitadel via ID. -func (zm *ZitadelManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { - body, err := zm.get("users/"+userID, nil) +func (zm *ZitadelManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) { + body, err := zm.get(ctx, "users/"+userID, nil) if err != nil { return nil, err } @@ -363,8 +364,8 @@ func (zm *ZitadelManager) GetUserDataByID(userID string, appMetadata AppMetadata } // GetAccount returns all the users for a given profile. -func (zm *ZitadelManager) GetAccount(accountID string) ([]*UserData, error) { - body, err := zm.post("users/_search", "") +func (zm *ZitadelManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) { + body, err := zm.post(ctx, "users/_search", "") if err != nil { return nil, err } @@ -392,8 +393,8 @@ func (zm *ZitadelManager) GetAccount(accountID string) ([]*UserData, error) { // GetAllAccounts gets all registered accounts with corresponding user data. // It returns a list of users indexed by accountID. -func (zm *ZitadelManager) GetAllAccounts() (map[string][]*UserData, error) { - body, err := zm.post("users/_search", "") +func (zm *ZitadelManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) { + body, err := zm.post(ctx, "users/_search", "") if err != nil { return nil, err } @@ -419,7 +420,7 @@ func (zm *ZitadelManager) GetAllAccounts() (map[string][]*UserData, error) { // UpdateUserAppMetadata updates user app metadata based on userID and metadata map. // Metadata values are base64 encoded. -func (zm *ZitadelManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error { +func (zm *ZitadelManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error { return nil } @@ -429,7 +430,7 @@ type inviteUserRequest struct { // InviteUserByID resend invitations to users who haven't activated, // their accounts prior to the expiration period. -func (zm *ZitadelManager) InviteUserByID(userID string) error { +func (zm *ZitadelManager) InviteUserByID(ctx context.Context, userID string) error { inviteUser := inviteUserRequest{ Email: userID, } @@ -440,14 +441,14 @@ func (zm *ZitadelManager) InviteUserByID(userID string) error { } // don't care about the body in the response - _, err = zm.post(fmt.Sprintf("users/%s/_resend_initialization", userID), string(payload)) + _, err = zm.post(ctx, fmt.Sprintf("users/%s/_resend_initialization", userID), string(payload)) return err } // DeleteUser from Zitadel -func (zm *ZitadelManager) DeleteUser(userID string) error { +func (zm *ZitadelManager) DeleteUser(ctx context.Context, userID string) error { resource := fmt.Sprintf("users/%s", userID) - if err := zm.delete(resource); err != nil { + if err := zm.delete(ctx, resource); err != nil { return err } @@ -459,8 +460,8 @@ func (zm *ZitadelManager) DeleteUser(userID string) error { } // post perform Post requests. -func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) { - jwtToken, err := zm.credentials.Authenticate() +func (zm *ZitadelManager) post(ctx context.Context, resource string, body string) ([]byte, error) { + jwtToken, err := zm.credentials.Authenticate(ctx) if err != nil { return nil, err } @@ -495,8 +496,8 @@ func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) { } // delete perform Delete requests. -func (zm *ZitadelManager) delete(resource string) error { - jwtToken, err := zm.credentials.Authenticate() +func (zm *ZitadelManager) delete(ctx context.Context, resource string) error { + jwtToken, err := zm.credentials.Authenticate(ctx) if err != nil { return err } @@ -531,8 +532,8 @@ func (zm *ZitadelManager) delete(resource string) error { } // get perform Get requests. -func (zm *ZitadelManager) get(resource string, q url.Values) ([]byte, error) { - jwtToken, err := zm.credentials.Authenticate() +func (zm *ZitadelManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) { + jwtToken, err := zm.credentials.Authenticate(ctx) if err != nil { return nil, err } diff --git a/management/server/idp/zitadel_test.go b/management/server/idp/zitadel_test.go index 9a771b36a..6bc612e78 100644 --- a/management/server/idp/zitadel_test.go +++ b/management/server/idp/zitadel_test.go @@ -1,6 +1,7 @@ package idp import ( + "context" "fmt" "io" "strings" @@ -108,7 +109,7 @@ func TestZitadelRequestJWTToken(t *testing.T) { helper: testCase.helper, } - resp, err := creds.requestJWTToken() + resp, err := creds.requestJWTToken(context.Background()) if err != nil { if testCase.expectedFuncExitErrDiff != nil { assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") @@ -274,7 +275,7 @@ func TestZitadelAuthenticate(t *testing.T) { } creds.jwtToken.expiresInTime = testCase.inputExpireToken - _, err := creds.Authenticate() + _, err := creds.Authenticate(context.Background()) if err != nil { if testCase.expectedFuncExitErrDiff != nil { assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 198f8d527..05537ada4 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -1,9 +1,10 @@ package server import ( + "context" "errors" - "github.com/google/martian/v3/log" + log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/account" ) @@ -19,22 +20,22 @@ import ( // // Returns: // - error: An error if any occurred during the process, otherwise returns nil -func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error { - ok, err := am.GroupValidation(accountID, groups) +func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error { + ok, err := am.GroupValidation(ctx, accountID, groups) if err != nil { - log.Debugf("error validating groups: %s", err.Error()) + log.WithContext(ctx).Debugf("error validating groups: %s", err.Error()) return err } if !ok { - log.Debugf("invalid groups") + log.WithContext(ctx).Debugf("invalid groups") return errors.New("invalid groups") } - unlock := am.Store.AcquireAccountWriteLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - a, err := am.Store.GetAccountByUser(userID) + a, err := am.Store.GetAccountByUser(ctx, userID) if err != nil { return err } @@ -48,14 +49,14 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(accountID strin a.Settings.Extra = extra } extra.IntegratedValidatorGroups = groups - return am.Store.SaveAccount(a) + return am.Store.SaveAccount(ctx, a) } -func (am *DefaultAccountManager) GroupValidation(accountId string, groups []string) (bool, error) { +func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) { if len(groups) == 0 { return true, nil } - accountsGroups, err := am.ListGroups(accountId) + accountsGroups, err := am.ListGroups(ctx, accountId) if err != nil { return false, err } diff --git a/management/server/integrated_validator/interface.go b/management/server/integrated_validator/interface.go index ae9698f79..6c9a3e44e 100644 --- a/management/server/integrated_validator/interface.go +++ b/management/server/integrated_validator/interface.go @@ -1,6 +1,8 @@ package integrated_validator import ( + "context" + "github.com/netbirdio/netbird/management/server/account" nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -8,12 +10,12 @@ import ( // IntegratedValidator interface exists to avoid the circle dependencies type IntegratedValidator interface { - ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error - ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) - PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer - IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) + ValidateExtraSettings(ctx context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error + ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) + PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer + IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) - PeerDeleted(accountID, peerID string) error + PeerDeleted(ctx context.Context, accountID, peerID string) error SetPeerInvalidationListener(fn func(accountID string)) - Stop() + Stop(ctx context.Context) } diff --git a/management/server/jwtclaims/jwtValidator.go b/management/server/jwtclaims/jwtValidator.go index f218c1aa9..c3417a769 100644 --- a/management/server/jwtclaims/jwtValidator.go +++ b/management/server/jwtclaims/jwtValidator.go @@ -2,6 +2,7 @@ package jwtclaims import ( "bytes" + "context" "crypto/rsa" "crypto/x509" "encoding/base64" @@ -69,8 +70,8 @@ type JWTValidator struct { } // NewJWTValidator constructor -func NewJWTValidator(issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (*JWTValidator, error) { - keys, err := getPemKeys(keysLocation) +func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (*JWTValidator, error) { + keys, err := getPemKeys(ctx, keysLocation) if err != nil { return nil, err } @@ -102,19 +103,19 @@ func NewJWTValidator(issuer string, audienceList []string, keysLocation string, lock.Lock() defer lock.Unlock() - refreshedKeys, err := getPemKeys(keysLocation) + refreshedKeys, err := getPemKeys(ctx, keysLocation) if err != nil { - log.Debugf("cannot get JSONWebKey: %v, falling back to old keys", err) + log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err) refreshedKeys = keys } - log.Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC()) + log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC()) keys = refreshedKeys } } - cert, err := getPemCert(token, keys) + cert, err := getPemCert(ctx, token, keys) if err != nil { return nil, err } @@ -136,19 +137,19 @@ func NewJWTValidator(issuer string, audienceList []string, keysLocation string, } // ValidateAndParse validates the token and returns the parsed token -func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) { +func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) { // If the token is empty... if token == "" { // Check if it was required if m.options.CredentialsOptional { - log.Debugf("no credentials found (CredentialsOptional=true)") + log.WithContext(ctx).Debugf("no credentials found (CredentialsOptional=true)") // No error, just no token (and that is ok given that CredentialsOptional is true) return nil, nil //nolint:nilnil } // If we get here, the required token is missing errorMsg := "required authorization token not found" - log.Debugf(" Error: No credentials found (CredentialsOptional=false)") + log.WithContext(ctx).Debugf(" Error: No credentials found (CredentialsOptional=false)") return nil, fmt.Errorf(errorMsg) } @@ -157,7 +158,7 @@ func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) { // Check if there was an error in parsing... if err != nil { - log.Errorf("error parsing token: %v", err) + log.WithContext(ctx).Errorf("error parsing token: %v", err) return nil, fmt.Errorf("Error parsing token: %w", err) } @@ -165,14 +166,14 @@ func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) { errorMsg := fmt.Sprintf("Expected %s signing method but token specified %s", m.options.SigningMethod.Alg(), parsedToken.Header["alg"]) - log.Debugf("error validating token algorithm: %s", errorMsg) + log.WithContext(ctx).Debugf("error validating token algorithm: %s", errorMsg) return nil, fmt.Errorf("error validating token algorithm: %s", errorMsg) } // Check if the parsed token is valid... if !parsedToken.Valid { errorMsg := "token is invalid" - log.Debugf(errorMsg) + log.WithContext(ctx).Debugf(errorMsg) return nil, errors.New(errorMsg) } @@ -184,7 +185,7 @@ func (jwks *Jwks) stillValid() bool { return !jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime) } -func getPemKeys(keysLocation string) (*Jwks, error) { +func getPemKeys(ctx context.Context, keysLocation string) (*Jwks, error) { resp, err := http.Get(keysLocation) if err != nil { return nil, err @@ -198,13 +199,13 @@ func getPemKeys(keysLocation string) (*Jwks, error) { } cacheControlHeader := resp.Header.Get("Cache-Control") - expiresIn := getMaxAgeFromCacheHeader(cacheControlHeader) + expiresIn := getMaxAgeFromCacheHeader(ctx, cacheControlHeader) jwks.expiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second) return jwks, err } -func getPemCert(token *jwt.Token, jwks *Jwks) (string, error) { +func getPemCert(ctx context.Context, token *jwt.Token, jwks *Jwks) (string, error) { // todo as we load the jkws when the server is starting, we should build a JKS map with the pem cert at the boot time cert := "" @@ -217,7 +218,7 @@ func getPemCert(token *jwt.Token, jwks *Jwks) (string, error) { cert = "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----" return cert, nil } - log.Debugf("generating validation pem from JWK") + log.WithContext(ctx).Debugf("generating validation pem from JWK") return generatePemFromJWK(jwks.Keys[k]) } @@ -284,7 +285,7 @@ func convertExponentStringToInt(stringExponent string) (int, error) { } // getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header -func getMaxAgeFromCacheHeader(cacheControl string) int { +func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int { // Split into individual directives directives := strings.Split(cacheControl, ",") @@ -295,7 +296,7 @@ func getMaxAgeFromCacheHeader(cacheControl string) int { maxAgeStr := strings.TrimPrefix(directive, "max-age=") maxAge, err := strconv.Atoi(maxAgeStr) if err != nil { - log.Debugf("error parsing max-age: %v", err) + log.WithContext(ctx).Debugf("error parsing max-age: %v", err) return 0 } diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 1b7bced3c..7976e76e4 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -134,7 +134,8 @@ func Test_SyncProtocol(t *testing.T) { // take the first registered peer as a base for the test. Total four. key := *peers[0] - message, err := encryption.EncryptMessage(*serverKey, key, &mgmtProto.SyncRequest{}) + syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}} + message, err := encryption.EncryptMessage(*serverKey, key, syncReq) if err != nil { t.Fatal(err) return @@ -169,7 +170,7 @@ func Test_SyncProtocol(t *testing.T) { } if wiretrusteeConfig.GetSignal() == nil { - t.Fatal("expecting SyncResponse to have WiretrusteeConfig with non-nil Signal turnCfg") + t.Fatal("expecting SyncResponse to have WiretrusteeConfig with non-nil Signal config") } expectedSignalConfig := &mgmtProto.HostConfig{ @@ -405,7 +406,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error) return nil, "", err } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := NewTestStoreFromJson(config.Datadir) + store, cleanUp, err := NewTestStoreFromJson(context.Background(), config.Datadir) if err != nil { return nil, "", err } @@ -413,7 +414,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error) peersUpdateManager := NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} - accountManager, err := BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", + accountManager, err := BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}) if err != nil { return nil, "", err @@ -421,7 +422,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error) turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "") ephemeralMgr := NewEphemeralManager(store, accountManager) - mgmtServer, err := NewServer(config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr) + mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr) if err != nil { return nil, "", err } diff --git a/management/server/management_test.go b/management/server/management_test.go index 0ad8426cf..1ef97f73b 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -93,7 +93,8 @@ var _ = Describe("Management service", func() { key, _ := wgtypes.GenerateKey() loginPeerWithValidSetupKey(serverPubKey, key, client) - encryptedBytes, err := encryption.EncryptMessage(serverPubKey, key, &mgmtProto.SyncRequest{}) + syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}} + encryptedBytes, err := encryption.EncryptMessage(serverPubKey, key, syncReq) Expect(err).NotTo(HaveOccurred()) sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ @@ -143,7 +144,7 @@ var _ = Describe("Management service", func() { loginPeerWithValidSetupKey(serverPubKey, key1, client) loginPeerWithValidSetupKey(serverPubKey, key2, client) - messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{}) + messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}) Expect(err).NotTo(HaveOccurred()) encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, key) Expect(err).NotTo(HaveOccurred()) @@ -176,7 +177,7 @@ var _ = Describe("Management service", func() { key, _ := wgtypes.GenerateKey() loginPeerWithValidSetupKey(serverPubKey, key, client) - messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{}) + messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}) Expect(err).NotTo(HaveOccurred()) encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, key) Expect(err).NotTo(HaveOccurred()) @@ -329,7 +330,7 @@ var _ = Describe("Management service", func() { var clients []mgmtProto.ManagementService_SyncClient for _, peer := range peers { - messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{}) + messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}) Expect(err).NotTo(HaveOccurred()) encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, peer) Expect(err).NotTo(HaveOccurred()) @@ -394,7 +395,8 @@ var _ = Describe("Management service", func() { defer GinkgoRecover() key, _ := wgtypes.GenerateKey() loginPeerWithValidSetupKey(serverPubKey, key, client) - encryptedBytes, err := encryption.EncryptMessage(serverPubKey, key, &mgmtProto.SyncRequest{}) + syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}} + encryptedBytes, err := encryption.EncryptMessage(serverPubKey, key, syncReq) Expect(err).NotTo(HaveOccurred()) // open stream @@ -449,11 +451,11 @@ var _ = Describe("Management service", func() { type MocIntegratedValidator struct { } -func (a MocIntegratedValidator) ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { +func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { return nil } -func (a MocIntegratedValidator) ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) { +func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) { return update, nil } @@ -465,15 +467,15 @@ func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[s return validatedPeers, nil } -func (MocIntegratedValidator) PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { +func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { return peer } -func (MocIntegratedValidator) IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) { +func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) { return false, false, nil } -func (MocIntegratedValidator) PeerDeleted(_, _ string) error { +func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error { return nil } @@ -481,7 +483,7 @@ func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string) } -func (MocIntegratedValidator) Stop() {} +func (MocIntegratedValidator) Stop(_ context.Context) {} func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, client mgmtProto.ManagementServiceClient) *mgmtProto.LoginResponse { defer GinkgoRecover() @@ -532,20 +534,20 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) { Expect(err).NotTo(HaveOccurred()) s := grpc.NewServer() - store, _, err := server.NewTestStoreFromJson(config.Datadir) + store, _, err := server.NewTestStoreFromJson(context.Background(), config.Datadir) if err != nil { log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) } peersUpdateManager := server.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} - accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", + accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}) if err != nil { log.Fatalf("failed creating a manager: %v", err) } turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "") - mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil) Expect(err).NotTo(HaveOccurred()) mgmtProto.RegisterManagementServiceServer(s, mgmtServer) go func() { diff --git a/management/server/metrics/selfhosted.go b/management/server/metrics/selfhosted.go index 9da1e577e..bdf744d21 100644 --- a/management/server/metrics/selfhosted.go +++ b/management/server/metrics/selfhosted.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "os" "sort" "strings" "time" @@ -24,7 +25,7 @@ const ( // payloadEndpoint metrics defaultEndpoint to send anonymous data payloadEndpoint = "https://metrics.netbird.io" // defaultPushInterval default interval to push metrics - defaultPushInterval = 24 * time.Hour + defaultPushInterval = 12 * time.Hour // requestTimeout http request timeout requestTimeout = 45 * time.Second ) @@ -46,7 +47,7 @@ type properties map[string]interface{} // DataSource metric data source type DataSource interface { - GetAllAccounts() []*server.Account + GetAllAccounts(ctx context.Context) []*server.Account GetStoreEngine() server.StoreEngine } @@ -81,29 +82,45 @@ func NewWorker(ctx context.Context, id string, dataSource DataSource, connManage } // Run runs the metrics worker -func (w *Worker) Run() { - pushTicker := time.NewTicker(defaultPushInterval) +func (w *Worker) Run(ctx context.Context) { + interval := getMetricsInterval(ctx) + + pushTicker := time.NewTicker(interval) for { select { case <-w.ctx.Done(): return case <-pushTicker.C: - err := w.sendMetrics() + err := w.sendMetrics(ctx) if err != nil { - log.Error(err) + log.WithContext(ctx).Error(err) } w.lastRun = time.Now() } } } -func (w *Worker) sendMetrics() error { +func getMetricsInterval(ctx context.Context) time.Duration { + interval := defaultPushInterval + if os.Getenv("NETBIRD_METRICS_INTERVAL_IN_SECONDS") != "" { + newInterval, err := time.ParseDuration(os.Getenv("NETBIRD_METRICS_INTERVAL_IN_SECONDS") + "s") + if err != nil { + log.WithContext(ctx).Errorf("unable to parse NETBIRD_METRICS_INTERVAL_IN_SECONDS, using default interval %v. Error: %v", defaultPushInterval, err) + } else { + log.WithContext(ctx).Infof("using NETBIRD_METRICS_INTERVAL_IN_SECONDS %s", newInterval) + interval = newInterval + } + } + return interval +} + +func (w *Worker) sendMetrics(ctx context.Context) error { apiKey, err := getAPIKey(w.ctx) if err != nil { return err } - payload := w.generatePayload(apiKey) + payload := w.generatePayload(ctx, apiKey) payloadString, err := buildMetricsPayload(payload) if err != nil { @@ -112,10 +129,11 @@ func (w *Worker) sendMetrics() error { httpClient := http.Client{} - exportJobReq, err := createPostRequest(w.ctx, payloadEndpoint+"/capture/", payloadString) + exportJobReq, cancelCTX, err := createPostRequest(w.ctx, payloadEndpoint+"/capture/", payloadString) if err != nil { return fmt.Errorf("unable to create metrics post request %v", err) } + defer cancelCTX() jobResp, err := httpClient.Do(exportJobReq) if err != nil { @@ -125,7 +143,7 @@ func (w *Worker) sendMetrics() error { defer func() { err = jobResp.Body.Close() if err != nil { - log.Errorf("error while closing update metrics response body: %v", err) + log.WithContext(ctx).Errorf("error while closing update metrics response body: %v", err) } }() @@ -133,15 +151,15 @@ func (w *Worker) sendMetrics() error { return fmt.Errorf("unable to push anonymous metrics, got statusCode %d", jobResp.StatusCode) } - log.Infof("sent anonymous metrics, next push will happen in %s. "+ + log.WithContext(ctx).Infof("sent anonymous metrics, next push will happen in %s. "+ "You can disable these metrics by running with flag --disable-anonymous-metrics,"+ - " see more information at https://netbird.io/docs/FAQ/metrics-collection", defaultPushInterval) + " see more information at https://docs.netbird.io/about-netbird/faq#why-and-what-are-the-anonymous-usage-metrics", getMetricsInterval(ctx)) return nil } -func (w *Worker) generatePayload(apiKey string) pushPayload { - properties := w.generateProperties() +func (w *Worker) generatePayload(ctx context.Context, apiKey string) pushPayload { + properties := w.generateProperties(ctx) return pushPayload{ APIKey: apiKey, @@ -152,7 +170,7 @@ func (w *Worker) generatePayload(apiKey string) pushPayload { } } -func (w *Worker) generateProperties() properties { +func (w *Worker) generateProperties(ctx context.Context) properties { var ( uptime float64 accounts int @@ -192,7 +210,7 @@ func (w *Worker) generateProperties() properties { connections := w.connManager.GetAllConnectedPeers() version = nbversion.NetbirdVersion() - for _, account := range w.dataSource.GetAllAccounts() { + for _, account := range w.dataSource.GetAllAccounts(ctx) { accounts++ if account.Settings.PeerLoginExpirationEnabled { @@ -342,7 +360,7 @@ func getAPIKey(ctx context.Context) (string, error) { defer func() { err = response.Body.Close() if err != nil { - log.Errorf("error while closing metrics token response body: %v", err) + log.WithContext(ctx).Errorf("error while closing metrics token response body: %v", err) } }() @@ -373,20 +391,20 @@ func buildMetricsPayload(payload pushPayload) (string, error) { return string(str), nil } -func createPostRequest(ctx context.Context, endpoint string, payloadStr string) (*http.Request, error) { +func createPostRequest(ctx context.Context, endpoint string, payloadStr string) (*http.Request, context.CancelFunc, error) { ctx, cancel := context.WithTimeout(ctx, requestTimeout) - defer cancel() reqURL := endpoint payload := strings.NewReader(payloadStr) req, err := http.NewRequestWithContext(ctx, "POST", reqURL, payload) if err != nil { - return nil, err + cancel() + return nil, nil, err } req.Header.Add("content-type", "application/json") - return req, nil + return req, cancel, nil } func getMinMaxVersion(inputList []string) (string, string) { diff --git a/management/server/metrics/selfhosted_test.go b/management/server/metrics/selfhosted_test.go index c5b18607a..2ac2d68a0 100644 --- a/management/server/metrics/selfhosted_test.go +++ b/management/server/metrics/selfhosted_test.go @@ -1,6 +1,7 @@ package metrics import ( + "context" "testing" nbdns "github.com/netbirdio/netbird/dns" @@ -21,7 +22,7 @@ func (mockDatasource) GetAllConnectedPeers() map[string]struct{} { } // GetAllAccounts returns a list of *server.Account for use in tests with predefined information -func (mockDatasource) GetAllAccounts() []*server.Account { +func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account { return []*server.Account{ { Id: "1", @@ -188,7 +189,7 @@ func TestGenerateProperties(t *testing.T) { connManager: ds, } - properties := worker.generateProperties() + properties := worker.generateProperties(context.Background()) if properties["accounts"] != 2 { t.Errorf("expected 2 accounts, got %d", properties["accounts"]) diff --git a/management/server/migration/migration.go b/management/server/migration/migration.go index 9776418ad..4c8baea5e 100644 --- a/management/server/migration/migration.go +++ b/management/server/migration/migration.go @@ -1,6 +1,7 @@ package migration import ( + "context" "database/sql" "encoding/gob" "encoding/json" @@ -16,7 +17,7 @@ import ( // MigrateFieldFromGobToJSON migrates a column from Gob encoding to JSON encoding. // T is the type of the model that contains the field to be migrated. // S is the type of the field to be migrated. -func MigrateFieldFromGobToJSON[T any, S any](db *gorm.DB, fieldName string) error { +func MigrateFieldFromGobToJSON[T any, S any](ctx context.Context, db *gorm.DB, fieldName string) error { oldColumnName := fieldName newColumnName := fieldName + "_tmp" @@ -24,7 +25,7 @@ func MigrateFieldFromGobToJSON[T any, S any](db *gorm.DB, fieldName string) erro var model T if !db.Migrator().HasTable(&model) { - log.Debugf("Table for %T does not exist, no migration needed", model) + log.WithContext(ctx).Debugf("Table for %T does not exist, no migration needed", model) return nil } @@ -35,20 +36,23 @@ func MigrateFieldFromGobToJSON[T any, S any](db *gorm.DB, fieldName string) erro } tableName := stmt.Schema.Table - var item string - if err := db.Model(model).Select(oldColumnName).First(&item).Error; err != nil { + var sqliteItem sql.NullString + if err := db.Model(model).Select(oldColumnName).First(&sqliteItem).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - log.Debugf("No records in table %s, no migration needed", tableName) + log.WithContext(ctx).Debugf("No records in table %s, no migration needed", tableName) return nil } return fmt.Errorf("fetch first record: %w", err) } + item := sqliteItem.String + var js json.RawMessage var syntaxError *json.SyntaxError err = json.Unmarshal([]byte(item), &js) - if err == nil || !errors.As(err, &syntaxError) { - log.Debugf("No migration needed for %s, %s", tableName, fieldName) + // if the item is JSON parsable or an empty string it can not be gob encoded + if err == nil || !errors.As(err, &syntaxError) || item == "" { + log.WithContext(ctx).Debugf("No migration needed for %s, %s", tableName, fieldName) return nil } @@ -74,7 +78,6 @@ func MigrateFieldFromGobToJSON[T any, S any](db *gorm.DB, fieldName string) erro if err := gob.NewDecoder(reader).Decode(&field); err != nil { return fmt.Errorf("gob decode error: %w", err) } - jsonValue, err := json.Marshal(field) if err != nil { return fmt.Errorf("re-encode to JSON: %w", err) @@ -97,14 +100,14 @@ func MigrateFieldFromGobToJSON[T any, S any](db *gorm.DB, fieldName string) erro return err } - log.Infof("Migration of %s.%s from gob to json completed", tableName, fieldName) + log.WithContext(ctx).Infof("Migration of %s.%s from gob to json completed", tableName, fieldName) return nil } // MigrateNetIPFieldFromBlobToJSON migrates a Net IP column from Blob encoding to JSON encoding. // T is the type of the model that contains the field to be migrated. -func MigrateNetIPFieldFromBlobToJSON[T any](db *gorm.DB, fieldName string, indexName string) error { +func MigrateNetIPFieldFromBlobToJSON[T any](ctx context.Context, db *gorm.DB, fieldName string, indexName string) error { oldColumnName := fieldName newColumnName := fieldName + "_tmp" @@ -136,7 +139,7 @@ func MigrateNetIPFieldFromBlobToJSON[T any](db *gorm.DB, fieldName string, index var syntaxError *json.SyntaxError err = json.Unmarshal([]byte(item.String), &js) if err == nil || !errors.As(err, &syntaxError) { - log.Debugf("No migration needed for %s, %s", tableName, fieldName) + log.WithContext(ctx).Debugf("No migration needed for %s, %s", tableName, fieldName) return nil } } @@ -167,7 +170,7 @@ func MigrateNetIPFieldFromBlobToJSON[T any](db *gorm.DB, fieldName string, index columnIpValue := net.IP(blobValue) if net.ParseIP(columnIpValue.String()) == nil { - log.Debugf("failed to parse %s as ip, fallback to ipv6 loopback", oldColumnName) + log.WithContext(ctx).Debugf("failed to parse %s as ip, fallback to ipv6 loopback", oldColumnName) columnIpValue = net.IPv6loopback } diff --git a/management/server/migration/migration_test.go b/management/server/migration/migration_test.go index 45757e9d6..5a1926641 100644 --- a/management/server/migration/migration_test.go +++ b/management/server/migration/migration_test.go @@ -1,6 +1,7 @@ package migration_test import ( + "context" "encoding/gob" "net" "strings" @@ -30,7 +31,7 @@ func setupDatabase(t *testing.T) *gorm.DB { func TestMigrateFieldFromGobToJSON_EmptyDB(t *testing.T) { db := setupDatabase(t) - err := migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](db, "network_net") + err := migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](context.Background(), db, "network_net") require.NoError(t, err, "Migration should not fail for an empty database") } @@ -63,7 +64,7 @@ func TestMigrateFieldFromGobToJSON_WithGobData(t *testing.T) { err = gob.NewDecoder(strings.NewReader(gobStr)).Decode(&ipnet) require.NoError(t, err, "Failed to decode Gob data") - err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](db, "network_net") + err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](context.Background(), db, "network_net") require.NoError(t, err, "Migration should not fail with Gob data") var jsonStr string @@ -83,7 +84,7 @@ func TestMigrateFieldFromGobToJSON_WithJSONData(t *testing.T) { err = db.Save(&server.Account{Network: &server.Network{Net: *ipnet}}).Error require.NoError(t, err, "Failed to insert JSON data") - err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](db, "network_net") + err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](context.Background(), db, "network_net") require.NoError(t, err, "Migration should not fail with JSON data") var jsonStr string @@ -93,7 +94,7 @@ func TestMigrateFieldFromGobToJSON_WithJSONData(t *testing.T) { func TestMigrateNetIPFieldFromBlobToJSON_EmptyDB(t *testing.T) { db := setupDatabase(t) - err := migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "ip", "idx_peers_account_id_ip") + err := migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](context.Background(), db, "ip", "idx_peers_account_id_ip") require.NoError(t, err, "Migration should not fail for an empty database") } @@ -130,7 +131,7 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) { err = db.Model(&nbpeer.Peer{}).Select("location_connection_ip").First(&blobValue).Error assert.NoError(t, err, "Failed to fetch blob data") - err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "") + err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](context.Background(), db, "location_connection_ip", "") require.NoError(t, err, "Migration should not fail with net.IP blob data") var jsonStr string @@ -152,7 +153,7 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) { ).Error require.NoError(t, err, "Failed to insert JSON data") - err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "") + err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](context.Background(), db, "location_connection_ip", "") require.NoError(t, err, "Migration should not fail with net.IP JSON data") var jsonStr string diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 765cd8483..177088ac5 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -1,13 +1,16 @@ package mock_server import ( + "context" "net" + "net/netip" "time" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/group" @@ -19,93 +22,95 @@ import ( ) type MockAccountManager struct { - GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error) - CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, + GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error) + CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) - GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error) - GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error) - GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error) - ListUsersFunc func(accountID string) ([]*server.User, error) - GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error) - MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error - SyncAndMarkPeerFunc func(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error) - DeletePeerFunc func(accountID, peerKey, userID string) error - GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error) - GetPeerNetworkFunc func(peerKey string) (*server.Network, error) - AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error) - GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error) - GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error) - GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error) - SaveGroupFunc func(accountID, userID string, group *group.Group) error - DeleteGroupFunc func(accountID, userId, groupID string) error - ListGroupsFunc func(accountID string) ([]*group.Group, error) - GroupAddPeerFunc func(accountID, groupID, peerID string) error - GroupDeletePeerFunc func(accountID, groupID, peerID string) error - DeleteRuleFunc func(accountID, ruleID, userID string) error - GetPolicyFunc func(accountID, policyID, userID string) (*server.Policy, error) - SavePolicyFunc func(accountID, userID string, policy *server.Policy) error - DeletePolicyFunc func(accountID, policyID, userID string) error - ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error) - GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error) - GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) - MarkPATUsedFunc func(pat string) error - UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error - UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error - UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) - CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) - GetRouteFunc func(accountID string, routeID route.ID, userID string) (*route.Route, error) - SaveRouteFunc func(accountID string, userID string, route *route.Route) error - DeleteRouteFunc func(accountID string, routeID route.ID, userID string) error - ListRoutesFunc func(accountID, userID string) ([]*route.Route, error) - SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) - ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error) - SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error) - SaveOrAddUserFunc func(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) - DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error - CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) - DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error - GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error) - GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error) - GetNameServerGroupFunc func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) - CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) - SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error - DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error - ListNameServerGroupsFunc func(accountID string, userID string) ([]*nbdns.NameServerGroup, error) - CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) - GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) - CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error - DeleteAccountFunc func(accountID, userID string) error + GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) + GetAccountByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (*server.Account, error) + GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) + ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) + GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) + MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error + SyncAndMarkPeerFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) + DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error + GetNetworkMapFunc func(ctx context.Context, peerKey string) (*server.NetworkMap, error) + GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*server.Network, error) + AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) + GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*group.Group, error) + GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*group.Group, error) + GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*group.Group, error) + SaveGroupFunc func(ctx context.Context, accountID, userID string, group *group.Group) error + DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error + ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error) + GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error + GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error + DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error + GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) + SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) error + DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error + ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error) + GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error) + GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) + MarkPATUsedFunc func(ctx context.Context, pat string) error + UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error + UpdatePeerSSHKeyFunc func(ctx context.Context, peerID string, sshKey string) error + UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) + CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) + GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) + SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error + DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error + ListRoutesFunc func(ctx context.Context, accountID, userID string) ([]*route.Route, error) + SaveSetupKeyFunc func(ctx context.Context, accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) + ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*server.SetupKey, error) + SaveUserFunc func(ctx context.Context, accountID, userID string, user *server.User) (*server.UserInfo, error) + SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) + DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error + CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) + DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error + GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error) + GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error) + GetNameServerGroupFunc func(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) + CreateNameServerGroupFunc func(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) + SaveNameServerGroupFunc func(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error + DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error + ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) + CreateUserFunc func(ctx context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) + GetAccountFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) + CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error + DeleteAccountFunc func(ctx context.Context, accountID, userID string) error GetDNSDomainFunc func() string - StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) - GetEventsFunc func(accountID, userID string) ([]*activity.Event, error) - GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error) - SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error - GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error) - UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error) - LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error) - SyncPeerFunc func(sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, error) - InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error + StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) + GetEventsFunc func(ctx context.Context, accountID, userID string) ([]*activity.Event, error) + GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*server.DNSSettings, error) + SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *server.DNSSettings) error + GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) + UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) + LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) + SyncPeerFunc func(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) + InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error GetAllConnectedPeersFunc func() (map[string]struct{}, error) HasConnectedChannelFunc func(peerID string) bool GetExternalCacheManagerFunc func() server.ExternalCacheManager - GetPostureChecksFunc func(accountID, postureChecksID, userID string) (*posture.Checks, error) - SavePostureChecksFunc func(accountID, userID string, postureChecks *posture.Checks) error - DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error - ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error) + GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) + SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error + DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error + ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) GetIdpManagerFunc func() idp.Manager - UpdateIntegratedValidatorGroupsFunc func(accountID string, userID string, groups []string) error - GroupValidationFunc func(accountId string, groups []string) (bool, error) + UpdateIntegratedValidatorGroupsFunc func(ctx context.Context, accountID string, userID string, groups []string) error + GroupValidationFunc func(ctx context.Context, accountId string, groups []string) (bool, error) + SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) + GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error) } -func (am *MockAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error) { +func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { if am.SyncAndMarkPeerFunc != nil { - return am.SyncAndMarkPeerFunc(peerPubKey, realIP) + return am.SyncAndMarkPeerFunc(ctx, peerPubKey, meta, realIP) } - return nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") + return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } -func (am *MockAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error { +func (am *MockAccountManager) CancelPeerRoutines(_ context.Context, peer *nbpeer.Peer) error { // TODO implement me panic("implement me") } @@ -119,43 +124,43 @@ func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[st } // GetGroup mock implementation of GetGroup from server.AccountManager interface -func (am *MockAccountManager) GetGroup(accountId, groupID, userID string) (*group.Group, error) { +func (am *MockAccountManager) GetGroup(ctx context.Context, accountId, groupID, userID string) (*group.Group, error) { if am.GetGroupFunc != nil { - return am.GetGroupFunc(accountId, groupID, userID) + return am.GetGroupFunc(ctx, accountId, groupID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetGroup is not implemented") } // GetAllGroups mock implementation of GetAllGroups from server.AccountManager interface -func (am *MockAccountManager) GetAllGroups(accountID, userID string) ([]*group.Group, error) { +func (am *MockAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*group.Group, error) { if am.GetAllGroupsFunc != nil { - return am.GetAllGroupsFunc(accountID, userID) + return am.GetAllGroupsFunc(ctx, accountID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetAllGroups is not implemented") } // GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface -func (am *MockAccountManager) GetUsersFromAccount(accountID string, userID string) ([]*server.UserInfo, error) { +func (am *MockAccountManager) GetUsersFromAccount(ctx context.Context, accountID string, userID string) ([]*server.UserInfo, error) { if am.GetUsersFromAccountFunc != nil { - return am.GetUsersFromAccountFunc(accountID, userID) + return am.GetUsersFromAccountFunc(ctx, accountID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetUsersFromAccount is not implemented") } // DeletePeer mock implementation of DeletePeer from server.AccountManager interface -func (am *MockAccountManager) DeletePeer(accountID, peerID, userID string) error { +func (am *MockAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error { if am.DeletePeerFunc != nil { - return am.DeletePeerFunc(accountID, peerID, userID) + return am.DeletePeerFunc(ctx, accountID, peerID, userID) } return status.Errorf(codes.Unimplemented, "method DeletePeer is not implemented") } // GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface func (am *MockAccountManager) GetOrCreateAccountByUser( - userId, domain string, + ctx context.Context, userId, domain string, ) (*server.Account, error) { if am.GetOrCreateAccountByUserFunc != nil { - return am.GetOrCreateAccountByUserFunc(userId, domain) + return am.GetOrCreateAccountByUserFunc(ctx, userId, domain) } return nil, status.Errorf( codes.Unimplemented, @@ -165,6 +170,7 @@ func (am *MockAccountManager) GetOrCreateAccountByUser( // CreateSetupKey mock implementation of CreateSetupKey from server.AccountManager interface func (am *MockAccountManager) CreateSetupKey( + ctx context.Context, accountID string, keyName string, keyType server.SetupKeyType, @@ -175,17 +181,17 @@ func (am *MockAccountManager) CreateSetupKey( ephemeral bool, ) (*server.SetupKey, error) { if am.CreateSetupKeyFunc != nil { - return am.CreateSetupKeyFunc(accountID, keyName, keyType, expiresIn, autoGroups, usageLimit, userID, ephemeral) + return am.CreateSetupKeyFunc(ctx, accountID, keyName, keyType, expiresIn, autoGroups, usageLimit, userID, ephemeral) } return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented") } // GetAccountByUserOrAccountID mock implementation of GetAccountByUserOrAccountID from server.AccountManager interface func (am *MockAccountManager) GetAccountByUserOrAccountID( - userId, accountId, domain string, + ctx context.Context, userId, accountId, domain string, ) (*server.Account, error) { if am.GetAccountByUserOrAccountIdFunc != nil { - return am.GetAccountByUserOrAccountIdFunc(userId, accountId, domain) + return am.GetAccountByUserOrAccountIdFunc(ctx, userId, accountId, domain) } return nil, status.Errorf( codes.Unimplemented, @@ -194,391 +200,392 @@ func (am *MockAccountManager) GetAccountByUserOrAccountID( } // MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface -func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool, realIP net.IP, account *server.Account) error { +func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *server.Account) error { if am.MarkPeerConnectedFunc != nil { - return am.MarkPeerConnectedFunc(peerKey, connected, realIP) + return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP) } return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } // GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface -func (am *MockAccountManager) GetAccountFromPAT(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { +func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { if am.GetAccountFromPATFunc != nil { - return am.GetAccountFromPATFunc(pat) + return am.GetAccountFromPATFunc(ctx, pat) } return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented") } // DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface -func (am *MockAccountManager) DeleteAccount(accountID, userID string) error { +func (am *MockAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error { if am.DeleteAccountFunc != nil { - return am.DeleteAccountFunc(accountID, userID) + return am.DeleteAccountFunc(ctx, accountID, userID) } return status.Errorf(codes.Unimplemented, "method DeleteAccount is not implemented") } // MarkPATUsed mock implementation of MarkPATUsed from server.AccountManager interface -func (am *MockAccountManager) MarkPATUsed(pat string) error { +func (am *MockAccountManager) MarkPATUsed(ctx context.Context, pat string) error { if am.MarkPATUsedFunc != nil { - return am.MarkPATUsedFunc(pat) + return am.MarkPATUsedFunc(ctx, pat) } return status.Errorf(codes.Unimplemented, "method MarkPATUsed is not implemented") } // CreatePAT mock implementation of GetPAT from server.AccountManager interface -func (am *MockAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { +func (am *MockAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { if am.CreatePATFunc != nil { - return am.CreatePATFunc(accountID, initiatorUserID, targetUserID, name, expiresIn) + return am.CreatePATFunc(ctx, accountID, initiatorUserID, targetUserID, name, expiresIn) } return nil, status.Errorf(codes.Unimplemented, "method CreatePAT is not implemented") } // DeletePAT mock implementation of DeletePAT from server.AccountManager interface -func (am *MockAccountManager) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error { +func (am *MockAccountManager) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error { if am.DeletePATFunc != nil { - return am.DeletePATFunc(accountID, initiatorUserID, targetUserID, tokenID) + return am.DeletePATFunc(ctx, accountID, initiatorUserID, targetUserID, tokenID) } return status.Errorf(codes.Unimplemented, "method DeletePAT is not implemented") } // GetPAT mock implementation of GetPAT from server.AccountManager interface -func (am *MockAccountManager) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) { +func (am *MockAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) { if am.GetPATFunc != nil { - return am.GetPATFunc(accountID, initiatorUserID, targetUserID, tokenID) + return am.GetPATFunc(ctx, accountID, initiatorUserID, targetUserID, tokenID) } return nil, status.Errorf(codes.Unimplemented, "method GetPAT is not implemented") } // GetAllPATs mock implementation of GetAllPATs from server.AccountManager interface -func (am *MockAccountManager) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) { +func (am *MockAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) { if am.GetAllPATsFunc != nil { - return am.GetAllPATsFunc(accountID, initiatorUserID, targetUserID) + return am.GetAllPATsFunc(ctx, accountID, initiatorUserID, targetUserID) } return nil, status.Errorf(codes.Unimplemented, "method GetAllPATs is not implemented") } // GetNetworkMap mock implementation of GetNetworkMap from server.AccountManager interface -func (am *MockAccountManager) GetNetworkMap(peerKey string) (*server.NetworkMap, error) { +func (am *MockAccountManager) GetNetworkMap(ctx context.Context, peerKey string) (*server.NetworkMap, error) { if am.GetNetworkMapFunc != nil { - return am.GetNetworkMapFunc(peerKey) + return am.GetNetworkMapFunc(ctx, peerKey) } return nil, status.Errorf(codes.Unimplemented, "method GetNetworkMap is not implemented") } // GetPeerNetwork mock implementation of GetPeerNetwork from server.AccountManager interface -func (am *MockAccountManager) GetPeerNetwork(peerKey string) (*server.Network, error) { +func (am *MockAccountManager) GetPeerNetwork(ctx context.Context, peerKey string) (*server.Network, error) { if am.GetPeerNetworkFunc != nil { - return am.GetPeerNetworkFunc(peerKey) + return am.GetPeerNetworkFunc(ctx, peerKey) } return nil, status.Errorf(codes.Unimplemented, "method GetPeerNetwork is not implemented") } // AddPeer mock implementation of AddPeer from server.AccountManager interface func (am *MockAccountManager) AddPeer( + ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer, -) (*nbpeer.Peer, *server.NetworkMap, error) { +) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { if am.AddPeerFunc != nil { - return am.AddPeerFunc(setupKey, userId, peer) + return am.AddPeerFunc(ctx, setupKey, userId, peer) } - return nil, nil, status.Errorf(codes.Unimplemented, "method AddPeer is not implemented") + return nil, nil, nil, status.Errorf(codes.Unimplemented, "method AddPeer is not implemented") } // GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface -func (am *MockAccountManager) GetGroupByName(accountID, groupName string) (*group.Group, error) { +func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*group.Group, error) { if am.GetGroupFunc != nil { - return am.GetGroupByNameFunc(accountID, groupName) + return am.GetGroupByNameFunc(ctx, accountID, groupName) } return nil, status.Errorf(codes.Unimplemented, "method GetGroupByName is not implemented") } // SaveGroup mock implementation of SaveGroup from server.AccountManager interface -func (am *MockAccountManager) SaveGroup(accountID, userID string, group *group.Group) error { +func (am *MockAccountManager) SaveGroup(ctx context.Context, accountID, userID string, group *group.Group) error { if am.SaveGroupFunc != nil { - return am.SaveGroupFunc(accountID, userID, group) + return am.SaveGroupFunc(ctx, accountID, userID, group) } return status.Errorf(codes.Unimplemented, "method SaveGroup is not implemented") } // DeleteGroup mock implementation of DeleteGroup from server.AccountManager interface -func (am *MockAccountManager) DeleteGroup(accountId, userId, groupID string) error { +func (am *MockAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error { if am.DeleteGroupFunc != nil { - return am.DeleteGroupFunc(accountId, userId, groupID) + return am.DeleteGroupFunc(ctx, accountId, userId, groupID) } return status.Errorf(codes.Unimplemented, "method DeleteGroup is not implemented") } // ListGroups mock implementation of ListGroups from server.AccountManager interface -func (am *MockAccountManager) ListGroups(accountID string) ([]*group.Group, error) { +func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) { if am.ListGroupsFunc != nil { - return am.ListGroupsFunc(accountID) + return am.ListGroupsFunc(ctx, accountID) } return nil, status.Errorf(codes.Unimplemented, "method ListGroups is not implemented") } // GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface -func (am *MockAccountManager) GroupAddPeer(accountID, groupID, peerID string) error { +func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { if am.GroupAddPeerFunc != nil { - return am.GroupAddPeerFunc(accountID, groupID, peerID) + return am.GroupAddPeerFunc(ctx, accountID, groupID, peerID) } return status.Errorf(codes.Unimplemented, "method GroupAddPeer is not implemented") } // GroupDeletePeer mock implementation of GroupDeletePeer from server.AccountManager interface -func (am *MockAccountManager) GroupDeletePeer(accountID, groupID, peerID string) error { +func (am *MockAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error { if am.GroupDeletePeerFunc != nil { - return am.GroupDeletePeerFunc(accountID, groupID, peerID) + return am.GroupDeletePeerFunc(ctx, accountID, groupID, peerID) } return status.Errorf(codes.Unimplemented, "method GroupDeletePeer is not implemented") } // DeleteRule mock implementation of DeleteRule from server.AccountManager interface -func (am *MockAccountManager) DeleteRule(accountID, ruleID, userID string) error { +func (am *MockAccountManager) DeleteRule(ctx context.Context, accountID, ruleID, userID string) error { if am.DeleteRuleFunc != nil { - return am.DeleteRuleFunc(accountID, ruleID, userID) + return am.DeleteRuleFunc(ctx, accountID, ruleID, userID) } return status.Errorf(codes.Unimplemented, "method DeleteRule is not implemented") } // GetPolicy mock implementation of GetPolicy from server.AccountManager interface -func (am *MockAccountManager) GetPolicy(accountID, policyID, userID string) (*server.Policy, error) { +func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) { if am.GetPolicyFunc != nil { - return am.GetPolicyFunc(accountID, policyID, userID) + return am.GetPolicyFunc(ctx, accountID, policyID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetPolicy is not implemented") } // SavePolicy mock implementation of SavePolicy from server.AccountManager interface -func (am *MockAccountManager) SavePolicy(accountID, userID string, policy *server.Policy) error { +func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) error { if am.SavePolicyFunc != nil { - return am.SavePolicyFunc(accountID, userID, policy) + return am.SavePolicyFunc(ctx, accountID, userID, policy) } return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented") } // DeletePolicy mock implementation of DeletePolicy from server.AccountManager interface -func (am *MockAccountManager) DeletePolicy(accountID, policyID, userID string) error { +func (am *MockAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error { if am.DeletePolicyFunc != nil { - return am.DeletePolicyFunc(accountID, policyID, userID) + return am.DeletePolicyFunc(ctx, accountID, policyID, userID) } return status.Errorf(codes.Unimplemented, "method DeletePolicy is not implemented") } // ListPolicies mock implementation of ListPolicies from server.AccountManager interface -func (am *MockAccountManager) ListPolicies(accountID, userID string) ([]*server.Policy, error) { +func (am *MockAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*server.Policy, error) { if am.ListPoliciesFunc != nil { - return am.ListPoliciesFunc(accountID, userID) + return am.ListPoliciesFunc(ctx, accountID, userID) } return nil, status.Errorf(codes.Unimplemented, "method ListPolicies is not implemented") } // UpdatePeerMeta mock implementation of UpdatePeerMeta from server.AccountManager interface -func (am *MockAccountManager) UpdatePeerMeta(peerID string, meta nbpeer.PeerSystemMeta) error { +func (am *MockAccountManager) UpdatePeerMeta(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error { if am.UpdatePeerMetaFunc != nil { - return am.UpdatePeerMetaFunc(peerID, meta) + return am.UpdatePeerMetaFunc(ctx, peerID, meta) } return status.Errorf(codes.Unimplemented, "method UpdatePeerMeta is not implemented") } // GetUser mock implementation of GetUser from server.AccountManager interface -func (am *MockAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (*server.User, error) { +func (am *MockAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) { if am.GetUserFunc != nil { - return am.GetUserFunc(claims) + return am.GetUserFunc(ctx, claims) } return nil, status.Errorf(codes.Unimplemented, "method GetUser is not implemented") } -func (am *MockAccountManager) ListUsers(accountID string) ([]*server.User, error) { +func (am *MockAccountManager) ListUsers(ctx context.Context, accountID string) ([]*server.User, error) { if am.ListUsersFunc != nil { - return am.ListUsersFunc(accountID) + return am.ListUsersFunc(ctx, accountID) } return nil, status.Errorf(codes.Unimplemented, "method ListUsers is not implemented") } // UpdatePeerSSHKey mocks UpdatePeerSSHKey function of the account manager -func (am *MockAccountManager) UpdatePeerSSHKey(peerID string, sshKey string) error { +func (am *MockAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error { if am.UpdatePeerSSHKeyFunc != nil { - return am.UpdatePeerSSHKeyFunc(peerID, sshKey) + return am.UpdatePeerSSHKeyFunc(ctx, peerID, sshKey) } return status.Errorf(codes.Unimplemented, "method UpdatePeerSSHKey is not implemented") } // UpdatePeer mocks UpdatePeerFunc function of the account manager -func (am *MockAccountManager) UpdatePeer(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) { +func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) { if am.UpdatePeerFunc != nil { - return am.UpdatePeerFunc(accountID, userID, peer) + return am.UpdatePeerFunc(ctx, accountID, userID, peer) } return nil, status.Errorf(codes.Unimplemented, "method UpdatePeer is not implemented") } // CreateRoute mock implementation of CreateRoute from server.AccountManager interface -func (am *MockAccountManager) CreateRoute(accountID, prefix, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) { +func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { if am.CreateRouteFunc != nil { - return am.CreateRouteFunc(accountID, prefix, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, enabled, userID) + return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, enabled, userID, keepRoute) } return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented") } // GetRoute mock implementation of GetRoute from server.AccountManager interface -func (am *MockAccountManager) GetRoute(accountID string, routeID route.ID, userID string) (*route.Route, error) { +func (am *MockAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) { if am.GetRouteFunc != nil { - return am.GetRouteFunc(accountID, routeID, userID) + return am.GetRouteFunc(ctx, accountID, routeID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetRoute is not implemented") } // SaveRoute mock implementation of SaveRoute from server.AccountManager interface -func (am *MockAccountManager) SaveRoute(accountID string, userID string, route *route.Route) error { +func (am *MockAccountManager) SaveRoute(ctx context.Context, accountID string, userID string, route *route.Route) error { if am.SaveRouteFunc != nil { - return am.SaveRouteFunc(accountID, userID, route) + return am.SaveRouteFunc(ctx, accountID, userID, route) } return status.Errorf(codes.Unimplemented, "method SaveRoute is not implemented") } // DeleteRoute mock implementation of DeleteRoute from server.AccountManager interface -func (am *MockAccountManager) DeleteRoute(accountID string, routeID route.ID, userID string) error { +func (am *MockAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error { if am.DeleteRouteFunc != nil { - return am.DeleteRouteFunc(accountID, routeID, userID) + return am.DeleteRouteFunc(ctx, accountID, routeID, userID) } return status.Errorf(codes.Unimplemented, "method DeleteRoute is not implemented") } // ListRoutes mock implementation of ListRoutes from server.AccountManager interface -func (am *MockAccountManager) ListRoutes(accountID, userID string) ([]*route.Route, error) { +func (am *MockAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) { if am.ListRoutesFunc != nil { - return am.ListRoutesFunc(accountID, userID) + return am.ListRoutesFunc(ctx, accountID, userID) } return nil, status.Errorf(codes.Unimplemented, "method ListRoutes is not implemented") } // SaveSetupKey mocks SaveSetupKey of the AccountManager interface -func (am *MockAccountManager) SaveSetupKey(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) { +func (am *MockAccountManager) SaveSetupKey(ctx context.Context, accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) { if am.SaveSetupKeyFunc != nil { - return am.SaveSetupKeyFunc(accountID, key, userID) + return am.SaveSetupKeyFunc(ctx, accountID, key, userID) } return nil, status.Errorf(codes.Unimplemented, "method SaveSetupKey is not implemented") } // GetSetupKey mocks GetSetupKey of the AccountManager interface -func (am *MockAccountManager) GetSetupKey(accountID, userID, keyID string) (*server.SetupKey, error) { +func (am *MockAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) { if am.GetSetupKeyFunc != nil { - return am.GetSetupKeyFunc(accountID, userID, keyID) + return am.GetSetupKeyFunc(ctx, accountID, userID, keyID) } return nil, status.Errorf(codes.Unimplemented, "method GetSetupKey is not implemented") } // ListSetupKeys mocks ListSetupKeys of the AccountManager interface -func (am *MockAccountManager) ListSetupKeys(accountID, userID string) ([]*server.SetupKey, error) { +func (am *MockAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*server.SetupKey, error) { if am.ListSetupKeysFunc != nil { - return am.ListSetupKeysFunc(accountID, userID) + return am.ListSetupKeysFunc(ctx, accountID, userID) } return nil, status.Errorf(codes.Unimplemented, "method ListSetupKeys is not implemented") } // SaveUser mocks SaveUser of the AccountManager interface -func (am *MockAccountManager) SaveUser(accountID, userID string, user *server.User) (*server.UserInfo, error) { +func (am *MockAccountManager) SaveUser(ctx context.Context, accountID, userID string, user *server.User) (*server.UserInfo, error) { if am.SaveUserFunc != nil { - return am.SaveUserFunc(accountID, userID, user) + return am.SaveUserFunc(ctx, accountID, userID, user) } return nil, status.Errorf(codes.Unimplemented, "method SaveUser is not implemented") } // SaveOrAddUser mocks SaveOrAddUser of the AccountManager interface -func (am *MockAccountManager) SaveOrAddUser(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) { +func (am *MockAccountManager) SaveOrAddUser(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) { if am.SaveOrAddUserFunc != nil { - return am.SaveOrAddUserFunc(accountID, userID, user, addIfNotExists) + return am.SaveOrAddUserFunc(ctx, accountID, userID, user, addIfNotExists) } return nil, status.Errorf(codes.Unimplemented, "method SaveOrAddUser is not implemented") } // DeleteUser mocks DeleteUser of the AccountManager interface -func (am *MockAccountManager) DeleteUser(accountID string, initiatorUserID string, targetUserID string) error { +func (am *MockAccountManager) DeleteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error { if am.DeleteUserFunc != nil { - return am.DeleteUserFunc(accountID, initiatorUserID, targetUserID) + return am.DeleteUserFunc(ctx, accountID, initiatorUserID, targetUserID) } return status.Errorf(codes.Unimplemented, "method DeleteUser is not implemented") } -func (am *MockAccountManager) InviteUser(accountID string, initiatorUserID string, targetUserID string) error { +func (am *MockAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error { if am.InviteUserFunc != nil { - return am.InviteUserFunc(accountID, initiatorUserID, targetUserID) + return am.InviteUserFunc(ctx, accountID, initiatorUserID, targetUserID) } return status.Errorf(codes.Unimplemented, "method InviteUser is not implemented") } // GetNameServerGroup mocks GetNameServerGroup of the AccountManager interface -func (am *MockAccountManager) GetNameServerGroup(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { +func (am *MockAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { if am.GetNameServerGroupFunc != nil { - return am.GetNameServerGroupFunc(accountID, userID, nsGroupID) + return am.GetNameServerGroupFunc(ctx, accountID, userID, nsGroupID) } return nil, nil } // CreateNameServerGroup mocks CreateNameServerGroup of the AccountManager interface -func (am *MockAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) { +func (am *MockAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) { if am.CreateNameServerGroupFunc != nil { - return am.CreateNameServerGroupFunc(accountID, name, description, nameServerList, groups, primary, domains, enabled, userID, searchDomainsEnabled) + return am.CreateNameServerGroupFunc(ctx, accountID, name, description, nameServerList, groups, primary, domains, enabled, userID, searchDomainsEnabled) } return nil, nil } // SaveNameServerGroup mocks SaveNameServerGroup of the AccountManager interface -func (am *MockAccountManager) SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error { +func (am *MockAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error { if am.SaveNameServerGroupFunc != nil { - return am.SaveNameServerGroupFunc(accountID, userID, nsGroupToSave) + return am.SaveNameServerGroupFunc(ctx, accountID, userID, nsGroupToSave) } return nil } // DeleteNameServerGroup mocks DeleteNameServerGroup of the AccountManager interface -func (am *MockAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error { +func (am *MockAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error { if am.DeleteNameServerGroupFunc != nil { - return am.DeleteNameServerGroupFunc(accountID, nsGroupID, userID) + return am.DeleteNameServerGroupFunc(ctx, accountID, nsGroupID, userID) } return nil } // ListNameServerGroups mocks ListNameServerGroups of the AccountManager interface -func (am *MockAccountManager) ListNameServerGroups(accountID string, userID string) ([]*nbdns.NameServerGroup, error) { +func (am *MockAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) { if am.ListNameServerGroupsFunc != nil { - return am.ListNameServerGroupsFunc(accountID, userID) + return am.ListNameServerGroupsFunc(ctx, accountID, userID) } return nil, nil } // CreateUser mocks CreateUser of the AccountManager interface -func (am *MockAccountManager) CreateUser(accountID, userID string, invite *server.UserInfo) (*server.UserInfo, error) { +func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID string, invite *server.UserInfo) (*server.UserInfo, error) { if am.CreateUserFunc != nil { - return am.CreateUserFunc(accountID, userID, invite) + return am.CreateUserFunc(ctx, accountID, userID, invite) } return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented") } // GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface -func (am *MockAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, +func (am *MockAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error, ) { if am.GetAccountFromTokenFunc != nil { - return am.GetAccountFromTokenFunc(claims) + return am.GetAccountFromTokenFunc(ctx, claims) } return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented") } -func (am *MockAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error { +func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error { if am.CheckUserAccessByJWTGroupsFunc != nil { - return am.CheckUserAccessByJWTGroupsFunc(claims) + return am.CheckUserAccessByJWTGroupsFunc(ctx, claims) } return status.Errorf(codes.Unimplemented, "method CheckUserAccessByJWTGroups is not implemented") } // GetPeers mocks GetPeers of the AccountManager interface -func (am *MockAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.Peer, error) { +func (am *MockAccountManager) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { if am.GetPeersFunc != nil { - return am.GetPeersFunc(accountID, userID) + return am.GetPeersFunc(ctx, accountID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetPeers is not implemented") } @@ -592,59 +599,59 @@ func (am *MockAccountManager) GetDNSDomain() string { } // GetEvents mocks GetEvents of the AccountManager interface -func (am *MockAccountManager) GetEvents(accountID, userID string) ([]*activity.Event, error) { +func (am *MockAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) { if am.GetEventsFunc != nil { - return am.GetEventsFunc(accountID, userID) + return am.GetEventsFunc(ctx, accountID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetEvents is not implemented") } // GetDNSSettings mocks GetDNSSettings of the AccountManager interface -func (am *MockAccountManager) GetDNSSettings(accountID string, userID string) (*server.DNSSettings, error) { +func (am *MockAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) { if am.GetDNSSettingsFunc != nil { - return am.GetDNSSettingsFunc(accountID, userID) + return am.GetDNSSettingsFunc(ctx, accountID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetDNSSettings is not implemented") } // SaveDNSSettings mocks SaveDNSSettings of the AccountManager interface -func (am *MockAccountManager) SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error { +func (am *MockAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error { if am.SaveDNSSettingsFunc != nil { - return am.SaveDNSSettingsFunc(accountID, userID, dnsSettingsToSave) + return am.SaveDNSSettingsFunc(ctx, accountID, userID, dnsSettingsToSave) } return status.Errorf(codes.Unimplemented, "method SaveDNSSettings is not implemented") } // GetPeer mocks GetPeer of the AccountManager interface -func (am *MockAccountManager) GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error) { +func (am *MockAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) { if am.GetPeerFunc != nil { - return am.GetPeerFunc(accountID, peerID, userID) + return am.GetPeerFunc(ctx, accountID, peerID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetPeer is not implemented") } // UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface -func (am *MockAccountManager) UpdateAccountSettings(accountID, userID string, newSettings *server.Settings) (*server.Account, error) { +func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) { if am.UpdateAccountSettingsFunc != nil { - return am.UpdateAccountSettingsFunc(accountID, userID, newSettings) + return am.UpdateAccountSettingsFunc(ctx, accountID, userID, newSettings) } return nil, status.Errorf(codes.Unimplemented, "method UpdateAccountSettings is not implemented") } // LoginPeer mocks LoginPeer of the AccountManager interface -func (am *MockAccountManager) LoginPeer(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error) { +func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { if am.LoginPeerFunc != nil { - return am.LoginPeerFunc(login) + return am.LoginPeerFunc(ctx, login) } - return nil, nil, status.Errorf(codes.Unimplemented, "method LoginPeer is not implemented") + return nil, nil, nil, status.Errorf(codes.Unimplemented, "method LoginPeer is not implemented") } // SyncPeer mocks SyncPeer of the AccountManager interface -func (am *MockAccountManager) SyncPeer(sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, error) { +func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { if am.SyncPeerFunc != nil { - return am.SyncPeerFunc(sync, account) + return am.SyncPeerFunc(ctx, sync, account) } - return nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented") + return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented") } // GetAllConnectedPeers mocks GetAllConnectedPeers of the AccountManager interface @@ -664,9 +671,9 @@ func (am *MockAccountManager) HasConnectedChannel(peerID string) bool { } // StoreEvent mocks StoreEvent of the AccountManager interface -func (am *MockAccountManager) StoreEvent(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { +func (am *MockAccountManager) StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { if am.StoreEventFunc != nil { - am.StoreEventFunc(initiatorID, targetID, accountID, activityID, meta) + am.StoreEventFunc(ctx, initiatorID, targetID, accountID, activityID, meta) } } @@ -679,35 +686,35 @@ func (am *MockAccountManager) GetExternalCacheManager() server.ExternalCacheMana } // GetPostureChecks mocks GetPostureChecks of the AccountManager interface -func (am *MockAccountManager) GetPostureChecks(accountID, postureChecksID, userID string) (*posture.Checks, error) { +func (am *MockAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { if am.GetPostureChecksFunc != nil { - return am.GetPostureChecksFunc(accountID, postureChecksID, userID) + return am.GetPostureChecksFunc(ctx, accountID, postureChecksID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetPostureChecks is not implemented") } // SavePostureChecks mocks SavePostureChecks of the AccountManager interface -func (am *MockAccountManager) SavePostureChecks(accountID, userID string, postureChecks *posture.Checks) error { +func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error { if am.SavePostureChecksFunc != nil { - return am.SavePostureChecksFunc(accountID, userID, postureChecks) + return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks) } return status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented") } // DeletePostureChecks mocks DeletePostureChecks of the AccountManager interface -func (am *MockAccountManager) DeletePostureChecks(accountID, postureChecksID, userID string) error { +func (am *MockAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error { if am.DeletePostureChecksFunc != nil { - return am.DeletePostureChecksFunc(accountID, postureChecksID, userID) + return am.DeletePostureChecksFunc(ctx, accountID, postureChecksID, userID) } return status.Errorf(codes.Unimplemented, "method DeletePostureChecks is not implemented") } // ListPostureChecks mocks ListPostureChecks of the AccountManager interface -func (am *MockAccountManager) ListPostureChecks(accountID, userID string) ([]*posture.Checks, error) { +func (am *MockAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) { if am.ListPostureChecksFunc != nil { - return am.ListPostureChecksFunc(accountID, userID) + return am.ListPostureChecksFunc(ctx, accountID, userID) } return nil, status.Errorf(codes.Unimplemented, "method ListPostureChecks is not implemented") } @@ -721,21 +728,29 @@ func (am *MockAccountManager) GetIdpManager() idp.Manager { } // UpdateIntegratedValidatorGroups mocks UpdateIntegratedApprovalGroups of the AccountManager interface -func (am *MockAccountManager) UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error { +func (am *MockAccountManager) UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error { if am.UpdateIntegratedValidatorGroupsFunc != nil { - return am.UpdateIntegratedValidatorGroupsFunc(accountID, userID, groups) + return am.UpdateIntegratedValidatorGroupsFunc(ctx, accountID, userID, groups) } return status.Errorf(codes.Unimplemented, "method UpdateIntegratedValidatorGroups is not implemented") } // GroupValidation mocks GroupValidation of the AccountManager interface -func (am *MockAccountManager) GroupValidation(accountId string, groups []string) (bool, error) { +func (am *MockAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) { if am.GroupValidationFunc != nil { - return am.GroupValidationFunc(accountId, groups) + return am.GroupValidationFunc(ctx, accountId, groups) } return false, status.Errorf(codes.Unimplemented, "method GroupValidation is not implemented") } +// SyncPeerMeta mocks SyncPeerMeta of the AccountManager interface +func (am *MockAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error { + if am.SyncPeerMetaFunc != nil { + return am.SyncPeerMetaFunc(ctx, peerPubKey, meta) + } + return status.Errorf(codes.Unimplemented, "method SyncPeerMeta is not implemented") +} + // FindExistingPostureCheck mocks FindExistingPostureCheck of the AccountManager interface func (am *MockAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { if am.FindExistingPostureCheckFunc != nil { @@ -743,3 +758,11 @@ func (am *MockAccountManager) FindExistingPostureCheck(accountID string, checks } return nil, status.Errorf(codes.Unimplemented, "method FindExistingPostureCheck is not implemented") } + +// GetAccountIDForPeerKey mocks GetAccountIDForPeerKey of the AccountManager interface +func (am *MockAccountManager) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) { + if am.GetAccountIDForPeerKeyFunc != nil { + return am.GetAccountIDForPeerKeyFunc(ctx, peerKey) + } + return "", status.Errorf(codes.Unimplemented, "method GetAccountIDForPeerKey is not implemented") +} diff --git a/management/server/mock_server/management_server_mock.go b/management/server/mock_server/management_server_mock.go index 29544b53f..d79fbd4e9 100644 --- a/management/server/mock_server/management_server_mock.go +++ b/management/server/mock_server/management_server_mock.go @@ -3,9 +3,10 @@ package mock_server import ( "context" - "github.com/netbirdio/netbird/management/proto" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + + "github.com/netbirdio/netbird/management/proto" ) type ManagementServiceServerMock struct { @@ -17,6 +18,7 @@ type ManagementServiceServerMock struct { IsHealthyFunc func(context.Context, *proto.Empty) (*proto.Empty, error) GetDeviceAuthorizationFlowFunc func(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) GetPKCEAuthorizationFlowFunc func(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) + SyncMetaFunc func(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) } func (m ManagementServiceServerMock) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { @@ -60,3 +62,10 @@ func (m ManagementServiceServerMock) GetPKCEAuthorizationFlow(ctx context.Contex } return nil, status.Errorf(codes.Unimplemented, "method GetPKCEAuthorizationFlow not implemented") } + +func (m ManagementServiceServerMock) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) { + if m.SyncMetaFunc != nil { + return m.SyncMetaFunc(ctx, req) + } + return nil, status.Errorf(codes.Unimplemented, "method SyncMeta not implemented") +} diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 44d231c3e..f8d644ded 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -1,6 +1,7 @@ package server import ( + "context" "errors" "regexp" "unicode/utf8" @@ -17,12 +18,12 @@ import ( const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$` // GetNameServerGroup gets a nameserver group object from account and nameserver group IDs -func (am *DefaultAccountManager) GetNameServerGroup(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { +func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -45,12 +46,12 @@ func (am *DefaultAccountManager) GetNameServerGroup(accountID, userID, nsGroupID } // CreateNameServerGroup creates and saves a new nameserver group -func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) { +func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -79,29 +80,29 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d account.NameServerGroups[newNSGroup.ID] = newNSGroup account.Network.IncSerial() - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return nil, err } - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) - am.StoreEvent(userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) + am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) return newNSGroup.Copy(), nil } // SaveNameServerGroup saves nameserver group -func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error { +func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() if nsGroupToSave == nil { return status.Errorf(status.InvalidArgument, "nameserver group provided is nil") } - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } @@ -114,25 +115,25 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, n account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave account.Network.IncSerial() - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return err } - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) - am.StoreEvent(userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) + am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) return nil } // DeleteNameServerGroup deletes nameserver group with nsGroupID -func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error { +func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } @@ -144,25 +145,25 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, use delete(account.NameServerGroups, nsGroupID) account.Network.IncSerial() - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return err } - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) - am.StoreEvent(userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) + am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) return nil } // ListNameServerGroups returns a list of nameserver groups from account -func (am *DefaultAccountManager) ListNameServerGroups(accountID string, userID string) ([]*nbdns.NameServerGroup, error) { +func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index f2921532d..dd7935fee 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "net/netip" "testing" @@ -383,6 +384,7 @@ func TestCreateNameServerGroup(t *testing.T) { } outNSGroup, err := am.CreateNameServerGroup( + context.Background(), account.Id, testCase.inputArgs.name, testCase.inputArgs.description, @@ -611,7 +613,7 @@ func TestSaveNameServerGroup(t *testing.T) { account.NameServerGroups[testCase.existingNSGroup.ID] = testCase.existingNSGroup - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(context.Background(), account) if err != nil { t.Error("account should be saved") } @@ -646,7 +648,7 @@ func TestSaveNameServerGroup(t *testing.T) { } } - err = am.SaveNameServerGroup(account.Id, userID, nsGroupToSave) + err = am.SaveNameServerGroup(context.Background(), account.Id, userID, nsGroupToSave) testCase.errFunc(t, err) @@ -654,7 +656,7 @@ func TestSaveNameServerGroup(t *testing.T) { return } - account, err = am.Store.GetAccount(account.Id) + account, err = am.Store.GetAccount(context.Background(), account.Id) if err != nil { t.Fatal(err) } @@ -705,17 +707,17 @@ func TestDeleteNameServerGroup(t *testing.T) { account.NameServerGroups[testingNSGroup.ID] = testingNSGroup - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(context.Background(), account) if err != nil { t.Error("failed to save account") } - err = am.DeleteNameServerGroup(account.Id, testingNSGroup.ID, userID) + err = am.DeleteNameServerGroup(context.Background(), account.Id, testingNSGroup.ID, userID) if err != nil { t.Error("deleting nameserver group failed with error: ", err) } - savedAccount, err := am.Store.GetAccount(account.Id) + savedAccount, err := am.Store.GetAccount(context.Background(), account.Id) if err != nil { t.Error("failed to retrieve saved account with error: ", err) } @@ -738,7 +740,7 @@ func TestGetNameServerGroup(t *testing.T) { t.Error("failed to init testing account") } - foundGroup, err := am.GetNameServerGroup(account.Id, testUserID, existingNSGroupID) + foundGroup, err := am.GetNameServerGroup(context.Background(), account.Id, testUserID, existingNSGroupID) if err != nil { t.Error("getting existing nameserver group failed with error: ", err) } @@ -747,7 +749,7 @@ func TestGetNameServerGroup(t *testing.T) { t.Error("got a nil group while getting nameserver group with ID") } - _, err = am.GetNameServerGroup(account.Id, testUserID, "not existing") + _, err = am.GetNameServerGroup(context.Background(), account.Id, testUserID, "not existing") if err == nil { t.Error("getting not existing nameserver group should return error, got nil") } @@ -760,13 +762,13 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}) + return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}) } func createNSStore(t *testing.T) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromJson(dataDir) + store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir) if err != nil { return nil, err } @@ -829,7 +831,7 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error userID := testUserID domain := "example.com" - account := newAccountWithId(accountID, userID, domain) + account := newAccountWithId(context.Background(), accountID, userID, domain) account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup @@ -846,16 +848,16 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error account.Groups[newGroup1.ID] = newGroup1 account.Groups[newGroup2.ID] = newGroup2 - err := am.Store.SaveAccount(account) + err := am.Store.SaveAccount(context.Background(), account) if err != nil { return nil, err } - _, _, err = am.AddPeer("", userID, peer1) + _, _, _, err = am.AddPeer(context.Background(), "", userID, peer1) if err != nil { return nil, err } - _, _, err = am.AddPeer("", userID, peer2) + _, _, _, err = am.AddPeer(context.Background(), "", userID, peer2) if err != nil { return nil, err } diff --git a/management/server/peer.go b/management/server/peer.go index e6488aa3a..9b48276ce 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -1,11 +1,13 @@ package server import ( + "context" "fmt" "net" "strings" "time" + "github.com/netbirdio/netbird/management/server/posture" "github.com/rs/xid" log "github.com/sirupsen/logrus" @@ -19,6 +21,11 @@ import ( type PeerSync struct { // WireGuardPubKey is a peers WireGuard public key WireGuardPubKey string + // Meta is the system information passed by peer, must be always present + Meta nbpeer.PeerSystemMeta + // UpdateAccountPeers indicate updating account peers, + // which occurs when the peer's metadata is updated + UpdateAccountPeers bool } // PeerLogin used as a data object between the gRPC API and AccountManager on Login request. @@ -39,8 +46,8 @@ type PeerLogin struct { // GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if // the current user is not an admin. -func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.Peer, error) { - account, err := am.Store.GetAccount(accountID) +func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -73,7 +80,7 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.P // fetch all the peers that have access to the user's peers for _, peer := range peers { - aclPeers, _ := account.getPeerConnectionResources(peer.ID, approvedPeersMap) + aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap) for _, p := range aclPeers { peersMap[p.ID] = p } @@ -88,7 +95,7 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.P } // MarkPeerConnected marks peer as connected (true) or disconnected (false) -func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected bool, realIP net.IP, account *Account) error { +func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *Account) error { peer, err := account.FindPeerByPubKey(peerPubKey) if err != nil { return err @@ -107,7 +114,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected if am.geo != nil && realIP != nil { location, err := am.geo.Lookup(realIP) if err != nil { - log.Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err) + log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err) } else { peer.Location.ConnectionIP = realIP peer.Location.CountryCode = location.Country.ISOCode @@ -115,7 +122,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected peer.Location.GeoNameID = location.City.GeonameID err = am.Store.SavePeerLocation(account.Id, peer) if err != nil { - log.Warnf("could not store location for peer %s: %s", peer.ID, err) + log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err) } } } @@ -128,24 +135,24 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected } if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { - am.checkAndSchedulePeerLoginExpiration(account) + am.checkAndSchedulePeerLoginExpiration(ctx, account) } if oldStatus.LoginExpired { // we need to update other peers because when peer login expires all other peers are notified to disconnect from // the expired one. Here we notify them that connection is now allowed again. - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) } return nil } // UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated. -func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -155,7 +162,7 @@ func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nb return nil, status.Errorf(status.NotFound, "peer %s not found", update.ID) } - update, err = am.integratedPeerValidator.ValidatePeer(update, peer, userID, accountID, am.GetDNSDomain(), account.GetPeerGroupsList(peer.ID), account.Settings.Extra) + update, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), account.GetPeerGroupsList(peer.ID), account.Settings.Extra) if err != nil { return nil, err } @@ -166,7 +173,7 @@ func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nb if !update.SSHEnabled { event = activity.PeerSSHDisabled } - am.StoreEvent(userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) + am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) } if peer.Name != update.Name { @@ -181,7 +188,7 @@ func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nb peer.DNSLabel = newLabel - am.StoreEvent(userID, peer.ID, accountID, activity.PeerRenamed, peer.EventMeta(am.GetDNSDomain())) + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRenamed, peer.EventMeta(am.GetDNSDomain())) } if peer.LoginExpirationEnabled != update.LoginExpirationEnabled { @@ -196,27 +203,27 @@ func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nb if !update.LoginExpirationEnabled { event = activity.PeerLoginExpirationDisabled } - am.StoreEvent(userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) + am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { - am.checkAndSchedulePeerLoginExpiration(account) + am.checkAndSchedulePeerLoginExpiration(ctx, account) } } account.UpdatePeer(peer) - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return nil, err } - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) return peer, nil } // deletePeers will delete all specified peers and send updates to the remote peers. Don't call without acquiring account lock -func (am *DefaultAccountManager) deletePeers(account *Account, peerIDs []string, userID string) error { +func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Account, peerIDs []string, userID string) error { // the first loop is needed to ensure all peers present under the account before modifying, otherwise // we might have some inconsistencies @@ -233,13 +240,13 @@ func (am *DefaultAccountManager) deletePeers(account *Account, peerIDs []string, // the 2nd loop performs the actual modification for _, peer := range peers { - err := am.integratedPeerValidator.PeerDeleted(account.Id, peer.ID) + err := am.integratedPeerValidator.PeerDeleted(ctx, account.Id, peer.ID) if err != nil { return err } account.DeletePeer(peer.ID) - am.peersUpdateManager.SendUpdate(peer.ID, + am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{ Update: &proto.SyncResponse{ // fill those field for backward compatibility @@ -255,41 +262,41 @@ func (am *DefaultAccountManager) deletePeers(account *Account, peerIDs []string, }, }, }) - am.peersUpdateManager.CloseChannel(peer.ID) - am.StoreEvent(userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain())) + am.peersUpdateManager.CloseChannel(ctx, peer.ID) + am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain())) } return nil } // DeletePeer removes peer from the account by its IP -func (am *DefaultAccountManager) DeletePeer(accountID, peerID, userID string) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } - err = am.deletePeers(account, []string{peerID}, userID) + err = am.deletePeers(ctx, account, []string{peerID}, userID) if err != nil { return err } - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return err } - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) return nil } // GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result) -func (am *DefaultAccountManager) GetNetworkMap(peerID string) (*NetworkMap, error) { - account, err := am.Store.GetAccountByPeerID(peerID) +func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID string) (*NetworkMap, error) { + account, err := am.Store.GetAccountByPeerID(ctx, peerID) if err != nil { return nil, err } @@ -308,12 +315,12 @@ func (am *DefaultAccountManager) GetNetworkMap(peerID string) (*NetworkMap, erro if err != nil { return nil, err } - return account.GetPeerNetworkMap(peer.ID, am.dnsDomain, validatedPeers), nil + return account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, validatedPeers), nil } // GetPeerNetwork returns the Network for a given peer -func (am *DefaultAccountManager) GetPeerNetwork(peerID string) (*Network, error) { - account, err := am.Store.GetAccountByPeerID(peerID) +func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID string) (*Network, error) { + account, err := am.Store.GetAccountByPeerID(ctx, peerID) if err != nil { return nil, err } @@ -328,10 +335,10 @@ func (am *DefaultAccountManager) GetPeerNetwork(peerID string) (*Network, error) // to it. We also add the User ID to the peer metadata to identify registrant. If no userID provided, then fail with status.PermissionDenied // Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused). // The peer property is just a placeholder for the Peer properties to pass further -func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, error) { +func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { if setupKey == "" && userID == "" { // no auth method provided => reject access - return nil, nil, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login") + return nil, nil, nil, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login") } upperKey := strings.ToUpper(setupKey) @@ -342,13 +349,13 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P addedByUser = true accountID, err = am.Store.GetAccountIDByUserID(userID) } else { - accountID, err = am.Store.GetAccountIDBySetupKey(setupKey) + accountID, err = am.Store.GetAccountIDBySetupKey(ctx, setupKey) } if err != nil { - return nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found") + return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found") } - unlock := am.Store.AcquireAccountWriteLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer func() { if unlock != nil { unlock() @@ -357,14 +364,14 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P var account *Account // ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account) - account, err = am.Store.GetAccount(accountID) + account, err = am.Store.GetAccount(ctx, accountID) if err != nil { - return nil, nil, err + return nil, nil, nil, err } if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" { if am.idpManager != nil { - userdata, err := am.lookupUserInCache(userID, account) + userdata, err := am.lookupUserInCache(ctx, userID, account) if err == nil && userdata != nil { peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0]) } @@ -378,7 +385,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P // The connecting peer should be able to recover with a retry. _, err = account.FindPeerByPubKey(peer.Key) if err == nil { - return nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered") + return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered") } opEvent := &activity.Event{ @@ -392,11 +399,11 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P // validate the setup key if adding with a key sk, err := account.FindSetupKey(upperKey) if err != nil { - return nil, nil, err + return nil, nil, nil, err } if !sk.IsValid() { - return nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid") + return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid") } account.SetupKeys[sk.Key] = sk.IncrementUsage() @@ -414,14 +421,14 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P newLabel, err := getPeerHostLabel(peer.Meta.Hostname, existingLabels) if err != nil { - return nil, nil, err + return nil, nil, nil, err } peer.DNSLabel = newLabel network := account.Network nextIp, err := AllocatePeerIP(network.Net, takenIps) if err != nil { - return nil, nil, err + return nil, nil, nil, err } registrationTime := time.Now().UTC() @@ -448,7 +455,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P // add peer to 'All' group group, err := account.GetGroupAll() if err != nil { - return nil, nil, err + return nil, nil, nil, err } group.Peers = append(group.Peers, newPeer.ID) @@ -456,12 +463,12 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P if addedByUser { groupsToAdd, err = account.getUserGroups(userID) if err != nil { - return nil, nil, err + return nil, nil, nil, err } } else { groupsToAdd, err = account.getSetupKeyGroups(upperKey) if err != nil { - return nil, nil, err + return nil, nil, nil, err } } @@ -473,21 +480,21 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P } } - newPeer = am.integratedPeerValidator.PreparePeer(account.Id, newPeer, account.GetPeerGroupsList(newPeer.ID), account.Settings.Extra) + newPeer = am.integratedPeerValidator.PreparePeer(ctx, account.Id, newPeer, account.GetPeerGroupsList(newPeer.ID), account.Settings.Extra) if addedByUser { user, err := account.FindUser(userID) if err != nil { - return nil, nil, status.Errorf(status.Internal, "couldn't find user") + return nil, nil, nil, status.Errorf(status.Internal, "couldn't find user") } user.updateLastLogin(newPeer.LastLogin) } account.Peers[newPeer.ID] = newPeer account.Network.IncSerial() - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { - return nil, nil, err + return nil, nil, nil, err } // Account is saved, we can release the lock @@ -500,61 +507,79 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P opEvent.Meta["setup_key_name"] = setupKeyName } - am.StoreEvent(opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) + am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) approvedPeersMap, err := am.GetValidatedPeers(account) if err != nil { - return nil, nil, err + return nil, nil, nil, err } - networkMap := account.GetPeerNetworkMap(newPeer.ID, am.dnsDomain, approvedPeersMap) - return newPeer, networkMap, nil + + postureChecks := am.getPeerPostureChecks(account, peer) + networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, am.dnsDomain, approvedPeersMap) + return newPeer, networkMap, postureChecks, nil } // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible -func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, error) { +func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey) if err != nil { - return nil, nil, status.NewPeerNotRegisteredError() + return nil, nil, nil, status.NewPeerNotRegisteredError() } err = checkIfPeerOwnerIsBlocked(peer, account) if err != nil { - return nil, nil, err + return nil, nil, nil, err } - if peerLoginExpired(peer, account.Settings) { - return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more") + if peerLoginExpired(ctx, peer, account.Settings) { + return nil, nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more") } - peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) + peer, updated := updatePeerMeta(peer, sync.Meta, account) + if updated { + err = am.Store.SaveAccount(ctx, account) + if err != nil { + return nil, nil, nil, err + } + + if sync.UpdateAccountPeers { + am.updateAccountPeers(ctx, account) + } + } + + peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) if err != nil { - return nil, nil, err + return nil, nil, nil, err } + var postureChecks []*posture.Checks + if peerNotValid { emptyMap := &NetworkMap{ Network: account.Network.Copy(), } - return peer, emptyMap, nil + return peer, emptyMap, postureChecks, nil } if isStatusChanged { - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) } validPeersMap, err := am.GetValidatedPeers(account) if err != nil { - return nil, nil, err + return nil, nil, nil, err } - return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, validPeersMap), nil + postureChecks = am.getPeerPostureChecks(account, peer) + + return peer, account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, validPeersMap), postureChecks, nil } // LoginPeer logs in or registers a peer. // If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so. -func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) { - accountID, err := am.Store.GetAccountIDByPeerPubKey(login.WireGuardPubKey) +func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { + accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, login.WireGuardPubKey) if err != nil { if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { // we couldn't find this peer by its public key which can mean that peer hasn't been registered yet. @@ -567,7 +592,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw if am.geo != nil && login.ConnectionIP != nil { location, err := am.geo.Lookup(login.ConnectionIP) if err != nil { - log.Warnf("failed to get location for new peer realip: [%s]: %v", login.ConnectionIP.String(), err) + log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", login.ConnectionIP.String(), err) } else { newPeer.Location.ConnectionIP = login.ConnectionIP newPeer.Location.CountryCode = location.Country.ISOCode @@ -577,49 +602,50 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw } } - return am.AddPeer(login.SetupKey, login.UserID, newPeer) + return am.AddPeer(ctx, login.SetupKey, login.UserID, newPeer) } - log.Errorf("failed while logging in peer %s: %v", login.WireGuardPubKey, err) - return nil, nil, status.Errorf(status.Internal, "failed while logging in peer") + + log.WithContext(ctx).Errorf("failed while logging in peer %s: %v", login.WireGuardPubKey, err) + return nil, nil, nil, status.Errorf(status.Internal, "failed while logging in peer") } - peer, err := am.Store.GetPeerByPeerPubKey(login.WireGuardPubKey) + peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey) if err != nil { - return nil, nil, status.NewPeerNotRegisteredError() + return nil, nil, nil, status.NewPeerNotRegisteredError() } - accSettings, err := am.Store.GetAccountSettings(accountID) + accSettings, err := am.Store.GetAccountSettings(ctx, accountID) if err != nil { - return nil, nil, status.Errorf(status.Internal, "failed to get account settings: %s", err) + return nil, nil, nil, status.Errorf(status.Internal, "failed to get account settings: %s", err) } var isWriteLock bool // duplicated logic from after the lock to have an early exit - expired := peerLoginExpired(peer, accSettings) + expired := peerLoginExpired(ctx, peer, accSettings) switch { case expired: - if err := checkAuth(login.UserID, peer); err != nil { - return nil, nil, err + if err := checkAuth(ctx, login.UserID, peer); err != nil { + return nil, nil, nil, err } isWriteLock = true - log.Debugf("peer login expired, acquiring write lock") + log.WithContext(ctx).Debugf("peer login expired, acquiring write lock") case peer.UpdateMetaIfNew(login.Meta): isWriteLock = true - log.Debugf("peer changed meta, acquiring write lock") + log.WithContext(ctx).Debugf("peer changed meta, acquiring write lock") default: isWriteLock = false - log.Debugf("peer meta is the same, acquiring read lock") + log.WithContext(ctx).Debugf("peer meta is the same, acquiring read lock") } var unlock func() if isWriteLock { - unlock = am.Store.AcquireAccountWriteLock(accountID) + unlock = am.Store.AcquireAccountWriteLock(ctx, accountID) } else { - unlock = am.Store.AcquireAccountReadLock(accountID) + unlock = am.Store.AcquireAccountReadLock(ctx, accountID) } defer func() { if unlock != nil { @@ -628,28 +654,28 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw }() // fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { - return nil, nil, err + return nil, nil, nil, err } peer, err = account.FindPeerByPubKey(login.WireGuardPubKey) if err != nil { - return nil, nil, status.NewPeerNotRegisteredError() + return nil, nil, nil, status.NewPeerNotRegisteredError() } err = checkIfPeerOwnerIsBlocked(peer, account) if err != nil { - return nil, nil, err + return nil, nil, nil, err } // this flag prevents unnecessary calls to the persistent store. shouldStoreAccount := false updateRemotePeers := false - if peerLoginExpired(peer, account.Settings) { - err = checkAuth(login.UserID, peer) + if peerLoginExpired(ctx, peer, account.Settings) { + err = checkAuth(ctx, login.UserID, peer) if err != nil { - return nil, nil, err + return nil, nil, nil, err } // If peer was expired before and if it reached this point, it is re-authenticated. // UserID is present, meaning that JWT validation passed successfully in the API layer. @@ -660,57 +686,60 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw // sync user last login with peer last login user, err := account.FindUser(login.UserID) if err != nil { - return nil, nil, status.Errorf(status.Internal, "couldn't find user") + return nil, nil, nil, status.Errorf(status.Internal, "couldn't find user") } user.updateLastLogin(peer.LastLogin) - am.StoreEvent(login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain())) + am.StoreEvent(ctx, login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain())) } - isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) + isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) if err != nil { - return nil, nil, err + return nil, nil, nil, err } peer, updated := updatePeerMeta(peer, login.Meta, account) if updated { shouldStoreAccount = true } - peer, err = am.checkAndUpdatePeerSSHKey(peer, account, login.SSHKey) + peer, err = am.checkAndUpdatePeerSSHKey(ctx, peer, account, login.SSHKey) if err != nil { - return nil, nil, err + return nil, nil, nil, err } if shouldStoreAccount { if !isWriteLock { - log.Errorf("account %s should be stored but is not write locked", accountID) - return nil, nil, status.Errorf(status.Internal, "account should be stored but is not write locked") + log.WithContext(ctx).Errorf("account %s should be stored but is not write locked", accountID) + return nil, nil, nil, status.Errorf(status.Internal, "account should be stored but is not write locked") } - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { - return nil, nil, err + return nil, nil, nil, err } } unlock() unlock = nil if updateRemotePeers || isStatusChanged { - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) } + var postureChecks []*posture.Checks + if isRequiresApproval { emptyMap := &NetworkMap{ Network: account.Network.Copy(), } - return peer, emptyMap, nil + return peer, emptyMap, postureChecks, nil } approvedPeersMap, err := am.GetValidatedPeers(account) if err != nil { - return nil, nil, err + return nil, nil, nil, err } + postureChecks = am.getPeerPostureChecks(account, peer) - return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap), nil + return peer, account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap), postureChecks, nil } func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error { @@ -726,23 +755,23 @@ func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error { return nil } -func checkAuth(loginUserID string, peer *nbpeer.Peer) error { +func checkAuth(ctx context.Context, loginUserID string, peer *nbpeer.Peer) error { if loginUserID == "" { // absence of a user ID indicates that JWT wasn't provided. return status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more") } if peer.UserID != loginUserID { - log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, loginUserID) + log.WithContext(ctx).Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, loginUserID) return status.Errorf(status.Unauthenticated, "can't login") } return nil } -func peerLoginExpired(peer *nbpeer.Peer, settings *Settings) bool { +func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings) bool { expired, expiresIn := peer.LoginExpired(settings.PeerLoginExpiration) expired = settings.PeerLoginExpirationEnabled && expired if expired || peer.Status.LoginExpired { - log.Debugf("peer's %s login expired %v ago", peer.ID, expiresIn) + log.WithContext(ctx).Debugf("peer's %s login expired %v ago", peer.ID, expiresIn) return true } return false @@ -753,48 +782,48 @@ func updatePeerLastLogin(peer *nbpeer.Peer, account *Account) { account.UpdatePeer(peer) } -func (am *DefaultAccountManager) checkAndUpdatePeerSSHKey(peer *nbpeer.Peer, account *Account, newSSHKey string) (*nbpeer.Peer, error) { +func (am *DefaultAccountManager) checkAndUpdatePeerSSHKey(ctx context.Context, peer *nbpeer.Peer, account *Account, newSSHKey string) (*nbpeer.Peer, error) { if len(newSSHKey) == 0 { - log.Debugf("no new SSH key provided for peer %s, skipping update", peer.ID) + log.WithContext(ctx).Debugf("no new SSH key provided for peer %s, skipping update", peer.ID) return peer, nil } if peer.SSHKey == newSSHKey { - log.Debugf("same SSH key provided for peer %s, skipping update", peer.ID) + log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peer.ID) return peer, nil } peer.SSHKey = newSSHKey account.UpdatePeer(peer) - err := am.Store.SaveAccount(account) + err := am.Store.SaveAccount(ctx, account) if err != nil { return nil, err } // trigger network map update - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) return peer, nil } // UpdatePeerSSHKey updates peer's public SSH key -func (am *DefaultAccountManager) UpdatePeerSSHKey(peerID string, sshKey string) error { +func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error { if sshKey == "" { - log.Debugf("empty SSH key provided for peer %s, skipping update", peerID) + log.WithContext(ctx).Debugf("empty SSH key provided for peer %s, skipping update", peerID) return nil } - account, err := am.Store.GetAccountByPeerID(peerID) + account, err := am.Store.GetAccountByPeerID(ctx, peerID) if err != nil { return err } - unlock := am.Store.AcquireAccountWriteLock(account.Id) + unlock := am.Store.AcquireAccountWriteLock(ctx, account.Id) defer unlock() // ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account) - account, err = am.Store.GetAccount(account.Id) + account, err = am.Store.GetAccount(ctx, account.Id) if err != nil { return err } @@ -805,30 +834,30 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerID string, sshKey string) } if peer.SSHKey == sshKey { - log.Debugf("same SSH key provided for peer %s, skipping update", peerID) + log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peerID) return nil } peer.SSHKey = sshKey account.UpdatePeer(peer) - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return err } // trigger network map update - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) return nil } // GetPeer for a given accountID, peerID and userID error if not found. -func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -865,7 +894,7 @@ func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*nbp } for _, p := range userPeers { - aclPeers, _ := account.getPeerConnectionResources(p.ID, approvedPeersMap) + aclPeers, _ := account.getPeerConnectionResources(ctx, p.ID, approvedPeersMap) for _, aclPeer := range aclPeers { if aclPeer.ID == peerID { return peer, nil @@ -886,21 +915,23 @@ func updatePeerMeta(peer *nbpeer.Peer, meta nbpeer.PeerSystemMeta, account *Acco // updateAccountPeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. -func (am *DefaultAccountManager) updateAccountPeers(account *Account) { +func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) { peers := account.GetPeers() approvedPeersMap, err := am.GetValidatedPeers(account) if err != nil { - log.Errorf("failed send out updates to peers, failed to validate peer: %v", err) + log.WithContext(ctx).Errorf("failed send out updates to peers, failed to validate peer: %v", err) return } for _, peer := range peers { if !am.peersUpdateManager.HasChannel(peer.ID) { - log.Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) + log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) continue } - remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap) - update := toSyncResponse(nil, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain()) - am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update}) + + postureChecks := am.getPeerPostureChecks(account, peer) + remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap) + update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks) + am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update}) } } diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index f71f629f6..4f808a79e 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -4,6 +4,7 @@ import ( "fmt" "net" "net/netip" + "slices" "time" ) @@ -79,6 +80,13 @@ type Environment struct { Platform string } +// File is a file on the system. +type File struct { + Path string + Exist bool + ProcessIsRunning bool +} + // PeerSystemMeta is a metadata of a Peer machine system type PeerSystemMeta struct { //nolint:revive Hostname string @@ -96,24 +104,22 @@ type PeerSystemMeta struct { //nolint:revive SystemProductName string SystemManufacturer string Environment Environment `gorm:"serializer:json"` + Files []File `gorm:"serializer:json"` } func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool { - if len(p.NetworkAddresses) != len(other.NetworkAddresses) { + equalNetworkAddresses := slices.EqualFunc(p.NetworkAddresses, other.NetworkAddresses, func(addr NetworkAddress, oAddr NetworkAddress) bool { + return addr.Mac == oAddr.Mac && addr.NetIP == oAddr.NetIP + }) + if !equalNetworkAddresses { return false } - for _, addr := range p.NetworkAddresses { - var found bool - for _, oAddr := range other.NetworkAddresses { - if addr.Mac == oAddr.Mac && addr.NetIP == oAddr.NetIP { - found = true - continue - } - } - if !found { - return false - } + equalFiles := slices.EqualFunc(p.Files, other.Files, func(file File, oFile File) bool { + return file.Path == oFile.Path && file.Exist == oFile.Exist && file.ProcessIsRunning == oFile.ProcessIsRunning + }) + if !equalFiles { + return false } return p.Hostname == other.Hostname && @@ -133,6 +139,26 @@ func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool { p.Environment.Platform == other.Environment.Platform } +func (p PeerSystemMeta) isEmpty() bool { + return p.Hostname == "" && + p.GoOS == "" && + p.Kernel == "" && + p.Core == "" && + p.Platform == "" && + p.OS == "" && + p.OSVersion == "" && + p.WtVersion == "" && + p.UIVersion == "" && + p.KernelVersion == "" && + len(p.NetworkAddresses) == 0 && + p.SystemSerialNumber == "" && + p.SystemProductName == "" && + p.SystemManufacturer == "" && + p.Environment.Cloud == "" && + p.Environment.Platform == "" && + len(p.Files) == 0 +} + // AddedWithSSOLogin indicates whether this peer has been added with an SSO login by a user. func (p *Peer) AddedWithSSOLogin() bool { return p.UserID != "" @@ -168,6 +194,10 @@ func (p *Peer) Copy() *Peer { // UpdateMetaIfNew updates peer's system metadata if new information is provided // returns true if meta was updated, false otherwise func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) bool { + if meta.isEmpty() { + return false + } + // Avoid overwriting UIVersion if the update was triggered sole by the CLI client if meta.UIVersion == "" { meta.UIVersion = p.Meta.UIVersion diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 6063cc2a7..407877296 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "testing" "time" @@ -80,7 +81,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { t.Fatal(err) } - setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false) if err != nil { t.Fatal("error creating setup key") return @@ -92,7 +93,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { return } - peer1, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, }) @@ -106,7 +107,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { t.Fatal(err) return } - _, _, err = manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, }) @@ -116,7 +117,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { return } - networkMap, err := manager.GetNetworkMap(peer1.ID) + networkMap, err := manager.GetNetworkMap(context.Background(), peer1.ID) if err != nil { t.Fatal(err) return @@ -165,7 +166,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { return } - peer1, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, }) @@ -179,7 +180,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { t.Fatal(err) return } - peer2, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ + peer2, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, }) @@ -188,13 +189,13 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { return } - policies, err := manager.ListPolicies(account.Id, userID) + policies, err := manager.ListPolicies(context.Background(), account.Id, userID) if err != nil { t.Errorf("expecting to get a list of rules, got failure %v", err) return } - err = manager.DeletePolicy(account.Id, policies[0].ID, userID) + err = manager.DeletePolicy(context.Background(), account.Id, policies[0].ID, userID) if err != nil { t.Errorf("expecting to delete 1 group, got failure %v", err) return @@ -213,12 +214,12 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { group1.Peers = append(group1.Peers, peer1.ID) group2.Peers = append(group2.Peers, peer2.ID) - err = manager.SaveGroup(account.Id, userID, &group1) + err = manager.SaveGroup(context.Background(), account.Id, userID, &group1) if err != nil { t.Errorf("expecting group1 to be added, got failure %v", err) return } - err = manager.SaveGroup(account.Id, userID, &group2) + err = manager.SaveGroup(context.Background(), account.Id, userID, &group2) if err != nil { t.Errorf("expecting group2 to be added, got failure %v", err) return @@ -235,13 +236,13 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { Action: PolicyTrafficActionAccept, }, } - err = manager.SavePolicy(account.Id, userID, &policy) + err = manager.SavePolicy(context.Background(), account.Id, userID, &policy) if err != nil { t.Errorf("expecting rule to be added, got failure %v", err) return } - networkMap1, err := manager.GetNetworkMap(peer1.ID) + networkMap1, err := manager.GetNetworkMap(context.Background(), peer1.ID) if err != nil { t.Fatal(err) return @@ -264,7 +265,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { ) } - networkMap2, err := manager.GetNetworkMap(peer2.ID) + networkMap2, err := manager.GetNetworkMap(context.Background(), peer2.ID) if err != nil { t.Fatal(err) return @@ -283,13 +284,13 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { } policy.Enabled = false - err = manager.SavePolicy(account.Id, userID, &policy) + err = manager.SavePolicy(context.Background(), account.Id, userID, &policy) if err != nil { t.Errorf("expecting rule to be added, got failure %v", err) return } - networkMap1, err = manager.GetNetworkMap(peer1.ID) + networkMap1, err = manager.GetNetworkMap(context.Background(), peer1.ID) if err != nil { t.Fatal(err) return @@ -304,7 +305,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { return } - networkMap2, err = manager.GetNetworkMap(peer2.ID) + networkMap2, err = manager.GetNetworkMap(context.Background(), peer2.ID) if err != nil { t.Fatal(err) return @@ -329,7 +330,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) { t.Fatal(err) } - setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false) if err != nil { t.Fatal("error creating setup key") return @@ -341,7 +342,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) { return } - peer1, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, }) @@ -355,7 +356,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) { t.Fatal(err) return } - _, _, err = manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, }) @@ -365,7 +366,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) { return } - network, err := manager.GetPeerNetwork(peer1.ID) + network, err := manager.GetPeerNetwork(context.Background(), peer1.ID) if err != nil { t.Fatal(err) return @@ -387,21 +388,21 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { accountID := "test_account" adminUser := "account_creator" someUser := "some_user" - account := newAccountWithId(accountID, adminUser, "") + account := newAccountWithId(context.Background(), accountID, adminUser, "") account.Users[someUser] = &User{ Id: someUser, Role: UserRoleUser, } account.Settings.RegularUsersViewBlocked = false - err = manager.Store.SaveAccount(account) + err = manager.Store.SaveAccount(context.Background(), account) if err != nil { t.Fatal(err) return } // two peers one added by a regular user and one with a setup key - setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false) if err != nil { t.Fatal("error creating setup key") return @@ -413,7 +414,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { return } - peer1, _, err := manager.AddPeer("", someUser, &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(context.Background(), "", someUser, &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, }) @@ -429,7 +430,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { } // the second peer added with a setup key - peer2, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ + peer2, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, }) @@ -439,7 +440,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { } // the user can see its own peer - peer, err := manager.GetPeer(accountID, peer1.ID, someUser) + peer, err := manager.GetPeer(context.Background(), accountID, peer1.ID, someUser) if err != nil { t.Fatal(err) return @@ -447,7 +448,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { assert.NotNil(t, peer) // the user can see peer2 because peer1 of the user has access to peer2 due to the All group and the default rule 0 all-to-all access - peer, err = manager.GetPeer(accountID, peer2.ID, someUser) + peer, err = manager.GetPeer(context.Background(), accountID, peer2.ID, someUser) if err != nil { t.Fatal(err) return @@ -456,7 +457,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { // delete the all-to-all policy so that user's peer1 has no access to peer2 for _, policy := range account.Policies { - err = manager.DeletePolicy(accountID, policy.ID, adminUser) + err = manager.DeletePolicy(context.Background(), accountID, policy.ID, adminUser) if err != nil { t.Fatal(err) return @@ -464,18 +465,18 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { } // at this point the user can't see the details of peer2 - peer, err = manager.GetPeer(accountID, peer2.ID, someUser) //nolint + peer, err = manager.GetPeer(context.Background(), accountID, peer2.ID, someUser) //nolint assert.Error(t, err) // admin users can always access all the peers - peer, err = manager.GetPeer(accountID, peer1.ID, adminUser) + peer, err = manager.GetPeer(context.Background(), accountID, peer1.ID, adminUser) if err != nil { t.Fatal(err) return } assert.NotNil(t, peer) - peer, err = manager.GetPeer(accountID, peer2.ID, adminUser) + peer, err = manager.GetPeer(context.Background(), accountID, peer2.ID, adminUser) if err != nil { t.Fatal(err) return @@ -574,7 +575,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { accountID := "test_account" adminUser := "account_creator" someUser := "some_user" - account := newAccountWithId(accountID, adminUser, "") + account := newAccountWithId(context.Background(), accountID, adminUser, "") account.Users[someUser] = &User{ Id: someUser, Role: testCase.role, @@ -583,7 +584,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { account.Policies = []*Policy{} account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings - err = manager.Store.SaveAccount(account) + err = manager.Store.SaveAccount(context.Background(), account) if err != nil { t.Fatal(err) return @@ -601,7 +602,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { return } - _, _, err = manager.AddPeer("", someUser, &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", someUser, &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, }) @@ -610,7 +611,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { return } - _, _, err = manager.AddPeer("", adminUser, &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", adminUser, &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, }) @@ -619,7 +620,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { return } - peers, err := manager.GetPeers(accountID, someUser) + peers, err := manager.GetPeers(context.Background(), accountID, someUser) if err != nil { t.Fatal(err) return diff --git a/management/server/policy.go b/management/server/policy.go index 5206df9e9..a70d7f0ed 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -1,6 +1,7 @@ package server import ( + "context" _ "embed" "strconv" "strings" @@ -211,9 +212,9 @@ type FirewallRule struct { // getPeerConnectionResources for a given peer // // This function returns the list of peers and firewall rules that are applicable to a given peer. -func (a *Account) getPeerConnectionResources(peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { +func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { - generateResources, getAccumulatedResources := a.connResourcesGenerator() + generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx) for _, policy := range a.Policies { if !policy.Enabled { continue @@ -224,8 +225,8 @@ func (a *Account) getPeerConnectionResources(peerID string, validatedPeersMap ma continue } - sourcePeers, peerInSources := getAllPeersFromGroups(a, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) - destinationPeers, peerInDestinations := getAllPeersFromGroups(a, rule.Destinations, peerID, nil, validatedPeersMap) + sourcePeers, peerInSources := getAllPeersFromGroups(ctx, a, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) + destinationPeers, peerInDestinations := getAllPeersFromGroups(ctx, a, rule.Destinations, peerID, nil, validatedPeersMap) if rule.Bidirectional { if peerInSources { @@ -254,7 +255,7 @@ func (a *Account) getPeerConnectionResources(peerID string, validatedPeersMap ma // The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer. // It safe to call the generator function multiple times for same peer and different rules no duplicates will be // generated. The accumulator function returns the result of all the generator calls. -func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) { +func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) { rulesExists := make(map[string]struct{}) peersExists := make(map[string]struct{}) rules := make([]*FirewallRule, 0) @@ -262,7 +263,7 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, in all, err := a.GetGroupAll() if err != nil { - log.Errorf("failed to get group all: %v", err) + log.WithContext(ctx).Errorf("failed to get group all: %v", err) all = &nbgroup.Group{} } @@ -313,11 +314,11 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, in } // GetPolicy from the store -func (am *DefaultAccountManager) GetPolicy(accountID, policyID, userID string) (*Policy, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -341,11 +342,11 @@ func (am *DefaultAccountManager) GetPolicy(accountID, policyID, userID string) ( } // SavePolicy in the store -func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Policy) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } @@ -353,7 +354,7 @@ func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Po exists := am.savePolicy(account, policy) account.Network.IncSerial() - if err = am.Store.SaveAccount(account); err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return err } @@ -361,19 +362,19 @@ func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Po if exists { action = activity.PolicyUpdated } - am.StoreEvent(userID, policy.ID, accountID, action, policy.EventMeta()) + am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) return nil } // DeletePolicy from the store -func (am *DefaultAccountManager) DeletePolicy(accountID, policyID, userID string) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } @@ -384,23 +385,23 @@ func (am *DefaultAccountManager) DeletePolicy(accountID, policyID, userID string } account.Network.IncSerial() - if err = am.Store.SaveAccount(account); err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return err } - am.StoreEvent(userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) + am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) return nil } // ListPolicies from the store -func (am *DefaultAccountManager) ListPolicies(accountID, userID string) ([]*Policy, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -490,7 +491,7 @@ func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule { // // Important: Posture checks are applicable only to source group peers, // for destination group peers, call this method with an empty list of sourcePostureChecksIDs -func getAllPeersFromGroups(account *Account, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { +func getAllPeersFromGroups(ctx context.Context, account *Account, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { peerInGroups := false filteredPeers := make([]*nbpeer.Peer, 0, len(groups)) for _, g := range groups { @@ -506,7 +507,7 @@ func getAllPeersFromGroups(account *Account, groups []string, peerID string, sou } // validate the peer based on policy posture checks applied - isValid := account.validatePostureChecksOnPeer(sourcePostureChecksIDs, peer.ID) + isValid := account.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID) if !isValid { continue } @@ -527,7 +528,7 @@ func getAllPeersFromGroups(account *Account, groups []string, peerID string, sou } // validatePostureChecksOnPeer validates the posture checks on a peer -func (a *Account) validatePostureChecksOnPeer(sourcePostureChecksID []string, peerID string) bool { +func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePostureChecksID []string, peerID string) bool { peer, ok := a.Peers[peerID] if !ok && peer == nil { return false @@ -540,9 +541,9 @@ func (a *Account) validatePostureChecksOnPeer(sourcePostureChecksID []string, pe } for _, check := range postureChecks.GetChecks() { - isValid, err := check.Check(*peer) + isValid, err := check.Check(ctx, *peer) if err != nil { - log.Debugf("an error occurred check %s: on peer: %s :%s", check.Name(), peer.ID, err.Error()) + log.WithContext(ctx).Debugf("an error occurred check %s: on peer: %s :%s", check.Name(), peer.ID, err.Error()) } if !isValid { return false diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 1ea3bb379..bf9a53d16 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "net" "testing" @@ -143,14 +144,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) { t.Run("check that all peers get map", func(t *testing.T) { for _, p := range account.Peers { - peers, firewallRules := account.getPeerConnectionResources(p.ID, validatedPeers) + peers, firewallRules := account.getPeerConnectionResources(context.Background(), p.ID, validatedPeers) assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present") assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present") } }) t.Run("check first peer map details", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources("peerB", validatedPeers) + peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", validatedPeers) assert.Len(t, peers, 7) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) @@ -386,7 +387,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { } t.Run("check first peer map", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers) + peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) assert.Contains(t, peers, account.Peers["peerC"]) epectedFirewallRules := []*FirewallRule{ @@ -414,7 +415,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }) t.Run("check second peer map", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources("peerC", approvedPeers) + peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) assert.Contains(t, peers, account.Peers["peerB"]) epectedFirewallRules := []*FirewallRule{ @@ -444,7 +445,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { account.Policies[1].Rules[0].Bidirectional = false t.Run("check first peer map directional only", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers) + peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) assert.Contains(t, peers, account.Peers["peerC"]) epectedFirewallRules := []*FirewallRule{ @@ -465,7 +466,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }) t.Run("check second peer map directional only", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources("peerC", approvedPeers) + peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) assert.Contains(t, peers, account.Peers["peerB"]) epectedFirewallRules := []*FirewallRule{ @@ -662,7 +663,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { t.Run("verify peer's network map with default group peer list", func(t *testing.T) { // peerB doesn't fulfill the NB posture check but is included in the destination group Swarm, // will establish a connection with all source peers satisfying the NB posture check. - peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers) + peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -672,7 +673,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.getPeerConnectionResources("peerC", approvedPeers) + peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, 1) expectedFirewallRules := []*FirewallRule{ @@ -688,7 +689,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources("peerE", approvedPeers) + peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -698,7 +699,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources("peerI", approvedPeers) + peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -713,19 +714,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers) + peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules = account.getPeerConnectionResources("peerI", approvedPeers) + peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.getPeerConnectionResources("peerC", approvedPeers) + peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers)) @@ -740,14 +741,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources("peerE", approvedPeers) + peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers) assert.Len(t, peers, 3) assert.Len(t, firewallRules, 3) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerD"]) - peers, firewallRules = account.getPeerConnectionResources("peerA", approvedPeers) + peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerA", approvedPeers) assert.Len(t, peers, 5) // assert peers from Group Swarm assert.Contains(t, peers, account.Peers["peerD"]) diff --git a/management/server/posture/checks.go b/management/server/posture/checks.go index 23b0e8379..f2739dddf 100644 --- a/management/server/posture/checks.go +++ b/management/server/posture/checks.go @@ -1,8 +1,10 @@ package posture import ( - "fmt" + "context" + "errors" "net/netip" + "regexp" "github.com/hashicorp/go-version" "github.com/rs/xid" @@ -17,15 +19,21 @@ const ( OSVersionCheckName = "OSVersionCheck" GeoLocationCheckName = "GeoLocationCheck" PeerNetworkRangeCheckName = "PeerNetworkRangeCheck" + ProcessCheckName = "ProcessCheck" CheckActionAllow string = "allow" CheckActionDeny string = "deny" ) +var ( + countryCodeRegex = regexp.MustCompile("^[a-zA-Z]{2}$") +) + // Check represents an interface for performing a check on a peer. type Check interface { - Check(peer nbpeer.Peer) (bool, error) Name() string + Check(ctx context.Context, peer nbpeer.Peer) (bool, error) + Validate() error } type Checks struct { @@ -51,6 +59,7 @@ type ChecksDefinition struct { OSVersionCheck *OSVersionCheck `json:",omitempty"` GeoLocationCheck *GeoLocationCheck `json:",omitempty"` PeerNetworkRangeCheck *PeerNetworkRangeCheck `json:",omitempty"` + ProcessCheck *ProcessCheck `json:",omitempty"` } // Copy returns a copy of a checks definition. @@ -96,6 +105,13 @@ func (cd ChecksDefinition) Copy() ChecksDefinition { } copy(cdCopy.PeerNetworkRangeCheck.Ranges, peerNetRangeCheck.Ranges) } + if cd.ProcessCheck != nil { + processCheck := cd.ProcessCheck + cdCopy.ProcessCheck = &ProcessCheck{ + Processes: make([]Process, len(processCheck.Processes)), + } + copy(cdCopy.ProcessCheck.Processes, processCheck.Processes) + } return cdCopy } @@ -136,6 +152,9 @@ func (pc *Checks) GetChecks() []Check { if pc.Checks.PeerNetworkRangeCheck != nil { checks = append(checks, pc.Checks.PeerNetworkRangeCheck) } + if pc.Checks.ProcessCheck != nil { + checks = append(checks, pc.Checks.ProcessCheck) + } return checks } @@ -191,6 +210,10 @@ func buildPostureCheck(postureChecksID string, name string, description string, } } + if processCheck := checks.ProcessCheck; processCheck != nil { + postureChecks.Checks.ProcessCheck = toProcessCheck(processCheck) + } + return &postureChecks, nil } @@ -221,6 +244,10 @@ func (pc *Checks) ToAPIResponse() *api.PostureCheck { checks.PeerNetworkRangeCheck = toPeerNetworkRangeCheckResponse(pc.Checks.PeerNetworkRangeCheck) } + if pc.Checks.ProcessCheck != nil { + checks.ProcessCheck = toProcessCheckResponse(pc.Checks.ProcessCheck) + } + return &api.PostureCheck{ Id: pc.ID, Name: pc.Name, @@ -229,44 +256,20 @@ func (pc *Checks) ToAPIResponse() *api.PostureCheck { } } +// Validate checks the validity of a posture checks. func (pc *Checks) Validate() error { - if check := pc.Checks.NBVersionCheck; check != nil { - if !isVersionValid(check.MinVersion) { - return fmt.Errorf("%s version: %s is not valid", check.Name(), check.MinVersion) - } + if pc.Name == "" { + return errors.New("posture checks name shouldn't be empty") } - if osCheck := pc.Checks.OSVersionCheck; osCheck != nil { - if osCheck.Android != nil { - if !isVersionValid(osCheck.Android.MinVersion) { - return fmt.Errorf("%s android version: %s is not valid", osCheck.Name(), osCheck.Android.MinVersion) - } - } + checks := pc.GetChecks() + if len(checks) == 0 { + return errors.New("posture checks shouldn't be empty") + } - if osCheck.Ios != nil { - if !isVersionValid(osCheck.Ios.MinVersion) { - return fmt.Errorf("%s ios version: %s is not valid", osCheck.Name(), osCheck.Ios.MinVersion) - } - } - - if osCheck.Darwin != nil { - if !isVersionValid(osCheck.Darwin.MinVersion) { - return fmt.Errorf("%s darwin version: %s is not valid", osCheck.Name(), osCheck.Darwin.MinVersion) - } - } - - if osCheck.Linux != nil { - if !isVersionValid(osCheck.Linux.MinKernelVersion) { - return fmt.Errorf("%s linux kernel version: %s is not valid", osCheck.Name(), - osCheck.Linux.MinKernelVersion) - } - } - - if osCheck.Windows != nil { - if !isVersionValid(osCheck.Windows.MinKernelVersion) { - return fmt.Errorf("%s windows kernel version: %s is not valid", osCheck.Name(), - osCheck.Windows.MinKernelVersion) - } + for _, check := range checks { + if err := check.Validate(); err != nil { + return err } } @@ -352,3 +355,40 @@ func toPeerNetworkRangeCheck(check *api.PeerNetworkRangeCheck) (*PeerNetworkRang Action: string(check.Action), }, nil } + +func toProcessCheckResponse(check *ProcessCheck) *api.ProcessCheck { + processes := make([]api.Process, 0, len(check.Processes)) + for i := range check.Processes { + processes = append(processes, api.Process{ + LinuxPath: &check.Processes[i].LinuxPath, + MacPath: &check.Processes[i].MacPath, + WindowsPath: &check.Processes[i].WindowsPath, + }) + } + + return &api.ProcessCheck{ + Processes: processes, + } +} + +func toProcessCheck(check *api.ProcessCheck) *ProcessCheck { + processes := make([]Process, 0, len(check.Processes)) + for _, process := range check.Processes { + var p Process + if process.LinuxPath != nil { + p.LinuxPath = *process.LinuxPath + } + if process.MacPath != nil { + p.MacPath = *process.MacPath + } + if process.WindowsPath != nil { + p.WindowsPath = *process.WindowsPath + } + + processes = append(processes, p) + } + + return &ProcessCheck{ + Processes: processes, + } +} diff --git a/management/server/posture/checks_test.go b/management/server/posture/checks_test.go index d36d4f50c..16268b72d 100644 --- a/management/server/posture/checks_test.go +++ b/management/server/posture/checks_test.go @@ -150,9 +150,23 @@ func TestChecks_Validate(t *testing.T) { checks Checks expectedError bool }{ + { + name: "Empty name", + checks: Checks{}, + expectedError: true, + }, + { + name: "Empty checks", + checks: Checks{ + Name: "Default", + Checks: ChecksDefinition{}, + }, + expectedError: true, + }, { name: "Valid checks version", checks: Checks{ + Name: "default", Checks: ChecksDefinition{ NBVersionCheck: &NBVersionCheck{ MinVersion: "0.25.0", @@ -261,6 +275,14 @@ func TestChecks_Copy(t *testing.T) { }, Action: CheckActionDeny, }, + ProcessCheck: &ProcessCheck{ + Processes: []Process{ + { + MacPath: "/Applications/NetBird.app/Contents/MacOS/netbird", + WindowsPath: "C:\\ProgramData\\NetBird\\netbird.exe", + }, + }, + }, }, } checkCopy := check.Copy() diff --git a/management/server/posture/geo_location.go b/management/server/posture/geo_location.go index 856913a7a..8a1f38830 100644 --- a/management/server/posture/geo_location.go +++ b/management/server/posture/geo_location.go @@ -1,7 +1,9 @@ package posture import ( + "context" "fmt" + "slices" nbpeer "github.com/netbirdio/netbird/management/server/peer" ) @@ -24,7 +26,7 @@ type GeoLocationCheck struct { Action string } -func (g *GeoLocationCheck) Check(peer nbpeer.Peer) (bool, error) { +func (g *GeoLocationCheck) Check(_ context.Context, peer nbpeer.Peer) (bool, error) { // deny if the peer location is not evaluated if peer.Location.CountryCode == "" && peer.Location.CityName == "" { return false, fmt.Errorf("peer's location is not set") @@ -60,3 +62,28 @@ func (g *GeoLocationCheck) Check(peer nbpeer.Peer) (bool, error) { func (g *GeoLocationCheck) Name() string { return GeoLocationCheckName } + +func (g *GeoLocationCheck) Validate() error { + if g.Action == "" { + return fmt.Errorf("%s action shouldn't be empty", g.Name()) + } + + allowedActions := []string{CheckActionAllow, CheckActionDeny} + if !slices.Contains(allowedActions, g.Action) { + return fmt.Errorf("%s action is not valid", g.Name()) + } + + if len(g.Locations) == 0 { + return fmt.Errorf("%s locations shouldn't be empty", g.Name()) + } + + for _, loc := range g.Locations { + if loc.CountryCode == "" { + return fmt.Errorf("%s country code shouldn't be empty", g.Name()) + } + if !countryCodeRegex.MatchString(loc.CountryCode) { + return fmt.Errorf("%s country code must be 2 letters (ISO 3166-1 alpha-2 format)", g.Name()) + } + } + return nil +} diff --git a/management/server/posture/geo_location_test.go b/management/server/posture/geo_location_test.go index 267bbe0f2..a64919f0b 100644 --- a/management/server/posture/geo_location_test.go +++ b/management/server/posture/geo_location_test.go @@ -1,6 +1,7 @@ package posture import ( + "context" "testing" "github.com/netbirdio/netbird/management/server/peer" @@ -226,7 +227,7 @@ func TestGeoLocationCheck_Check(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - isValid, err := tt.check.Check(tt.input) + isValid, err := tt.check.Check(context.Background(), tt.input) if tt.wantErr { assert.Error(t, err) } else { @@ -236,3 +237,81 @@ func TestGeoLocationCheck_Check(t *testing.T) { }) } } + +func TestGeoLocationCheck_Validate(t *testing.T) { + testCases := []struct { + name string + check GeoLocationCheck + expectedError bool + }{ + { + name: "Valid location list", + check: GeoLocationCheck{ + Action: CheckActionAllow, + Locations: []Location{ + { + CountryCode: "DE", + CityName: "Berlin", + }, + }, + }, + expectedError: false, + }, + { + name: "Invalid empty location list", + check: GeoLocationCheck{ + Action: CheckActionDeny, + Locations: []Location{}, + }, + expectedError: true, + }, + { + name: "Invalid empty country name", + check: GeoLocationCheck{ + Action: CheckActionDeny, + Locations: []Location{ + { + CityName: "Los Angeles", + }, + }, + }, + expectedError: true, + }, + { + name: "Invalid check action", + check: GeoLocationCheck{ + Action: "unknownAction", + Locations: []Location{ + { + CountryCode: "DE", + CityName: "Berlin", + }, + }, + }, + expectedError: true, + }, + { + name: "Invalid country code", + check: GeoLocationCheck{ + Action: CheckActionAllow, + Locations: []Location{ + { + CountryCode: "USA", + }, + }, + }, + expectedError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.check.Validate() + if tc.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/management/server/posture/nb_version.go b/management/server/posture/nb_version.go index 0645b8f73..f63db85b1 100644 --- a/management/server/posture/nb_version.go +++ b/management/server/posture/nb_version.go @@ -1,6 +1,9 @@ package posture import ( + "context" + "fmt" + "github.com/hashicorp/go-version" log "github.com/sirupsen/logrus" @@ -13,7 +16,7 @@ type NBVersionCheck struct { var _ Check = (*NBVersionCheck)(nil) -func (n *NBVersionCheck) Check(peer nbpeer.Peer) (bool, error) { +func (n *NBVersionCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, error) { peerNBVersion, err := version.NewVersion(peer.Meta.WtVersion) if err != nil { return false, err @@ -28,7 +31,7 @@ func (n *NBVersionCheck) Check(peer nbpeer.Peer) (bool, error) { return true, nil } - log.Debugf("peer %s NB version %s is older than minimum allowed version %s", + log.WithContext(ctx).Debugf("peer %s NB version %s is older than minimum allowed version %s", peer.ID, peer.Meta.WtVersion, n.MinVersion) return false, nil @@ -37,3 +40,13 @@ func (n *NBVersionCheck) Check(peer nbpeer.Peer) (bool, error) { func (n *NBVersionCheck) Name() string { return NBVersionCheckName } + +func (n *NBVersionCheck) Validate() error { + if n.MinVersion == "" { + return fmt.Errorf("%s minimum version shouldn't be empty", n.Name()) + } + if !isVersionValid(n.MinVersion) { + return fmt.Errorf("%s version: %s is not valid", n.Name(), n.MinVersion) + } + return nil +} diff --git a/management/server/posture/nb_version_test.go b/management/server/posture/nb_version_test.go index de51c2283..1bf485453 100644 --- a/management/server/posture/nb_version_test.go +++ b/management/server/posture/nb_version_test.go @@ -1,6 +1,7 @@ package posture import ( + "context" "testing" "github.com/netbirdio/netbird/management/server/peer" @@ -98,7 +99,7 @@ func TestNBVersionCheck_Check(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - isValid, err := tt.check.Check(tt.input) + isValid, err := tt.check.Check(context.Background(), tt.input) if tt.wantErr { assert.Error(t, err) } else { @@ -108,3 +109,33 @@ func TestNBVersionCheck_Check(t *testing.T) { }) } } + +func TestNBVersionCheck_Validate(t *testing.T) { + testCases := []struct { + name string + check NBVersionCheck + expectedError bool + }{ + { + name: "Valid NBVersionCheck", + check: NBVersionCheck{MinVersion: "1.0"}, + expectedError: false, + }, + { + name: "Invalid NBVersionCheck", + check: NBVersionCheck{}, + expectedError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.check.Validate() + if tc.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/management/server/posture/network.go b/management/server/posture/network.go index 9bf969f4c..0fa6f6e71 100644 --- a/management/server/posture/network.go +++ b/management/server/posture/network.go @@ -1,11 +1,13 @@ package posture import ( + "context" "fmt" "net/netip" "slices" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/status" ) type PeerNetworkRangeCheck struct { @@ -15,7 +17,7 @@ type PeerNetworkRangeCheck struct { var _ Check = (*PeerNetworkRangeCheck)(nil) -func (p *PeerNetworkRangeCheck) Check(peer nbpeer.Peer) (bool, error) { +func (p *PeerNetworkRangeCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, error) { if len(peer.Meta.NetworkAddresses) == 0 { return false, fmt.Errorf("peer's does not contain peer network range addresses") } @@ -52,3 +54,19 @@ func (p *PeerNetworkRangeCheck) Check(peer nbpeer.Peer) (bool, error) { func (p *PeerNetworkRangeCheck) Name() string { return PeerNetworkRangeCheckName } + +func (p *PeerNetworkRangeCheck) Validate() error { + if p.Action == "" { + return status.Errorf(status.InvalidArgument, "action for peer network range check shouldn't be empty") + } + + allowedActions := []string{CheckActionAllow, CheckActionDeny} + if !slices.Contains(allowedActions, p.Action) { + return fmt.Errorf("%s action is not valid", p.Name()) + } + + if len(p.Ranges) == 0 { + return fmt.Errorf("%s network ranges shouldn't be empty", p.Name()) + } + return nil +} diff --git a/management/server/posture/network_test.go b/management/server/posture/network_test.go index 36ead4660..a841bbe08 100644 --- a/management/server/posture/network_test.go +++ b/management/server/posture/network_test.go @@ -1,6 +1,7 @@ package posture import ( + "context" "net/netip" "testing" @@ -137,7 +138,7 @@ func TestPeerNetworkRangeCheck_Check(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - isValid, err := tt.check.Check(tt.peer) + isValid, err := tt.check.Check(context.Background(), tt.peer) if tt.wantErr { assert.Error(t, err) } else { @@ -147,3 +148,52 @@ func TestPeerNetworkRangeCheck_Check(t *testing.T) { }) } } + +func TestNetworkCheck_Validate(t *testing.T) { + testCases := []struct { + name string + check PeerNetworkRangeCheck + expectedError bool + }{ + { + name: "Valid network range", + check: PeerNetworkRangeCheck{ + Action: CheckActionAllow, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("10.0.0.0/8"), + }, + }, + expectedError: false, + }, + { + name: "Invalid empty network range", + check: PeerNetworkRangeCheck{ + Action: CheckActionDeny, + Ranges: []netip.Prefix{}, + }, + expectedError: true, + }, + { + name: "Invalid check action", + check: PeerNetworkRangeCheck{ + Action: "unknownAction", + Ranges: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + }, + }, + expectedError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.check.Validate() + if tc.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/management/server/posture/os_version.go b/management/server/posture/os_version.go index 4c311f01b..411f4c2c6 100644 --- a/management/server/posture/os_version.go +++ b/management/server/posture/os_version.go @@ -1,11 +1,14 @@ package posture import ( + "context" + "fmt" "strings" "github.com/hashicorp/go-version" - nbpeer "github.com/netbirdio/netbird/management/server/peer" log "github.com/sirupsen/logrus" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" ) type MinVersionCheck struct { @@ -26,20 +29,20 @@ type OSVersionCheck struct { var _ Check = (*OSVersionCheck)(nil) -func (c *OSVersionCheck) Check(peer nbpeer.Peer) (bool, error) { +func (c *OSVersionCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, error) { peerGoOS := peer.Meta.GoOS switch peerGoOS { case "android": - return checkMinVersion(peerGoOS, peer.Meta.OSVersion, c.Android) + return checkMinVersion(ctx, peerGoOS, peer.Meta.OSVersion, c.Android) case "darwin": - return checkMinVersion(peerGoOS, peer.Meta.OSVersion, c.Darwin) + return checkMinVersion(ctx, peerGoOS, peer.Meta.OSVersion, c.Darwin) case "ios": - return checkMinVersion(peerGoOS, peer.Meta.OSVersion, c.Ios) + return checkMinVersion(ctx, peerGoOS, peer.Meta.OSVersion, c.Ios) case "linux": kernelVersion := strings.Split(peer.Meta.KernelVersion, "-")[0] - return checkMinKernelVersion(peerGoOS, kernelVersion, c.Linux) + return checkMinKernelVersion(ctx, peerGoOS, kernelVersion, c.Linux) case "windows": - return checkMinKernelVersion(peerGoOS, peer.Meta.KernelVersion, c.Windows) + return checkMinKernelVersion(ctx, peerGoOS, peer.Meta.KernelVersion, c.Windows) } return true, nil } @@ -48,9 +51,38 @@ func (c *OSVersionCheck) Name() string { return OSVersionCheckName } -func checkMinVersion(peerGoOS, peerVersion string, check *MinVersionCheck) (bool, error) { +func (c *OSVersionCheck) Validate() error { + if c.Android == nil && c.Darwin == nil && c.Ios == nil && c.Linux == nil && c.Windows == nil { + return fmt.Errorf("%s at least one OS version check is required", c.Name()) + } + + if c.Android != nil && !isVersionValid(c.Android.MinVersion) { + return fmt.Errorf("%s android version: %s is not valid", c.Name(), c.Android.MinVersion) + } + + if c.Ios != nil && !isVersionValid(c.Ios.MinVersion) { + return fmt.Errorf("%s ios version: %s is not valid", c.Name(), c.Ios.MinVersion) + } + + if c.Darwin != nil && !isVersionValid(c.Darwin.MinVersion) { + return fmt.Errorf("%s darwin version: %s is not valid", c.Name(), c.Darwin.MinVersion) + } + + if c.Linux != nil && !isVersionValid(c.Linux.MinKernelVersion) { + return fmt.Errorf("%s linux kernel version: %s is not valid", c.Name(), + c.Linux.MinKernelVersion) + } + + if c.Windows != nil && !isVersionValid(c.Windows.MinKernelVersion) { + return fmt.Errorf("%s windows kernel version: %s is not valid", c.Name(), + c.Windows.MinKernelVersion) + } + return nil +} + +func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinVersionCheck) (bool, error) { if check == nil { - log.Debugf("peer %s OS is not allowed in the check", peerGoOS) + log.WithContext(ctx).Debugf("peer %s OS is not allowed in the check", peerGoOS) return false, nil } @@ -68,14 +100,14 @@ func checkMinVersion(peerGoOS, peerVersion string, check *MinVersionCheck) (bool return true, nil } - log.Debugf("peer %s OS version %s is older than minimum allowed version %s", peerGoOS, peerVersion, check.MinVersion) + log.WithContext(ctx).Debugf("peer %s OS version %s is older than minimum allowed version %s", peerGoOS, peerVersion, check.MinVersion) return false, nil } -func checkMinKernelVersion(peerGoOS, peerVersion string, check *MinKernelVersionCheck) (bool, error) { +func checkMinKernelVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinKernelVersionCheck) (bool, error) { if check == nil { - log.Debugf("peer %s OS is not allowed in the check", peerGoOS) + log.WithContext(ctx).Debugf("peer %s OS is not allowed in the check", peerGoOS) return false, nil } @@ -93,7 +125,7 @@ func checkMinKernelVersion(peerGoOS, peerVersion string, check *MinKernelVersion return true, nil } - log.Debugf("peer %s kernel version %s is older than minimum allowed version %s", peerGoOS, peerVersion, check.MinKernelVersion) + log.WithContext(ctx).Debugf("peer %s kernel version %s is older than minimum allowed version %s", peerGoOS, peerVersion, check.MinKernelVersion) return false, nil } diff --git a/management/server/posture/os_version_test.go b/management/server/posture/os_version_test.go index 32bf52660..76343b0c4 100644 --- a/management/server/posture/os_version_test.go +++ b/management/server/posture/os_version_test.go @@ -1,6 +1,7 @@ package posture import ( + "context" "testing" "github.com/netbirdio/netbird/management/server/peer" @@ -140,7 +141,7 @@ func TestOSVersionCheck_Check(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - isValid, err := tt.check.Check(tt.input) + isValid, err := tt.check.Check(context.Background(), tt.input) if tt.wantErr { assert.Error(t, err) } else { @@ -150,3 +151,79 @@ func TestOSVersionCheck_Check(t *testing.T) { }) } } + +func TestOSVersionCheck_Validate(t *testing.T) { + testCases := []struct { + name string + check OSVersionCheck + expectedError bool + }{ + { + name: "Valid linux kernel version", + check: OSVersionCheck{ + Linux: &MinKernelVersionCheck{MinKernelVersion: "6.0"}, + }, + expectedError: false, + }, + { + name: "Valid linux and darwin version", + check: OSVersionCheck{ + Linux: &MinKernelVersionCheck{MinKernelVersion: "6.0"}, + Darwin: &MinVersionCheck{MinVersion: "14.2"}, + }, + expectedError: false, + }, + { + name: "Invalid empty check", + check: OSVersionCheck{}, + expectedError: true, + }, + { + name: "Invalid empty linux kernel version", + check: OSVersionCheck{ + Linux: &MinKernelVersionCheck{}, + }, + expectedError: true, + }, + { + name: "Invalid empty linux kernel version with correct darwin version", + check: OSVersionCheck{ + Linux: &MinKernelVersionCheck{}, + Darwin: &MinVersionCheck{MinVersion: "14.2"}, + }, + expectedError: true, + }, + { + name: "Valid windows kernel version", + check: OSVersionCheck{ + Windows: &MinKernelVersionCheck{MinKernelVersion: "10.0"}, + }, + expectedError: false, + }, + { + name: "Valid ios minimum version", + check: OSVersionCheck{ + Ios: &MinVersionCheck{MinVersion: "13.0"}, + }, + expectedError: false, + }, + { + name: "Invalid empty window version with valid ios minimum version", + check: OSVersionCheck{ + Windows: &MinKernelVersionCheck{}, + Ios: &MinVersionCheck{MinVersion: "13.0"}, + }, + expectedError: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.check.Validate() + if tc.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/management/server/posture/process.go b/management/server/posture/process.go new file mode 100644 index 000000000..911aabb52 --- /dev/null +++ b/management/server/posture/process.go @@ -0,0 +1,80 @@ +package posture + +import ( + "context" + "fmt" + "slices" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" +) + +type Process struct { + LinuxPath string + MacPath string + WindowsPath string +} + +type ProcessCheck struct { + Processes []Process +} + +var _ Check = (*ProcessCheck)(nil) + +func (p *ProcessCheck) Check(_ context.Context, peer nbpeer.Peer) (bool, error) { + peerActiveProcesses := extractPeerActiveProcesses(peer.Meta.Files) + + var pathSelector func(Process) string + switch peer.Meta.GoOS { + case "linux": + pathSelector = func(process Process) string { return process.LinuxPath } + case "darwin": + pathSelector = func(process Process) string { return process.MacPath } + case "windows": + pathSelector = func(process Process) string { return process.WindowsPath } + default: + return false, fmt.Errorf("unsupported peer's operating system: %s", peer.Meta.GoOS) + } + + return p.areAllProcessesRunning(peerActiveProcesses, pathSelector), nil +} + +func (p *ProcessCheck) Name() string { + return ProcessCheckName +} + +func (p *ProcessCheck) Validate() error { + if len(p.Processes) == 0 { + return fmt.Errorf("%s processes shouldn't be empty", p.Name()) + } + + for _, process := range p.Processes { + if process.LinuxPath == "" && process.MacPath == "" && process.WindowsPath == "" { + return fmt.Errorf("%s path shouldn't be empty", p.Name()) + } + } + return nil +} + +// areAllProcessesRunning checks if all processes specified in ProcessCheck are running. +// It uses the provided pathSelector to get the appropriate process path for the peer's OS. +// It returns true if all processes are running, otherwise false. +func (p *ProcessCheck) areAllProcessesRunning(activeProcesses []string, pathSelector func(Process) string) bool { + for _, process := range p.Processes { + path := pathSelector(process) + if path == "" || !slices.Contains(activeProcesses, path) { + return false + } + } + return true +} + +// extractPeerActiveProcesses extracts the paths of running processes from the peer meta. +func extractPeerActiveProcesses(files []nbpeer.File) []string { + activeProcesses := make([]string, 0, len(files)) + for _, file := range files { + if file.ProcessIsRunning { + activeProcesses = append(activeProcesses, file.Path) + } + } + return activeProcesses +} diff --git a/management/server/posture/process_test.go b/management/server/posture/process_test.go new file mode 100644 index 000000000..ce43a948a --- /dev/null +++ b/management/server/posture/process_test.go @@ -0,0 +1,319 @@ +package posture + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/peer" +) + +func TestProcessCheck_Check(t *testing.T) { + tests := []struct { + name string + input peer.Peer + check ProcessCheck + wantErr bool + isValid bool + }{ + { + name: "darwin with matching running processes", + input: peer.Peer{ + Meta: peer.PeerSystemMeta{ + GoOS: "darwin", + Files: []peer.File{ + {Path: "/Applications/process1.app", ProcessIsRunning: true}, + {Path: "/Applications/process2.app", ProcessIsRunning: true}, + }, + }, + }, + check: ProcessCheck{ + Processes: []Process{ + {MacPath: "/Applications/process1.app"}, + {MacPath: "/Applications/process2.app"}, + }, + }, + wantErr: false, + isValid: true, + }, + { + name: "darwin with windows process paths", + input: peer.Peer{ + Meta: peer.PeerSystemMeta{ + GoOS: "darwin", + Files: []peer.File{ + {Path: "/Applications/process1.app", ProcessIsRunning: true}, + {Path: "/Applications/process2.app", ProcessIsRunning: true}, + }, + }, + }, + check: ProcessCheck{ + Processes: []Process{ + {WindowsPath: "C:\\Program Files\\process1.exe"}, + {WindowsPath: "C:\\Program Files\\process2.exe"}, + }, + }, + wantErr: false, + isValid: false, + }, + { + name: "linux with matching running processes", + input: peer.Peer{ + Meta: peer.PeerSystemMeta{ + GoOS: "linux", + Files: []peer.File{ + {Path: "/usr/bin/process1", ProcessIsRunning: true}, + {Path: "/usr/bin/process2", ProcessIsRunning: true}, + }, + }, + }, + check: ProcessCheck{ + Processes: []Process{ + {LinuxPath: "/usr/bin/process1"}, + {LinuxPath: "/usr/bin/process2"}, + }, + }, + wantErr: false, + isValid: true, + }, + { + name: "linux with matching no running processes", + input: peer.Peer{ + Meta: peer.PeerSystemMeta{ + GoOS: "linux", + Files: []peer.File{ + {Path: "/usr/bin/process1", ProcessIsRunning: true}, + {Path: "/usr/bin/process2", ProcessIsRunning: false}, + }, + }, + }, + check: ProcessCheck{ + Processes: []Process{ + {LinuxPath: "/usr/bin/process1"}, + {LinuxPath: "/usr/bin/process2"}, + }, + }, + wantErr: false, + isValid: false, + }, + { + name: "linux with windows process paths", + input: peer.Peer{ + Meta: peer.PeerSystemMeta{ + GoOS: "linux", + Files: []peer.File{ + {Path: "/usr/bin/process1", ProcessIsRunning: true}, + {Path: "/usr/bin/process2"}, + }, + }, + }, + check: ProcessCheck{ + Processes: []Process{ + {WindowsPath: "C:\\Program Files\\process1.exe"}, + {WindowsPath: "C:\\Program Files\\process2.exe"}, + }, + }, + wantErr: false, + isValid: false, + }, + { + name: "linux with non-matching processes", + input: peer.Peer{ + Meta: peer.PeerSystemMeta{ + GoOS: "linux", + Files: []peer.File{ + {Path: "/usr/bin/process3"}, + {Path: "/usr/bin/process4"}, + }, + }, + }, + check: ProcessCheck{ + Processes: []Process{ + {LinuxPath: "/usr/bin/process1"}, + {LinuxPath: "/usr/bin/process2"}, + }, + }, + wantErr: false, + isValid: false, + }, + { + name: "windows with matching running processes", + input: peer.Peer{ + Meta: peer.PeerSystemMeta{ + GoOS: "windows", + Files: []peer.File{ + {Path: "C:\\Program Files\\process1.exe", ProcessIsRunning: true}, + {Path: "C:\\Program Files\\process1.exe", ProcessIsRunning: true}, + }, + }, + }, + check: ProcessCheck{ + Processes: []Process{ + {WindowsPath: "C:\\Program Files\\process1.exe"}, + {WindowsPath: "C:\\Program Files\\process1.exe"}, + }, + }, + wantErr: false, + isValid: true, + }, + { + name: "windows with darwin process paths", + input: peer.Peer{ + Meta: peer.PeerSystemMeta{ + GoOS: "windows", + Files: []peer.File{ + {Path: "C:\\Program Files\\process1.exe"}, + {Path: "C:\\Program Files\\process1.exe"}, + }, + }, + }, + check: ProcessCheck{ + Processes: []Process{ + {MacPath: "/Applications/process1.app"}, + {LinuxPath: "/Applications/process2.app"}, + }, + }, + wantErr: false, + isValid: false, + }, + { + name: "windows with non-matching processes", + input: peer.Peer{ + Meta: peer.PeerSystemMeta{ + GoOS: "windows", + Files: []peer.File{ + {Path: "C:\\Program Files\\process3.exe"}, + {Path: "C:\\Program Files\\process4.exe"}, + }, + }, + }, + check: ProcessCheck{ + Processes: []Process{ + {WindowsPath: "C:\\Program Files\\process1.exe"}, + {WindowsPath: "C:\\Program Files\\process2.exe"}, + }, + }, + wantErr: false, + isValid: false, + }, + { + name: "unsupported ios operating system", + input: peer.Peer{ + Meta: peer.PeerSystemMeta{ + GoOS: "ios", + }, + }, + check: ProcessCheck{ + Processes: []Process{ + {WindowsPath: "C:\\Program Files\\process1.exe"}, + {MacPath: "/Applications/process2.app"}, + }, + }, + wantErr: true, + isValid: false, + }, + { + name: "unsupported android operating system", + input: peer.Peer{ + Meta: peer.PeerSystemMeta{ + GoOS: "android", + }, + }, + check: ProcessCheck{ + Processes: []Process{ + {WindowsPath: "C:\\Program Files\\process1.exe"}, + {MacPath: "/Applications/process2.app"}, + {LinuxPath: "/usr/bin/process2"}, + }, + }, + wantErr: true, + isValid: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + isValid, err := tt.check.Check(context.Background(), tt.input) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tt.isValid, isValid) + }) + } +} + +func TestProcessCheck_Validate(t *testing.T) { + testCases := []struct { + name string + check ProcessCheck + expectedError bool + }{ + { + name: "Valid linux, mac and windows processes", + check: ProcessCheck{ + Processes: []Process{ + { + LinuxPath: "/usr/local/bin/netbird", + MacPath: "/usr/local/bin/netbird", + WindowsPath: "C:\\ProgramData\\NetBird\\netbird.exe", + }, + }, + }, + expectedError: false, + }, + { + name: "Valid linux process", + check: ProcessCheck{ + Processes: []Process{ + { + LinuxPath: "/usr/local/bin/netbird", + }, + }, + }, + expectedError: false, + }, + { + name: "Valid mac process", + check: ProcessCheck{ + Processes: []Process{ + { + MacPath: "/Applications/NetBird.app/Contents/MacOS/netbird", + }, + }, + }, + expectedError: false, + }, + { + name: "Valid windows process", + check: ProcessCheck{ + Processes: []Process{ + { + WindowsPath: "C:\\ProgramData\\NetBird\\netbird.exe", + }, + }, + }, + expectedError: false, + }, + { + name: "Invalid empty processes", + check: ProcessCheck{ + Processes: []Process{}, + }, + expectedError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.check.Validate() + if tc.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index fb904c10f..851d4d31f 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -1,16 +1,24 @@ package server import ( + "context" + "slices" + "github.com/netbirdio/netbird/management/server/activity" + nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" ) -func (am *DefaultAccountManager) GetPostureChecks(accountID, postureChecksID, userID string) (*posture.Checks, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +const ( + errMsgPostureAdminOnly = "only users with admin power are allowed to view posture checks" +) + +func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -21,7 +29,7 @@ func (am *DefaultAccountManager) GetPostureChecks(accountID, postureChecksID, us } if !user.HasAdminPower() { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view posture checks") + return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) } for _, postureChecks := range account.PostureChecks { @@ -33,11 +41,11 @@ func (am *DefaultAccountManager) GetPostureChecks(accountID, postureChecksID, us return nil, status.Errorf(status.NotFound, "posture checks with ID %s not found", postureChecksID) } -func (am *DefaultAccountManager) SavePostureChecks(accountID, userID string, postureChecks *posture.Checks) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } @@ -48,11 +56,11 @@ func (am *DefaultAccountManager) SavePostureChecks(accountID, userID string, pos } if !user.HasAdminPower() { - return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view posture checks") + return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) } if err := postureChecks.Validate(); err != nil { - return status.Errorf(status.BadRequest, err.Error()) + return status.Errorf(status.InvalidArgument, err.Error()) } exists, uniqName := am.savePostureChecks(account, postureChecks) @@ -68,23 +76,23 @@ func (am *DefaultAccountManager) SavePostureChecks(accountID, userID string, pos account.Network.IncSerial() } - if err = am.Store.SaveAccount(account); err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return err } - am.StoreEvent(userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) + am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) if exists { - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) } return nil } -func (am *DefaultAccountManager) DeletePostureChecks(accountID, postureChecksID, userID string) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } @@ -95,7 +103,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(accountID, postureChecksID, } if !user.HasAdminPower() { - return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view posture checks") + return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) } postureChecks, err := am.deletePostureChecks(account, postureChecksID) @@ -103,20 +111,20 @@ func (am *DefaultAccountManager) DeletePostureChecks(accountID, postureChecksID, return err } - if err = am.Store.SaveAccount(account); err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return err } - am.StoreEvent(userID, postureChecks.ID, accountID, activity.PostureCheckDeleted, postureChecks.EventMeta()) + am.StoreEvent(ctx, userID, postureChecks.ID, accountID, activity.PostureCheckDeleted, postureChecks.EventMeta()) return nil } -func (am *DefaultAccountManager) ListPostureChecks(accountID, userID string) ([]*posture.Checks, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -127,7 +135,7 @@ func (am *DefaultAccountManager) ListPostureChecks(accountID, userID string) ([] } if !user.HasAdminPower() { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view posture checks") + return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) } return account.PostureChecks, nil @@ -176,3 +184,58 @@ func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureCh return postureChecks, nil } + +// getPeerPostureChecks returns the posture checks applied for a given peer. +func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peer *nbpeer.Peer) []*posture.Checks { + peerPostureChecks := make(map[string]posture.Checks) + + if len(account.PostureChecks) == 0 { + return nil + } + + for _, policy := range account.Policies { + if !policy.Enabled { + continue + } + + if isPeerInPolicySourceGroups(peer.ID, account, policy) { + addPolicyPostureChecks(account, policy, peerPostureChecks) + } + } + + postureChecksList := make([]*posture.Checks, 0, len(peerPostureChecks)) + for _, check := range peerPostureChecks { + checkCopy := check + postureChecksList = append(postureChecksList, &checkCopy) + } + + return postureChecksList +} + +// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups. +func isPeerInPolicySourceGroups(peerID string, account *Account, policy *Policy) bool { + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + for _, sourceGroup := range rule.Sources { + group, ok := account.Groups[sourceGroup] + if ok && slices.Contains(group.Peers, peerID) { + return true + } + } + } + + return false +} + +func addPolicyPostureChecks(account *Account, policy *Policy, peerPostureChecks map[string]posture.Checks) { + for _, sourcePostureCheckID := range policy.SourcePostureChecks { + for _, postureCheck := range account.PostureChecks { + if postureCheck.ID == sourcePostureCheckID { + peerPostureChecks[sourcePostureCheckID] = *postureCheck + } + } + } +} diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index dd92fe8b9..d837120f4 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -28,15 +29,15 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { t.Run("Generic posture check flow", func(t *testing.T) { // regular users can not create checks - err := am.SavePostureChecks(account.Id, regularUserID, &posture.Checks{}) + err := am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{}) assert.Error(t, err) // regular users cannot list check - _, err = am.ListPostureChecks(account.Id, regularUserID) + _, err = am.ListPostureChecks(context.Background(), account.Id, regularUserID) assert.Error(t, err) // should be possible to create posture check with uniq name - err = am.SavePostureChecks(account.Id, adminUserID, &posture.Checks{ + err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ ID: postureCheckID, Name: postureCheckName, Checks: posture.ChecksDefinition{ @@ -48,12 +49,12 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { assert.NoError(t, err) // admin users can list check - checks, err := am.ListPostureChecks(account.Id, adminUserID) + checks, err := am.ListPostureChecks(context.Background(), account.Id, adminUserID) assert.NoError(t, err) assert.Len(t, checks, 1) // should not be possible to create posture check with non uniq name - err = am.SavePostureChecks(account.Id, adminUserID, &posture.Checks{ + err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ ID: "new-id", Name: postureCheckName, Checks: posture.ChecksDefinition{ @@ -69,7 +70,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { assert.Error(t, err) // admins can update posture checks - err = am.SavePostureChecks(account.Id, adminUserID, &posture.Checks{ + err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ ID: postureCheckID, Name: postureCheckName, Checks: posture.ChecksDefinition{ @@ -81,13 +82,13 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { assert.NoError(t, err) // users should not be able to delete posture checks - err = am.DeletePostureChecks(account.Id, postureCheckID, regularUserID) + err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, regularUserID) assert.Error(t, err) // admin should be able to delete posture checks - err = am.DeletePostureChecks(account.Id, postureCheckID, adminUserID) + err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, adminUserID) assert.NoError(t, err) - checks, err = am.ListPostureChecks(account.Id, adminUserID) + checks, err = am.ListPostureChecks(context.Background(), account.Id, adminUserID) assert.NoError(t, err) assert.Len(t, checks, 0) }) @@ -106,14 +107,14 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*Account, error) { Role: UserRoleUser, } - account := newAccountWithId(accountID, groupAdminUserID, domain) + account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain) account.Users[admin.Id] = admin account.Users[user.Id] = user - err := am.Store.SaveAccount(account) + err := am.Store.SaveAccount(context.Background(), account) if err != nil { return nil, err } - return am.Store.GetAccount(account.Id) + return am.Store.GetAccount(context.Background(), account.Id) } diff --git a/management/server/route.go b/management/server/route.go index 2de813d48..6db00a255 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -1,11 +1,14 @@ package server import ( + "context" + "fmt" "net/netip" "unicode/utf8" "github.com/rs/xid" + "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/status" @@ -13,11 +16,11 @@ import ( ) // GetRoute gets a route object from account and route IDs -func (am *DefaultAccountManager) GetRoute(accountID string, routeID route.ID, userID string) (*route.Route, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -39,10 +42,10 @@ func (am *DefaultAccountManager) GetRoute(accountID string, routeID route.ID, us return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID) } -// checkRoutePrefixExistsForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. -func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix) error { +// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. +func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error { // routes can have both peer and peer_groups - routesWithPrefix := account.GetRoutesByPrefix(prefix) + routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains) // lets remember all the peers and the peer groups from routesWithPrefix seenPeers := make(map[string]bool) @@ -50,7 +53,7 @@ func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account for _, prefixRoute := range routesWithPrefix { // we skip route(s) with the same network ID as we want to allow updating of the existing route - // when create a new route routeID is newly generated so nothing will be skipped + // when creating a new route routeID is newly generated so nothing will be skipped if routeID == prefixRoute.ID { continue } @@ -64,8 +67,9 @@ func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account group := account.GetGroup(groupID) if group == nil { return status.Errorf( - status.InvalidArgument, "failed to add route with prefix %s - peer group %s doesn't exist", - prefix.String(), groupID) + status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist", + getRouteDescriptor(prefix, domains), groupID, + ) } for _, pID := range group.Peers { @@ -82,18 +86,18 @@ func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account } if _, ok := seenPeers[peerID]; ok { return status.Errorf(status.AlreadyExists, - "failed to add route with prefix %s - peer %s already has this route", prefix.String(), peerID) + "failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID) } } // check that peerGroupIDs are not in any route peerGroups list for _, groupID := range peerGroupIDs { - group := account.GetGroup(groupID) // we validated the group existent before entering this function, o need to check again. + group := account.GetGroup(groupID) // we validated the group existence before entering this function, no need to check again. if _, ok := seenPeerGroups[groupID]; ok { return status.Errorf( - status.AlreadyExists, "failed to add route with prefix %s - peer group %s already has this route", - prefix.String(), group.Name) + status.AlreadyExists, "failed to add route with %s - peer group %s already has this route", + getRouteDescriptor(prefix, domains), group.Name) } // check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix @@ -104,8 +108,8 @@ func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) } return status.Errorf(status.AlreadyExists, - "failed to add route with prefix %s - peer %s from the group %s already has this route", - prefix.String(), peer.Name, group.Name) + "failed to add route with %s - peer %s from the group %s already has this route", + getRouteDescriptor(prefix, domains), peer.Name, group.Name) } } } @@ -113,16 +117,35 @@ func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account return nil } +func getRouteDescriptor(prefix netip.Prefix, domains domain.List) string { + if len(domains) > 0 { + return fmt.Sprintf("domains [%s]", domains.SafeString()) + } + return fmt.Sprintf("prefix %s", prefix.String()) +} + // CreateRoute creates and saves a new route -func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } + if len(domains) > 0 && prefix.IsValid() { + return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") + } + + if len(domains) == 0 && !prefix.IsValid() { + return nil, status.Errorf(status.InvalidArgument, "invalid Prefix") + } + + if len(domains) > 0 { + prefix = getPlaceholderIP() + } + if peerID != "" && len(peerGroupIDs) != 0 { return nil, status.Errorf( status.InvalidArgument, @@ -133,11 +156,6 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, var newRoute route.Route newRoute.ID = route.ID(xid.New().String()) - prefixType, newPrefix, err := route.ParseNetwork(network) - if err != nil { - return nil, status.Errorf(status.InvalidArgument, "failed to parse IP %s", network) - } - if len(peerGroupIDs) > 0 { err = validateGroups(peerGroupIDs, account.Groups) if err != nil { @@ -145,7 +163,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, } } - err = am.checkRoutePrefixExistsForPeers(account, peerID, newRoute.ID, peerGroupIDs, newPrefix) + err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains) if err != nil { return nil, err } @@ -165,14 +183,16 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, newRoute.Peer = peerID newRoute.PeerGroups = peerGroupIDs - newRoute.Network = newPrefix - newRoute.NetworkType = prefixType + newRoute.Network = prefix + newRoute.Domains = domains + newRoute.NetworkType = networkType newRoute.Description = description newRoute.NetID = netID newRoute.Masquerade = masquerade newRoute.Metric = metric newRoute.Enabled = enabled newRoute.Groups = groups + newRoute.KeepRoute = keepRoute if account.Routes == nil { account.Routes = make(map[route.ID]*route.Route) @@ -181,30 +201,26 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, account.Routes[newRoute.ID] = &newRoute account.Network.IncSerial() - if err = am.Store.SaveAccount(account); err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return nil, err } - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) - am.StoreEvent(userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) + am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) return &newRoute, nil } // SaveRoute saves route -func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave *route.Route) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userID string, routeToSave *route.Route) error { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() if routeToSave == nil { return status.Errorf(status.InvalidArgument, "route provided is nil") } - if !routeToSave.Network.IsValid() { - return status.Errorf(status.InvalidArgument, "invalid Prefix %s", routeToSave.Network.String()) - } - if routeToSave.Metric < route.MinMetric || routeToSave.Metric > route.MaxMetric { return status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric) } @@ -213,11 +229,23 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) } - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } + if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() { + return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") + } + + if len(routeToSave.Domains) == 0 && !routeToSave.Network.IsValid() { + return status.Errorf(status.InvalidArgument, "invalid Prefix") + } + + if len(routeToSave.Domains) > 0 { + routeToSave.Network = getPlaceholderIP() + } + if routeToSave.Peer != "" && len(routeToSave.PeerGroups) != 0 { return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time") } @@ -229,7 +257,7 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave } } - err = am.checkRoutePrefixExistsForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network) + err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains) if err != nil { return err } @@ -242,23 +270,23 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave account.Routes[routeToSave.ID] = routeToSave account.Network.IncSerial() - if err = am.Store.SaveAccount(account); err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return err } - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) - am.StoreEvent(userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) + am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) return nil } // DeleteRoute deletes route with routeID -func (am *DefaultAccountManager) DeleteRoute(accountID string, routeID route.ID, userID string) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } @@ -270,23 +298,23 @@ func (am *DefaultAccountManager) DeleteRoute(accountID string, routeID route.ID, delete(account.Routes, routeID) account.Network.IncSerial() - if err = am.Store.SaveAccount(account); err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return err } - am.StoreEvent(userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) + am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) return nil } // ListRoutes returns a list of routes from account -func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.Route, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -313,10 +341,12 @@ func toProtocolRoute(route *route.Route) *proto.Route { ID: string(route.ID), NetID: string(route.NetID), Network: route.Network.String(), + Domains: route.Domains.ToPunycodeList(), NetworkType: int64(route.NetworkType), Peer: route.Peer, Metric: int64(route.Metric), Masquerade: route.Masquerade, + KeepRoute: route.KeepRoute, } } @@ -327,3 +357,9 @@ func toProtocolRoutes(routes []*route.Route) []*proto.Route { } return protoRoutes } + +// getPlaceholderIP returns a placeholder IP address for the route if domains are used +func getPlaceholderIP() netip.Prefix { + // Using an IP from the documentation range to minimize impact in case older clients try to set a route + return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32) +} diff --git a/management/server/route_test.go b/management/server/route_test.go index d28b40d48..8b168a79f 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "net/netip" "testing" @@ -8,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server/activity" nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -33,13 +35,18 @@ const ( routeGroupHA2 = "routeGroupHA2" routeInvalidGroup1 = "routeInvalidGroup1" userID = "testingUser" - existingNetwork = "10.10.10.0/24" existingRouteID = "random-id" ) +var existingNetwork = netip.MustParsePrefix("10.10.10.0/24") +var existingDomains = domain.List{"example.com"} + func TestCreateRoute(t *testing.T) { type input struct { - network string + network netip.Prefix + domains domain.List + keepRoute bool + networkType route.NetworkType netID route.NetID peerKey string peerGroupIDs []string @@ -59,9 +66,10 @@ func TestCreateRoute(t *testing.T) { expectedRoute *route.Route }{ { - name: "Happy Path", + name: "Happy Path Network", inputArgs: input{ - network: "192.168.0.0/16", + network: netip.MustParsePrefix("192.168.0.0/16"), + networkType: route.IPv4Network, netID: "happy", peerKey: peer1ID, description: "super", @@ -84,10 +92,41 @@ func TestCreateRoute(t *testing.T) { Groups: []string{routeGroup1}, }, }, + { + name: "Happy Path Domains", + inputArgs: input{ + domains: domain.List{"domain1", "domain2"}, + keepRoute: true, + networkType: route.DomainNetwork, + netID: "happy", + peerKey: peer1ID, + description: "super", + masquerade: false, + metric: 9999, + enabled: true, + groups: []string{routeGroup1}, + }, + errFunc: require.NoError, + shouldCreate: true, + expectedRoute: &route.Route{ + Network: netip.MustParsePrefix("192.0.2.0/32"), + Domains: domain.List{"domain1", "domain2"}, + NetworkType: route.DomainNetwork, + NetID: "happy", + Peer: peer1ID, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + KeepRoute: true, + }, + }, { name: "Happy Path Peer Groups", inputArgs: input{ - network: "192.168.0.0/16", + network: netip.MustParsePrefix("192.168.0.0/16"), + networkType: route.IPv4Network, netID: "happy", peerGroupIDs: []string{routeGroupHA1, routeGroupHA2}, description: "super", @@ -111,9 +150,10 @@ func TestCreateRoute(t *testing.T) { }, }, { - name: "Both peer and peer_groups Provided Should Fail", + name: "Both network and domains provided should fail", inputArgs: input{ - network: "192.168.0.0/16", + network: netip.MustParsePrefix("192.168.0.0/16"), + domains: domain.List{"domain1", "domain2"}, netID: "happy", peerKey: peer1ID, peerGroupIDs: []string{routeGroupHA1}, @@ -127,16 +167,18 @@ func TestCreateRoute(t *testing.T) { shouldCreate: false, }, { - name: "Bad Prefix Should Fail", + name: "Both peer and peer_groups Provided Should Fail", inputArgs: input{ - network: "192.168.0.0/34", - netID: "happy", - peerKey: peer1ID, - description: "super", - masquerade: false, - metric: 9999, - enabled: true, - groups: []string{routeGroup1}, + network: netip.MustParsePrefix("192.168.0.0/16"), + networkType: route.IPv4Network, + netID: "happy", + peerKey: peer1ID, + peerGroupIDs: []string{routeGroupHA1}, + description: "super", + masquerade: false, + metric: 9999, + enabled: true, + groups: []string{routeGroup1}, }, errFunc: require.Error, shouldCreate: false, @@ -144,7 +186,8 @@ func TestCreateRoute(t *testing.T) { { name: "Bad Peer Should Fail", inputArgs: input{ - network: "192.168.0.0/16", + network: netip.MustParsePrefix("192.168.0.0/16"), + networkType: route.IPv4Network, netID: "happy", peerKey: "notExistingPeer", description: "super", @@ -157,9 +200,10 @@ func TestCreateRoute(t *testing.T) { shouldCreate: false, }, { - name: "Bad Peer already has this route", + name: "Bad Peer already has this network route", inputArgs: input{ network: existingNetwork, + networkType: route.IPv4Network, netID: "bad", peerKey: peer5ID, description: "super", @@ -173,9 +217,44 @@ func TestCreateRoute(t *testing.T) { shouldCreate: false, }, { - name: "Bad Peers Group already has this route", + name: "Bad Peer already has this domains route", + inputArgs: input{ + domains: existingDomains, + networkType: route.DomainNetwork, + netID: "bad", + peerKey: peer5ID, + description: "super", + masquerade: false, + metric: 9999, + enabled: true, + groups: []string{routeGroup1}, + }, + createInitRoute: true, + errFunc: require.Error, + shouldCreate: false, + }, + { + name: "Bad Peers Group already has this network route", inputArgs: input{ network: existingNetwork, + networkType: route.IPv4Network, + netID: "bad", + peerGroupIDs: []string{routeGroup1, routeGroup3}, + description: "super", + masquerade: false, + metric: 9999, + enabled: true, + groups: []string{routeGroup1}, + }, + createInitRoute: true, + errFunc: require.Error, + shouldCreate: false, + }, + { + name: "Bad Peers Group already has this domains route", + inputArgs: input{ + domains: existingDomains, + networkType: route.DomainNetwork, netID: "bad", peerGroupIDs: []string{routeGroup1, routeGroup3}, description: "super", @@ -191,7 +270,8 @@ func TestCreateRoute(t *testing.T) { { name: "Empty Peer Should Create", inputArgs: input{ - network: "192.168.0.0/16", + network: netip.MustParsePrefix("192.168.0.0/16"), + networkType: route.IPv4Network, netID: "happy", peerKey: "", description: "super", @@ -217,7 +297,8 @@ func TestCreateRoute(t *testing.T) { { name: "Large Metric Should Fail", inputArgs: input{ - network: "192.168.0.0/16", + network: netip.MustParsePrefix("192.168.0.0/16"), + networkType: route.IPv4Network, peerKey: peer1ID, netID: "happy", description: "super", @@ -232,7 +313,8 @@ func TestCreateRoute(t *testing.T) { { name: "Small Metric Should Fail", inputArgs: input{ - network: "192.168.0.0/16", + network: netip.MustParsePrefix("192.168.0.0/16"), + networkType: route.IPv4Network, netID: "happy", peerKey: peer1ID, description: "super", @@ -247,7 +329,8 @@ func TestCreateRoute(t *testing.T) { { name: "Large NetID Should Fail", inputArgs: input{ - network: "192.168.0.0/16", + network: netip.MustParsePrefix("192.168.0.0/16"), + networkType: route.IPv4Network, peerKey: peer1ID, netID: "12345678901234567890qwertyuiopqwertyuiop1", description: "super", @@ -262,7 +345,8 @@ func TestCreateRoute(t *testing.T) { { name: "Small NetID Should Fail", inputArgs: input{ - network: "192.168.0.0/16", + network: netip.MustParsePrefix("192.168.0.0/16"), + networkType: route.IPv4Network, netID: "", peerKey: peer1ID, description: "", @@ -277,7 +361,8 @@ func TestCreateRoute(t *testing.T) { { name: "Empty Group List Should Fail", inputArgs: input{ - network: "192.168.0.0/16", + network: netip.MustParsePrefix("192.168.0.0/16"), + networkType: route.IPv4Network, netID: "NewId", peerKey: peer1ID, description: "", @@ -292,7 +377,8 @@ func TestCreateRoute(t *testing.T) { { name: "Empty Group ID string Should Fail", inputArgs: input{ - network: "192.168.0.0/16", + network: netip.MustParsePrefix("192.168.0.0/16"), + networkType: route.IPv4Network, netID: "NewId", peerKey: peer1ID, description: "", @@ -307,7 +393,8 @@ func TestCreateRoute(t *testing.T) { { name: "Invalid Group Should Fail", inputArgs: input{ - network: "192.168.0.0/16", + network: netip.MustParsePrefix("192.168.0.0/16"), + networkType: route.IPv4Network, netID: "NewId", peerKey: peer1ID, description: "", @@ -334,29 +421,14 @@ func TestCreateRoute(t *testing.T) { if testCase.createInitRoute { groupAll, errInit := account.GetGroupAll() - if errInit != nil { - t.Errorf("failed to get group all: %s", errInit) - } - _, errInit = am.CreateRoute(account.Id, existingNetwork, "", []string{routeGroup3, routeGroup4}, - "", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID) - if errInit != nil { - t.Errorf("failed to create init route: %s", errInit) - } + require.NoError(t, errInit) + _, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID, false) + require.NoError(t, errInit) + _, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID, false) + require.NoError(t, errInit) } - outRoute, err := am.CreateRoute( - account.Id, - testCase.inputArgs.network, - testCase.inputArgs.peerKey, - testCase.inputArgs.peerGroupIDs, - testCase.inputArgs.description, - testCase.inputArgs.netID, - testCase.inputArgs.masquerade, - testCase.inputArgs.metric, - testCase.inputArgs.groups, - testCase.inputArgs.enabled, - userID, - ) + outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute) testCase.errFunc(t, err) @@ -379,8 +451,13 @@ func TestSaveRoute(t *testing.T) { validUsedPeer := peer5ID invalidPeer := "nonExisting" validPrefix := netip.MustParsePrefix("192.168.0.0/24") + placeholderPrefix := netip.MustParsePrefix("192.0.2.0/32") invalidPrefix, _ := netip.ParsePrefix("192.168.0.0/34") validMetric := 1000 + trueKeepRoute := true + falseKeepRoute := false + ipv4networkType := route.IPv4Network + domainNetworkType := route.DomainNetwork invalidMetric := 99999 validNetID := route.NetID("12345678901234567890qw") invalidNetID := route.NetID("12345678901234567890qwertyuiopqwertyuiop1") @@ -395,6 +472,9 @@ func TestSaveRoute(t *testing.T) { newPeerGroups []string newMetric *int newPrefix *netip.Prefix + newDomains domain.List + newNetworkType *route.NetworkType + newKeepRoute *bool newGroups []string skipCopying bool shouldCreate bool @@ -402,7 +482,7 @@ func TestSaveRoute(t *testing.T) { expectedRoute *route.Route }{ { - name: "Happy Path", + name: "Happy Path Network", existingRoute: &route.Route{ ID: "testingRoute", Network: netip.MustParsePrefix("192.168.0.0/16"), @@ -434,6 +514,45 @@ func TestSaveRoute(t *testing.T) { Groups: []string{routeGroup2}, }, }, + { + name: "Happy Path Domains", + existingRoute: &route.Route{ + ID: "testingRoute", + Network: netip.Prefix{}, + Domains: domain.List{"example.com"}, + KeepRoute: false, + NetID: validNetID, + NetworkType: route.DomainNetwork, + Peer: peer1ID, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + }, + newPeer: &validPeer, + newMetric: &validMetric, + newPrefix: &netip.Prefix{}, + newDomains: domain.List{"example.com", "example2.com"}, + newKeepRoute: &trueKeepRoute, + newGroups: []string{routeGroup1}, + errFunc: require.NoError, + shouldCreate: true, + expectedRoute: &route.Route{ + ID: "testingRoute", + Network: placeholderPrefix, + Domains: domain.List{"example.com", "example2.com"}, + KeepRoute: true, + NetID: validNetID, + NetworkType: route.DomainNetwork, + Peer: validPeer, + Description: "super", + Masquerade: false, + Metric: validMetric, + Enabled: true, + Groups: []string{routeGroup1}, + }, + }, { name: "Happy Path Peer Groups", existingRoute: &route.Route{ @@ -466,6 +585,23 @@ func TestSaveRoute(t *testing.T) { Groups: []string{routeGroup2}, }, }, + { + name: "Both network and domains provided should fail", + existingRoute: &route.Route{ + ID: "testingRoute", + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetID: validNetID, + NetworkType: route.IPv4Network, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + }, + newPrefix: &validPrefix, + newDomains: domain.List{"example.com"}, + errFunc: require.Error, + }, { name: "Both peer and peers_roup Provided Should Fail", existingRoute: &route.Route{ @@ -623,7 +759,7 @@ func TestSaveRoute(t *testing.T) { name: "Allow to modify existing route with new peer", existingRoute: &route.Route{ ID: "testingRoute", - Network: netip.MustParsePrefix(existingNetwork), + Network: existingNetwork, NetID: validNetID, NetworkType: route.IPv4Network, Peer: peer1ID, @@ -638,7 +774,7 @@ func TestSaveRoute(t *testing.T) { shouldCreate: true, expectedRoute: &route.Route{ ID: "testingRoute", - Network: netip.MustParsePrefix(existingNetwork), + Network: existingNetwork, NetID: validNetID, NetworkType: route.IPv4Network, Peer: validPeer, @@ -654,7 +790,7 @@ func TestSaveRoute(t *testing.T) { name: "Do not allow to modify existing route with a peer from another route", existingRoute: &route.Route{ ID: "testingRoute", - Network: netip.MustParsePrefix(existingNetwork), + Network: existingNetwork, NetID: validNetID, NetworkType: route.IPv4Network, Peer: peer1ID, @@ -672,7 +808,7 @@ func TestSaveRoute(t *testing.T) { name: "Do not allow to modify existing route with a peers group from another route", existingRoute: &route.Route{ ID: "testingRoute", - Network: netip.MustParsePrefix(existingNetwork), + Network: existingNetwork, NetID: validNetID, NetworkType: route.IPv4Network, PeerGroups: []string{routeGroup3}, @@ -686,6 +822,80 @@ func TestSaveRoute(t *testing.T) { newPeerGroups: []string{routeGroup4}, errFunc: require.Error, }, + { + name: "Allow switching from network route to domains route", + existingRoute: &route.Route{ + ID: "testingRoute", + Network: validPrefix, + Domains: nil, + KeepRoute: false, + NetID: validNetID, + NetworkType: route.IPv4Network, + Peer: peer1ID, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + }, + newPrefix: &netip.Prefix{}, + newDomains: domain.List{"example.com"}, + newNetworkType: &domainNetworkType, + newKeepRoute: &trueKeepRoute, + errFunc: require.NoError, + shouldCreate: true, + expectedRoute: &route.Route{ + ID: "testingRoute", + Network: placeholderPrefix, + NetworkType: route.DomainNetwork, + Domains: domain.List{"example.com"}, + KeepRoute: true, + NetID: validNetID, + Peer: peer1ID, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + }, + }, + { + name: "Allow switching from domains route to network route", + existingRoute: &route.Route{ + ID: "testingRoute", + Network: placeholderPrefix, + Domains: domain.List{"example.com"}, + KeepRoute: true, + NetID: validNetID, + NetworkType: route.DomainNetwork, + Peer: peer1ID, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + }, + newPrefix: &validPrefix, + newDomains: nil, + newKeepRoute: &falseKeepRoute, + newNetworkType: &ipv4networkType, + errFunc: require.NoError, + shouldCreate: true, + expectedRoute: &route.Route{ + ID: "testingRoute", + Network: validPrefix, + NetworkType: route.IPv4Network, + KeepRoute: false, + Domains: nil, + NetID: validNetID, + Peer: peer1ID, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + }, + }, } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { @@ -702,7 +912,7 @@ func TestSaveRoute(t *testing.T) { if testCase.createInitRoute { account.Routes["initRoute"] = &route.Route{ ID: "initRoute", - Network: netip.MustParsePrefix(existingNetwork), + Network: existingNetwork, NetID: existingRouteID, NetworkType: route.IPv4Network, PeerGroups: []string{routeGroup4}, @@ -716,7 +926,7 @@ func TestSaveRoute(t *testing.T) { account.Routes[testCase.existingRoute.ID] = testCase.existingRoute - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(context.Background(), account) if err != nil { t.Error("account should be saved") } @@ -739,12 +949,22 @@ func TestSaveRoute(t *testing.T) { routeToSave.Network = *testCase.newPrefix } + routeToSave.Domains = testCase.newDomains + + if testCase.newNetworkType != nil { + routeToSave.NetworkType = *testCase.newNetworkType + } + + if testCase.newKeepRoute != nil { + routeToSave.KeepRoute = *testCase.newKeepRoute + } + if testCase.newGroups != nil { routeToSave.Groups = testCase.newGroups } } - err = am.SaveRoute(account.Id, userID, routeToSave) + err = am.SaveRoute(context.Background(), account.Id, userID, routeToSave) testCase.errFunc(t, err) @@ -752,7 +972,7 @@ func TestSaveRoute(t *testing.T) { return } - account, err = am.Store.GetAccount(account.Id) + account, err = am.Store.GetAccount(context.Background(), account.Id) if err != nil { t.Fatal(err) } @@ -771,6 +991,8 @@ func TestDeleteRoute(t *testing.T) { testingRoute := &route.Route{ ID: "testingRoute", Network: netip.MustParsePrefix("192.168.0.0/16"), + Domains: domain.List{"domain1", "domain2"}, + KeepRoute: true, NetworkType: route.IPv4Network, Peer: peer1Key, Description: "super", @@ -791,17 +1013,17 @@ func TestDeleteRoute(t *testing.T) { account.Routes[testingRoute.ID] = testingRoute - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(context.Background(), account) if err != nil { t.Error("failed to save account") } - err = am.DeleteRoute(account.Id, testingRoute.ID, userID) + err = am.DeleteRoute(context.Background(), account.Id, testingRoute.ID, userID) if err != nil { t.Error("deleting route failed with error: ", err) } - savedAccount, err := am.Store.GetAccount(account.Id) + savedAccount, err := am.Store.GetAccount(context.Background(), account.Id) if err != nil { t.Error("failed to retrieve saved account with error: ", err) } @@ -835,29 +1057,27 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { t.Error("failed to init testing account") } - newAccountRoutes, err := am.GetNetworkMap(peer1ID) + newAccountRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) require.NoError(t, err) require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") - newRoute, err := am.CreateRoute( - account.Id, baseRoute.Network.String(), baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, - baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.Enabled, userID) + newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.Enabled, userID, baseRoute.KeepRoute) require.NoError(t, err) require.Equal(t, newRoute.Enabled, true) - peer1Routes, err := am.GetNetworkMap(peer1ID) + peer1Routes, err := am.GetNetworkMap(context.Background(), peer1ID) require.NoError(t, err) assert.Len(t, peer1Routes.Routes, 1, "HA route should have 1 server route") - peer2Routes, err := am.GetNetworkMap(peer2ID) + peer2Routes, err := am.GetNetworkMap(context.Background(), peer2ID) require.NoError(t, err) assert.Len(t, peer2Routes.Routes, 1, "HA route should have 1 server route") - peer4Routes, err := am.GetNetworkMap(peer4ID) + peer4Routes, err := am.GetNetworkMap(context.Background(), peer4ID) require.NoError(t, err) assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route") - groups, err := am.ListGroups(account.Id) + groups, err := am.ListGroups(context.Background(), account.Id) require.NoError(t, err) var groupHA1, groupHA2 *nbgroup.Group for _, group := range groups { @@ -869,35 +1089,35 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { } } - err = am.GroupDeletePeer(account.Id, groupHA1.ID, peer2ID) + err = am.GroupDeletePeer(context.Background(), account.Id, groupHA1.ID, peer2ID) require.NoError(t, err) - peer2RoutesAfterDelete, err := am.GetNetworkMap(peer2ID) + peer2RoutesAfterDelete, err := am.GetNetworkMap(context.Background(), peer2ID) require.NoError(t, err) assert.Len(t, peer2RoutesAfterDelete.Routes, 2, "after peer deletion group should have 2 client routes") - err = am.GroupDeletePeer(account.Id, groupHA2.ID, peer4ID) + err = am.GroupDeletePeer(context.Background(), account.Id, groupHA2.ID, peer4ID) require.NoError(t, err) - peer2RoutesAfterDelete, err = am.GetNetworkMap(peer2ID) + peer2RoutesAfterDelete, err = am.GetNetworkMap(context.Background(), peer2ID) require.NoError(t, err) assert.Len(t, peer2RoutesAfterDelete.Routes, 1, "after peer deletion group should have only 1 route") - err = am.GroupAddPeer(account.Id, groupHA2.ID, peer4ID) + err = am.GroupAddPeer(context.Background(), account.Id, groupHA2.ID, peer4ID) require.NoError(t, err) - peer1RoutesAfterAdd, err := am.GetNetworkMap(peer1ID) + peer1RoutesAfterAdd, err := am.GetNetworkMap(context.Background(), peer1ID) require.NoError(t, err) assert.Len(t, peer1RoutesAfterAdd.Routes, 1, "HA route should have more than 1 route") - peer2RoutesAfterAdd, err := am.GetNetworkMap(peer2ID) + peer2RoutesAfterAdd, err := am.GetNetworkMap(context.Background(), peer2ID) require.NoError(t, err) assert.Len(t, peer2RoutesAfterAdd.Routes, 2, "HA route should have 2 client routes") - err = am.DeleteRoute(account.Id, newRoute.ID, userID) + err = am.DeleteRoute(context.Background(), account.Id, newRoute.ID, userID) require.NoError(t, err) - peer1DeletedRoute, err := am.GetNetworkMap(peer1ID) + peer1DeletedRoute, err := am.GetNetworkMap(context.Background(), peer1ID) require.NoError(t, err) assert.Len(t, peer1DeletedRoute.Routes, 0, "we should receive one route for peer1") } @@ -928,16 +1148,14 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { t.Error("failed to init testing account") } - newAccountRoutes, err := am.GetNetworkMap(peer1ID) + newAccountRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) require.NoError(t, err) require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") - createdRoute, err := am.CreateRoute(account.Id, baseRoute.Network.String(), peer1ID, []string{}, - baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, false, - userID) + createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, false, userID, baseRoute.KeepRoute) require.NoError(t, err) - noDisabledRoutes, err := am.GetNetworkMap(peer1ID) + noDisabledRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) require.NoError(t, err) require.Len(t, noDisabledRoutes.Routes, 0, "no routes for disabled routes") @@ -948,22 +1166,22 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { expectedRoute := enabledRoute.Copy() expectedRoute.Peer = peer1Key - err = am.SaveRoute(account.Id, userID, enabledRoute) + err = am.SaveRoute(context.Background(), account.Id, userID, enabledRoute) require.NoError(t, err) - peer1Routes, err := am.GetNetworkMap(peer1ID) + peer1Routes, err := am.GetNetworkMap(context.Background(), peer1ID) require.NoError(t, err) require.Len(t, peer1Routes.Routes, 1, "we should receive one route for peer1") require.True(t, expectedRoute.IsEqual(peer1Routes.Routes[0]), "received route should be equal") - peer2Routes, err := am.GetNetworkMap(peer2ID) + peer2Routes, err := am.GetNetworkMap(context.Background(), peer2ID) require.NoError(t, err) require.Len(t, peer2Routes.Routes, 0, "no routes for peers not in the distribution group") - err = am.GroupAddPeer(account.Id, routeGroup1, peer2ID) + err = am.GroupAddPeer(context.Background(), account.Id, routeGroup1, peer2ID) require.NoError(t, err) - peer2Routes, err = am.GetNetworkMap(peer2ID) + peer2Routes, err = am.GetNetworkMap(context.Background(), peer2ID) require.NoError(t, err) require.Len(t, peer2Routes.Routes, 1, "we should receive one route") require.True(t, peer1Routes.Routes[0].IsEqual(peer2Routes.Routes[0]), "routes should be the same for peers in the same group") @@ -973,10 +1191,10 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { Name: "peer1 group", Peers: []string{peer1ID}, } - err = am.SaveGroup(account.Id, userID, newGroup) + err = am.SaveGroup(context.Background(), account.Id, userID, newGroup) require.NoError(t, err) - rules, err := am.ListPolicies(account.Id, "testingUser") + rules, err := am.ListPolicies(context.Background(), account.Id, "testingUser") require.NoError(t, err) defaultRule := rules[0] @@ -986,24 +1204,24 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { newPolicy.Rules[0].Sources = []string{newGroup.ID} newPolicy.Rules[0].Destinations = []string{newGroup.ID} - err = am.SavePolicy(account.Id, userID, newPolicy) + err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy) require.NoError(t, err) - err = am.DeletePolicy(account.Id, defaultRule.ID, userID) + err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID) require.NoError(t, err) - peer1GroupRoutes, err := am.GetNetworkMap(peer1ID) + peer1GroupRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) require.NoError(t, err) require.Len(t, peer1GroupRoutes.Routes, 1, "we should receive one route for peer1") - peer2GroupRoutes, err := am.GetNetworkMap(peer2ID) + peer2GroupRoutes, err := am.GetNetworkMap(context.Background(), peer2ID) require.NoError(t, err) require.Len(t, peer2GroupRoutes.Routes, 0, "we should not receive routes for peer2") - err = am.DeleteRoute(account.Id, enabledRoute.ID, userID) + err = am.DeleteRoute(context.Background(), account.Id, enabledRoute.ID, userID) require.NoError(t, err) - peer1DeletedRoute, err := am.GetNetworkMap(peer1ID) + peer1DeletedRoute, err := am.GetNetworkMap(context.Background(), peer1ID) require.NoError(t, err) require.Len(t, peer1DeletedRoute.Routes, 0, "we should receive one route for peer1") } @@ -1015,13 +1233,13 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}) + return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}) } func createRouterStore(t *testing.T) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromJson(dataDir) + store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir) if err != nil { return nil, err } @@ -1036,8 +1254,8 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er accountID := "testingAcc" domain := "example.com" - account := newAccountWithId(accountID, userID, domain) - err := am.Store.SaveAccount(account) + account := newAccountWithId(context.Background(), accountID, userID, domain) + err := am.Store.SaveAccount(context.Background(), account) if err != nil { return nil, err } @@ -1172,7 +1390,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er } account.Peers[peer5.ID] = peer5 - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(context.Background(), account) if err != nil { return nil, err } @@ -1180,19 +1398,19 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er if err != nil { return nil, err } - err = am.GroupAddPeer(accountID, groupAll.ID, peer1ID) + err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer1ID) if err != nil { return nil, err } - err = am.GroupAddPeer(accountID, groupAll.ID, peer2ID) + err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer2ID) if err != nil { return nil, err } - err = am.GroupAddPeer(accountID, groupAll.ID, peer3ID) + err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer3ID) if err != nil { return nil, err } - err = am.GroupAddPeer(accountID, groupAll.ID, peer4ID) + err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer4ID) if err != nil { return nil, err } @@ -1231,11 +1449,11 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er } for _, group := range newGroup { - err = am.SaveGroup(accountID, userID, group) + err = am.SaveGroup(context.Background(), accountID, userID, group) if err != nil { return nil, err } } - return am.Store.GetAccount(account.Id) + return am.Store.GetAccount(context.Background(), account.Id) } diff --git a/management/server/scheduler.go b/management/server/scheduler.go index 356348056..147b50fc6 100644 --- a/management/server/scheduler.go +++ b/management/server/scheduler.go @@ -1,6 +1,7 @@ package server import ( + "context" "sync" "time" @@ -9,32 +10,32 @@ import ( // Scheduler is an interface which implementations can schedule and cancel jobs type Scheduler interface { - Cancel(IDs []string) - Schedule(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) + Cancel(ctx context.Context, IDs []string) + Schedule(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) } // MockScheduler is a mock implementation of Scheduler type MockScheduler struct { - CancelFunc func(IDs []string) - ScheduleFunc func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) + CancelFunc func(ctx context.Context, IDs []string) + ScheduleFunc func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) } // Cancel mocks the Cancel function of the Scheduler interface -func (mock *MockScheduler) Cancel(IDs []string) { +func (mock *MockScheduler) Cancel(ctx context.Context, IDs []string) { if mock.CancelFunc != nil { - mock.CancelFunc(IDs) + mock.CancelFunc(ctx, IDs) return } - log.Errorf("MockScheduler doesn't have Cancel function defined ") + log.WithContext(ctx).Errorf("MockScheduler doesn't have Cancel function defined ") } // Schedule mocks the Schedule function of the Scheduler interface -func (mock *MockScheduler) Schedule(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { +func (mock *MockScheduler) Schedule(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { if mock.ScheduleFunc != nil { - mock.ScheduleFunc(in, ID, job) + mock.ScheduleFunc(ctx, in, ID, job) return } - log.Errorf("MockScheduler doesn't have Schedule function defined") + log.WithContext(ctx).Errorf("MockScheduler doesn't have Schedule function defined") } // DefaultScheduler is a generic structure that allows to schedule jobs (functions) to run in the future and cancel them. @@ -52,35 +53,35 @@ func NewDefaultScheduler() *DefaultScheduler { } } -func (wm *DefaultScheduler) cancel(ID string) bool { +func (wm *DefaultScheduler) cancel(ctx context.Context, ID string) bool { cancel, ok := wm.jobs[ID] if ok { delete(wm.jobs, ID) close(cancel) - log.Debugf("cancelled scheduled job %s", ID) + log.WithContext(ctx).Debugf("cancelled scheduled job %s", ID) } return ok } // Cancel cancels the scheduled job by ID if present. // If job wasn't found the function returns false. -func (wm *DefaultScheduler) Cancel(IDs []string) { +func (wm *DefaultScheduler) Cancel(ctx context.Context, IDs []string) { wm.mu.Lock() defer wm.mu.Unlock() for _, id := range IDs { - wm.cancel(id) + wm.cancel(ctx, id) } } // Schedule a job to run in some time in the future. If job returns true then it will be scheduled one more time. // If job with the provided ID already exists, a new one won't be scheduled. -func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { +func (wm *DefaultScheduler) Schedule(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { wm.mu.Lock() defer wm.mu.Unlock() cancel := make(chan struct{}) if _, ok := wm.jobs[ID]; ok { - log.Debugf("couldn't schedule a job %s because it already exists. There are %d total jobs scheduled.", + log.WithContext(ctx).Debugf("couldn't schedule a job %s because it already exists. There are %d total jobs scheduled.", ID, len(wm.jobs)) return } @@ -88,25 +89,25 @@ func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (ne ticker := time.NewTicker(in) wm.jobs[ID] = cancel - log.Debugf("scheduled a job %s to run in %s. There are %d total jobs scheduled.", ID, in.String(), len(wm.jobs)) + log.WithContext(ctx).Debugf("scheduled a job %s to run in %s. There are %d total jobs scheduled.", ID, in.String(), len(wm.jobs)) go func() { for { select { case <-ticker.C: select { case <-cancel: - log.Debugf("scheduled job %s was canceled, stop timer", ID) + log.WithContext(ctx).Debugf("scheduled job %s was canceled, stop timer", ID) ticker.Stop() return default: - log.Debugf("time to do a scheduled job %s", ID) + log.WithContext(ctx).Debugf("time to do a scheduled job %s", ID) } runIn, reschedule := job() if !reschedule { wm.mu.Lock() defer wm.mu.Unlock() delete(wm.jobs, ID) - log.Debugf("job %s is not scheduled to run again", ID) + log.WithContext(ctx).Debugf("job %s is not scheduled to run again", ID) ticker.Stop() return } @@ -115,7 +116,7 @@ func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (ne ticker.Reset(runIn) } case <-cancel: - log.Debugf("job %s was canceled, stopping timer", ID) + log.WithContext(ctx).Debugf("job %s was canceled, stopping timer", ID) ticker.Stop() return } diff --git a/management/server/scheduler_test.go b/management/server/scheduler_test.go index 9dd73e269..7c287a554 100644 --- a/management/server/scheduler_test.go +++ b/management/server/scheduler_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "math/rand" "runtime" @@ -20,7 +21,7 @@ func TestScheduler_Performance(t *testing.T) { minMs := 50 for i := 0; i < n; i++ { millis := time.Duration(rand.Intn(maxMs-minMs)+minMs) * time.Millisecond - go scheduler.Schedule(millis, fmt.Sprintf("test-scheduler-job-%d", i), func() (nextRunIn time.Duration, reschedule bool) { + go scheduler.Schedule(context.Background(), millis, fmt.Sprintf("test-scheduler-job-%d", i), func() (nextRunIn time.Duration, reschedule bool) { time.Sleep(millis) wg.Done() return 0, false @@ -53,19 +54,19 @@ func TestScheduler_Cancel(t *testing.T) { sleepTime = 20 * time.Millisecond } - scheduler.Schedule(scheduletime, jobID1, func() (nextRunIn time.Duration, reschedule bool) { + scheduler.Schedule(context.Background(), scheduletime, jobID1, func() (nextRunIn time.Duration, reschedule bool) { tt := p[0] <-tChan t.Logf("job %s", tt) return scheduletime, true }) - scheduler.Schedule(scheduletime, jobID2, func() (nextRunIn time.Duration, reschedule bool) { + scheduler.Schedule(context.Background(), scheduletime, jobID2, func() (nextRunIn time.Duration, reschedule bool) { return scheduletime, true }) time.Sleep(sleepTime) assert.Len(t, scheduler.jobs, 2) - scheduler.Cancel([]string{jobID1}) + scheduler.Cancel(context.Background(), []string{jobID1}) close(tChan) p = []string{} time.Sleep(sleepTime) @@ -83,7 +84,7 @@ func TestScheduler_Schedule(t *testing.T) { wg.Done() return 0, false } - scheduler.Schedule(300*time.Millisecond, jobID, job) + scheduler.Schedule(context.Background(), 300*time.Millisecond, jobID, job) failed := waitTimeout(wg, time.Second) if failed { t.Fatal("timed out while waiting for test to finish") @@ -107,12 +108,12 @@ func TestScheduler_Schedule(t *testing.T) { return 0, false } - scheduler.Schedule(300*time.Millisecond, jobID, job) + scheduler.Schedule(context.Background(), 300*time.Millisecond, jobID, job) failed = waitTimeout(wg, time.Second) if failed { t.Fatal("timed out while waiting for test to finish") return } - scheduler.cancel(jobID) + scheduler.cancel(context.Background(), jobID) } diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 40b8ac457..dcaee357c 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -1,6 +1,7 @@ package server import ( + "context" "hash/fnv" "strconv" "strings" @@ -207,9 +208,9 @@ func Hash(s string) uint32 { // CreateSetupKey generates a new setup key with a given name, type, list of groups IDs to auto-assign to peers registered with this key, // and adds it to the specified account. A list of autoGroups IDs can be empty. -func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, +func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() keyDuration := DefaultSetupKeyDuration @@ -217,7 +218,7 @@ func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string keyDuration = expiresIn } - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -230,20 +231,20 @@ func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string setupKey := GenerateSetupKey(keyName, keyType, keyDuration, autoGroups, usageLimit, ephemeral) account.SetupKeys[setupKey.Key] = setupKey - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return nil, status.Errorf(status.Internal, "failed adding account key") } - am.StoreEvent(userID, setupKey.Id, accountID, activity.SetupKeyCreated, setupKey.EventMeta()) + am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.SetupKeyCreated, setupKey.EventMeta()) for _, g := range setupKey.AutoGroups { group := account.GetGroup(g) if group != nil { - am.StoreEvent(userID, setupKey.Id, accountID, activity.GroupAddedToSetupKey, + am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.GroupAddedToSetupKey, map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": setupKey.Name}) } else { - log.Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id) + log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id) } } @@ -254,15 +255,15 @@ func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string // Due to the unique nature of a SetupKey certain properties must not be overwritten // (e.g. the key itself, creation date, ID, etc). // These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key. -func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() if keyToSave == nil { return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil") } - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -287,12 +288,12 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup account.SetupKeys[newKey.Key] = newKey - if err = am.Store.SaveAccount(account); err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return nil, err } if !oldKey.Revoked && newKey.Revoked { - am.StoreEvent(userID, newKey.Id, accountID, activity.SetupKeyRevoked, newKey.EventMeta()) + am.StoreEvent(ctx, userID, newKey.Id, accountID, activity.SetupKeyRevoked, newKey.EventMeta()) } defer func() { @@ -301,10 +302,10 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup for _, g := range removedGroups { group := account.GetGroup(g) if group != nil { - am.StoreEvent(userID, oldKey.Id, accountID, activity.GroupRemovedFromSetupKey, + am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupRemovedFromSetupKey, map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name}) } else { - log.Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id) + log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id) } } @@ -312,24 +313,24 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup for _, g := range addedGroups { group := account.GetGroup(g) if group != nil { - am.StoreEvent(userID, oldKey.Id, accountID, activity.GroupAddedToSetupKey, + am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupAddedToSetupKey, map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name}) } else { - log.Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id) + log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id) } } }() - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) return newKey, nil } // ListSetupKeys returns a list of all setup keys of the account -func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*SetupKey, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -358,11 +359,11 @@ func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*Set } // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. -func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index 43edabbd6..034f4e2d6 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "strconv" "testing" @@ -20,12 +21,12 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { } userID := "testingUser" - account, err := manager.GetOrCreateAccountByUser(userID, "") + account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") if err != nil { t.Fatal(err) } - err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ ID: "group_1", Name: "group_name_1", Peers: []string{}, @@ -37,7 +38,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { expiresIn := time.Hour keyName := "my-test-key" - key, err := manager.CreateSetupKey(account.Id, keyName, SetupKeyReusable, expiresIn, []string{}, + key, err := manager.CreateSetupKey(context.Background(), account.Id, keyName, SetupKeyReusable, expiresIn, []string{}, SetupKeyUnlimitedUsage, userID, false) if err != nil { t.Fatal(err) @@ -46,7 +47,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { autoGroups := []string{"group_1", "group_2"} newKeyName := "my-new-test-key" revoked := true - newKey, err := manager.SaveSetupKey(account.Id, &SetupKey{ + newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ Id: key.Id, Name: newKeyName, Revoked: revoked, @@ -78,12 +79,12 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { } userID := "testingUser" - account, err := manager.GetOrCreateAccountByUser(userID, "") + account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") if err != nil { t.Fatal(err) } - err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ ID: "group_1", Name: "group_name_1", Peers: []string{}, @@ -92,7 +93,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ ID: "group_2", Name: "group_name_2", Peers: []string{}, @@ -136,7 +137,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { for _, tCase := range []testCase{testCase1, testCase2} { t.Run(tCase.name, func(t *testing.T) { - key, err := manager.CreateSetupKey(account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn, + key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn, tCase.expectedGroups, SetupKeyUnlimitedUsage, userID, false) if tCase.expectedFailure { @@ -174,12 +175,12 @@ func TestGetSetupKeys(t *testing.T) { } userID := "testingUser" - account, err := manager.GetOrCreateAccountByUser(userID, "") + account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") if err != nil { t.Fatal(err) } - err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ ID: "group_1", Name: "group_name_1", Peers: []string{}, @@ -188,7 +189,7 @@ func TestGetSetupKeys(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ ID: "group_2", Name: "group_name_2", Peers: []string{}, diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 56136327a..b5ae82828 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1,11 +1,14 @@ package server import ( + "context" "encoding/json" "errors" "fmt" + "os" "path/filepath" "runtime" + "runtime/debug" "strings" "sync" "time" @@ -27,6 +30,11 @@ import ( "github.com/netbirdio/netbird/route" ) +const ( + storeSqliteFileName = "store.db" + idQueryCondition = "id = ?" +) + // SqlStore represents an account storage backed by a Sql DB persisted to disk type SqlStore struct { db *gorm.DB @@ -45,7 +53,7 @@ type installation struct { type migrationFunc func(*gorm.DB) error // NewSqlStore creates a new SqlStore instance. -func NewSqlStore(db *gorm.DB, storeEngine StoreEngine, metrics telemetry.AppMetrics) (*SqlStore, error) { +func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine StoreEngine, metrics telemetry.AppMetrics) (*SqlStore, error) { sql, err := db.DB() if err != nil { return nil, err @@ -53,7 +61,7 @@ func NewSqlStore(db *gorm.DB, storeEngine StoreEngine, metrics telemetry.AppMetr conns := runtime.NumCPU() sql.SetMaxOpenConns(conns) // TODO: make it configurable - if err := migrate(db); err != nil { + if err := migrate(ctx, db); err != nil { return nil, fmt.Errorf("migrate: %w", err) } err = db.AutoMigrate( @@ -69,18 +77,18 @@ func NewSqlStore(db *gorm.DB, storeEngine StoreEngine, metrics telemetry.AppMetr } // AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock -func (s *SqlStore) AcquireGlobalLock() (unlock func()) { - log.Tracef("acquiring global lock") +func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) { + log.WithContext(ctx).Tracef("acquiring global lock") start := time.Now() s.globalAccountLock.Lock() unlock = func() { s.globalAccountLock.Unlock() - log.Tracef("released global lock in %v", time.Since(start)) + log.WithContext(ctx).Tracef("released global lock in %v", time.Since(start)) } took := time.Since(start) - log.Tracef("took %v to acquire global lock", took) + log.WithContext(ctx).Tracef("took %v to acquire global lock", took) if s.metrics != nil { s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took) } @@ -88,8 +96,8 @@ func (s *SqlStore) AcquireGlobalLock() (unlock func()) { return unlock } -func (s *SqlStore) AcquireAccountWriteLock(accountID string) (unlock func()) { - log.Tracef("acquiring write lock for account %s", accountID) +func (s *SqlStore) AcquireAccountWriteLock(ctx context.Context, accountID string) (unlock func()) { + log.WithContext(ctx).Tracef("acquiring write lock for account %s", accountID) start := time.Now() value, _ := s.accountLocks.LoadOrStore(accountID, &sync.RWMutex{}) @@ -98,14 +106,14 @@ func (s *SqlStore) AcquireAccountWriteLock(accountID string) (unlock func()) { unlock = func() { mtx.Unlock() - log.Tracef("released write lock for account %s in %v", accountID, time.Since(start)) + log.WithContext(ctx).Tracef("released write lock for account %s in %v", accountID, time.Since(start)) } return unlock } -func (s *SqlStore) AcquireAccountReadLock(accountID string) (unlock func()) { - log.Tracef("acquiring read lock for account %s", accountID) +func (s *SqlStore) AcquireAccountReadLock(ctx context.Context, accountID string) (unlock func()) { + log.WithContext(ctx).Tracef("acquiring read lock for account %s", accountID) start := time.Now() value, _ := s.accountLocks.LoadOrStore(accountID, &sync.RWMutex{}) @@ -114,15 +122,57 @@ func (s *SqlStore) AcquireAccountReadLock(accountID string) (unlock func()) { unlock = func() { mtx.RUnlock() - log.Tracef("released read lock for account %s in %v", accountID, time.Since(start)) + log.WithContext(ctx).Tracef("released read lock for account %s in %v", accountID, time.Since(start)) } return unlock } -func (s *SqlStore) SaveAccount(account *Account) error { +func (s *SqlStore) SaveAccount(ctx context.Context, account *Account) error { start := time.Now() + // todo: remove this check after the issue is resolved + s.checkAccountDomainBeforeSave(ctx, account.Id, account.Domain) + + generateAccountSQLTypes(account) + + err := s.db.Transaction(func(tx *gorm.DB) error { + result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id) + if result.Error != nil { + return result.Error + } + + result = tx.Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id) + if result.Error != nil { + return result.Error + } + + result = tx.Select(clause.Associations).Delete(account) + if result.Error != nil { + return result.Error + } + + result = tx. + Session(&gorm.Session{FullSaveAssociations: true}). + Clauses(clause.OnConflict{UpdateAll: true}). + Create(account) + if result.Error != nil { + return result.Error + } + return nil + }) + + took := time.Since(start) + if s.metrics != nil { + s.metrics.StoreMetrics().CountPersistenceDuration(took) + } + log.WithContext(ctx).Debugf("took %d ms to persist an account to the store", took.Milliseconds()) + + return err +} + +// generateAccountSQLTypes generates the GORM compatible types for the account +func generateAccountSQLTypes(account *Account) { for _, key := range account.SetupKeys { account.SetupKeysG = append(account.SetupKeysG, *key) } @@ -155,43 +205,25 @@ func (s *SqlStore) SaveAccount(account *Account) error { ns.ID = id account.NameServerGroupsG = append(account.NameServerGroupsG, *ns) } - - err := s.db.Transaction(func(tx *gorm.DB) error { - result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id) - if result.Error != nil { - return result.Error - } - - result = tx.Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id) - if result.Error != nil { - return result.Error - } - - result = tx.Select(clause.Associations).Delete(account) - if result.Error != nil { - return result.Error - } - - result = tx. - Session(&gorm.Session{FullSaveAssociations: true}). - Clauses(clause.OnConflict{UpdateAll: true}). - Create(account) - if result.Error != nil { - return result.Error - } - return nil - }) - - took := time.Since(start) - if s.metrics != nil { - s.metrics.StoreMetrics().CountPersistenceDuration(took) - } - log.Debugf("took %d ms to persist an account to the store", took.Milliseconds()) - - return err } -func (s *SqlStore) DeleteAccount(account *Account) error { +// checkAccountDomainBeforeSave temporary method to troubleshoot an issue with domains getting blank +func (s *SqlStore) checkAccountDomainBeforeSave(ctx context.Context, accountID, newDomain string) { + var acc Account + var domain string + result := s.db.Model(&acc).Select("domain").Where(idQueryCondition, accountID).First(&domain) + if result.Error != nil { + if !errors.Is(result.Error, gorm.ErrRecordNotFound) { + log.WithContext(ctx).Errorf("error when getting account %s from the store to check domain: %s", accountID, result.Error) + } + return + } + if domain != "" && newDomain == "" { + log.WithContext(ctx).Warnf("saving an account with empty domain when there was a domain set. Previous domain %s, Account ID: %s, Trace: %s", domain, accountID, debug.Stack()) + } +} + +func (s *SqlStore) DeleteAccount(ctx context.Context, account *Account) error { start := time.Now() err := s.db.Transaction(func(tx *gorm.DB) error { @@ -217,12 +249,12 @@ func (s *SqlStore) DeleteAccount(account *Account) error { if s.metrics != nil { s.metrics.StoreMetrics().CountPersistenceDuration(took) } - log.Debugf("took %d ms to delete an account to the store", took.Milliseconds()) + log.WithContext(ctx).Debugf("took %d ms to delete an account to the store", took.Milliseconds()) return err } -func (s *SqlStore) SaveInstallationID(ID string) error { +func (s *SqlStore) SaveInstallationID(_ context.Context, ID string) error { installation := installation{InstallationIDValue: ID} installation.ID = uint(s.installationPK) @@ -232,7 +264,7 @@ func (s *SqlStore) SaveInstallationID(ID string) error { func (s *SqlStore) GetInstallationID() string { var installation installation - if result := s.db.First(&installation, "id = ?", s.installationPK); result.Error != nil { + if result := s.db.First(&installation, idQueryCondition, s.installationPK); result.Error != nil { return "" } @@ -289,7 +321,7 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error { return nil } -func (s *SqlStore) GetAccountByPrivateDomain(domain string) (*Account, error) { +func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) { var account Account result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?", @@ -298,22 +330,22 @@ func (s *SqlStore) GetAccountByPrivateDomain(domain string) (*Account, error) { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") } - log.Errorf("error when getting account from the store: %s", result.Error) + log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting account from store") } // TODO: rework to not call GetAccount - return s.GetAccount(account.Id) + return s.GetAccount(ctx, account.Id) } -func (s *SqlStore) GetAccountBySetupKey(setupKey string) (*Account, error) { +func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) { var key SetupKey result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey)) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - log.Errorf("error when getting setup key from the store: %s", result.Error) + log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting setup key from store") } @@ -321,31 +353,31 @@ func (s *SqlStore) GetAccountBySetupKey(setupKey string) (*Account, error) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - return s.GetAccount(key.AccountID) + return s.GetAccount(ctx, key.AccountID) } -func (s *SqlStore) GetTokenIDByHashedToken(hashedToken string) (string, error) { +func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken string) (string, error) { var token PersonalAccessToken result := s.db.First(&token, "hashed_token = ?", hashedToken) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - log.Errorf("error when getting token from the store: %s", result.Error) + log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error) return "", status.Errorf(status.Internal, "issue getting account from store") } return token.ID, nil } -func (s *SqlStore) GetUserByTokenID(tokenID string) (*User, error) { +func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) { var token PersonalAccessToken - result := s.db.First(&token, "id = ?", tokenID) + result := s.db.First(&token, idQueryCondition, tokenID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - log.Errorf("error when getting token from the store: %s", result.Error) + log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting account from store") } @@ -354,7 +386,7 @@ func (s *SqlStore) GetUserByTokenID(tokenID string) (*User, error) { } var user User - result = s.db.Preload("PATsG").First(&user, "id = ?", token.UserID) + result = s.db.Preload("PATsG").First(&user, idQueryCondition, token.UserID) if result.Error != nil { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } @@ -367,7 +399,7 @@ func (s *SqlStore) GetUserByTokenID(tokenID string) (*User, error) { return &user, nil } -func (s *SqlStore) GetAllAccounts() (all []*Account) { +func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) { var accounts []Account result := s.db.Find(&accounts) if result.Error != nil { @@ -375,7 +407,7 @@ func (s *SqlStore) GetAllAccounts() (all []*Account) { } for _, account := range accounts { - if acc, err := s.GetAccount(account.Id); err == nil { + if acc, err := s.GetAccount(ctx, account.Id); err == nil { all = append(all, acc) } } @@ -383,15 +415,15 @@ func (s *SqlStore) GetAllAccounts() (all []*Account) { return all } -func (s *SqlStore) GetAccount(accountID string) (*Account, error) { +func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, error) { var account Account result := s.db.Model(&account). Preload("UsersG.PATsG"). // have to be specifies as this is nester reference Preload(clause.Associations). - First(&account, "id = ?", accountID) + First(&account, idQueryCondition, accountID) if result.Error != nil { - log.Errorf("error when getting account %s from the store: %s", accountID, result.Error) + log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found") } @@ -451,9 +483,9 @@ func (s *SqlStore) GetAccount(accountID string) (*Account, error) { return &account, nil } -func (s *SqlStore) GetAccountByUser(userID string) (*Account, error) { +func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) { var user User - result := s.db.Select("account_id").First(&user, "id = ?", userID) + result := s.db.Select("account_id").First(&user, idQueryCondition, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -465,17 +497,17 @@ func (s *SqlStore) GetAccountByUser(userID string) (*Account, error) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - return s.GetAccount(user.AccountID) + return s.GetAccount(ctx, user.AccountID) } -func (s *SqlStore) GetAccountByPeerID(peerID string) (*Account, error) { +func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) { var peer nbpeer.Peer - result := s.db.Select("account_id").First(&peer, "id = ?", peerID) + result := s.db.Select("account_id").First(&peer, idQueryCondition, peerID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - log.Errorf("error when getting peer from the store: %s", result.Error) + log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting account from store") } @@ -483,10 +515,10 @@ func (s *SqlStore) GetAccountByPeerID(peerID string) (*Account, error) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - return s.GetAccount(peer.AccountID) + return s.GetAccount(ctx, peer.AccountID) } -func (s *SqlStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) { +func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) { var peer nbpeer.Peer result := s.db.Select("account_id").First(&peer, "key = ?", peerKey) @@ -494,7 +526,7 @@ func (s *SqlStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - log.Errorf("error when getting peer from the store: %s", result.Error) + log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting account from store") } @@ -502,10 +534,10 @@ func (s *SqlStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - return s.GetAccount(peer.AccountID) + return s.GetAccount(ctx, peer.AccountID) } -func (s *SqlStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) { +func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) { var peer nbpeer.Peer var accountID string result := s.db.Model(&peer).Select("account_id").Where("key = ?", peerKey).First(&accountID) @@ -513,7 +545,7 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - log.Errorf("error when getting peer from the store: %s", result.Error) + log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error) return "", status.Errorf(status.Internal, "issue getting account from store") } @@ -523,7 +555,7 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) { func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { var user User var accountID string - result := s.db.Model(&user).Select("account_id").Where("id = ?", userID).First(&accountID) + result := s.db.Model(&user).Select("account_id").Where(idQueryCondition, userID).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -534,7 +566,7 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { return accountID, nil } -func (s *SqlStore) GetAccountIDBySetupKey(setupKey string) (string, error) { +func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) { var key SetupKey var accountID string result := s.db.Model(&key).Select("account_id").Where("key = ?", strings.ToUpper(setupKey)).First(&accountID) @@ -542,34 +574,34 @@ func (s *SqlStore) GetAccountIDBySetupKey(setupKey string) (string, error) { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - log.Errorf("error when getting setup key from the store: %s", result.Error) + log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error) return "", status.Errorf(status.Internal, "issue getting setup key from store") } return accountID, nil } -func (s *SqlStore) GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error) { +func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error) { var peer nbpeer.Peer result := s.db.First(&peer, "key = ?", peerKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "peer not found") } - log.Errorf("error when getting peer from the store: %s", result.Error) + log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting peer from store") } return &peer, nil } -func (s *SqlStore) GetAccountSettings(accountID string) (*Settings, error) { +func (s *SqlStore) GetAccountSettings(ctx context.Context, accountID string) (*Settings, error) { var accountSettings AccountSettings - if err := s.db.Model(&Account{}).Where("id = ?", accountID).First(&accountSettings).Error; err != nil { + if err := s.db.Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "settings not found") } - log.Errorf("error when getting settings from the store: %s", err) + log.WithContext(ctx).Errorf("error when getting settings from the store: %s", err) return nil, status.Errorf(status.Internal, "issue getting settings from store") } return accountSettings.Settings, nil @@ -608,7 +640,7 @@ func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *p } // Close closes the underlying DB connection -func (s *SqlStore) Close() error { +func (s *SqlStore) Close(_ context.Context) error { sql, err := s.db.DB() if err != nil { return fmt.Errorf("get db: %w", err) @@ -622,11 +654,11 @@ func (s *SqlStore) GetStoreEngine() StoreEngine { } // NewSqliteStore creates a new SQLite store. -func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqlStore, error) { - storeStr := "store.db?cache=shared" +func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMetrics) (*SqlStore, error) { + storeStr := fmt.Sprintf("%s?cache=shared", storeSqliteFileName) if runtime.GOOS == "windows" { // Vo avoid `The process cannot access the file because it is being used by another process` on Windows - storeStr = "store.db" + storeStr = storeSqliteFileName } file := filepath.Join(dataDir, storeStr) @@ -639,11 +671,11 @@ func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqlStore, er return nil, err } - return NewSqlStore(db, SqliteStoreEngine, metrics) + return NewSqlStore(ctx, db, SqliteStoreEngine, metrics) } // NewPostgresqlStore creates a new Postgres store. -func NewPostgresqlStore(dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { +func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), PrepareStmt: true, @@ -652,23 +684,32 @@ func NewPostgresqlStore(dsn string, metrics telemetry.AppMetrics) (*SqlStore, er return nil, err } - return NewSqlStore(db, PostgresStoreEngine, metrics) + return NewSqlStore(ctx, db, PostgresStoreEngine, metrics) +} + +// newPostgresStore initializes a new Postgres store. +func newPostgresStore(ctx context.Context, metrics telemetry.AppMetrics) (Store, error) { + dsn, ok := os.LookupEnv(postgresDsnEnv) + if !ok { + return nil, fmt.Errorf("%s is not set", postgresDsnEnv) + } + return NewPostgresqlStore(ctx, dsn, metrics) } // NewSqliteStoreFromFileStore restores a store from FileStore and stores SQLite DB in the file located in datadir. -func NewSqliteStoreFromFileStore(fileStore *FileStore, dataDir string, metrics telemetry.AppMetrics) (*SqlStore, error) { - store, err := NewSqliteStore(dataDir, metrics) +func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, dataDir string, metrics telemetry.AppMetrics) (*SqlStore, error) { + store, err := NewSqliteStore(ctx, dataDir, metrics) if err != nil { return nil, err } - err = store.SaveInstallationID(fileStore.InstallationID) + err = store.SaveInstallationID(ctx, fileStore.InstallationID) if err != nil { return nil, err } - for _, account := range fileStore.GetAllAccounts() { - err := store.SaveAccount(account) + for _, account := range fileStore.GetAllAccounts(ctx) { + err := store.SaveAccount(ctx, account) if err != nil { return nil, err } @@ -678,19 +719,19 @@ func NewSqliteStoreFromFileStore(fileStore *FileStore, dataDir string, metrics t } // NewPostgresqlStoreFromFileStore restores a store from FileStore and stores Postgres DB. -func NewPostgresqlStoreFromFileStore(fileStore *FileStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { - store, err := NewPostgresqlStore(dsn, metrics) +func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { + store, err := NewPostgresqlStore(ctx, dsn, metrics) if err != nil { return nil, err } - err = store.SaveInstallationID(fileStore.InstallationID) + err = store.SaveInstallationID(ctx, fileStore.InstallationID) if err != nil { return nil, err } - for _, account := range fileStore.GetAllAccounts() { - err := store.SaveAccount(account) + for _, account := range fileStore.GetAllAccounts(ctx) { + err := store.SaveAccount(ctx, account) if err != nil { return nil, err } diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index fc2743986..e0e893052 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "math/rand" "net" @@ -34,7 +35,7 @@ func TestSqlite_NewStore(t *testing.T) { store := newSqliteStore(t) - if len(store.GetAllAccounts()) != 0 { + if len(store.GetAllAccounts(context.Background())) != 0 { t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") } } @@ -46,7 +47,7 @@ func TestSqlite_SaveAccount_Large(t *testing.T) { store := newSqliteStore(t) - account := newAccountWithId("account_id", "testuser", "") + account := newAccountWithId(context.Background(), "account_id", "testuser", "") groupALL, err := account.GetGroupAll() if err != nil { t.Fatal(err) @@ -117,14 +118,14 @@ func TestSqlite_SaveAccount_Large(t *testing.T) { account.SetupKeys[setupKey.Key] = setupKey } - err = store.SaveAccount(account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) - if len(store.GetAllAccounts()) != 1 { + if len(store.GetAllAccounts(context.Background())) != 1 { t.Errorf("expecting 1 Accounts to be stored after SaveAccount()") } - a, err := store.GetAccount(account.Id) + a, err := store.GetAccount(context.Background(), account.Id) if a == nil { t.Errorf("expecting Account to be stored after SaveAccount(): %v", err) } @@ -191,7 +192,7 @@ func TestSqlite_SaveAccount(t *testing.T) { store := newSqliteStore(t) - account := newAccountWithId("account_id", "testuser", "") + account := newAccountWithId(context.Background(), "account_id", "testuser", "") setupKey := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ @@ -203,10 +204,10 @@ func TestSqlite_SaveAccount(t *testing.T) { Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) require.NoError(t, err) - account2 := newAccountWithId("account_id2", "testuser2", "") + account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") setupKey = GenerateDefaultSetupKey() account2.SetupKeys[setupKey.Key] = setupKey account2.Peers["testpeer2"] = &nbpeer.Peer{ @@ -218,14 +219,14 @@ func TestSqlite_SaveAccount(t *testing.T) { Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } - err = store.SaveAccount(account2) + err = store.SaveAccount(context.Background(), account2) require.NoError(t, err) - if len(store.GetAllAccounts()) != 2 { + if len(store.GetAllAccounts(context.Background())) != 2 { t.Errorf("expecting 2 Accounts to be stored after SaveAccount()") } - a, err := store.GetAccount(account.Id) + a, err := store.GetAccount(context.Background(), account.Id) if a == nil { t.Errorf("expecting Account to be stored after SaveAccount(): %v", err) } @@ -239,19 +240,19 @@ func TestSqlite_SaveAccount(t *testing.T) { return } - if a, err := store.GetAccountByPeerPubKey("peerkey"); a == nil { + if a, err := store.GetAccountByPeerPubKey(context.Background(), "peerkey"); a == nil { t.Errorf("expecting PeerKeyID2AccountID index updated after SaveAccount(): %v", err) } - if a, err := store.GetAccountByUser("testuser"); a == nil { + if a, err := store.GetAccountByUser(context.Background(), "testuser"); a == nil { t.Errorf("expecting UserID2AccountID index updated after SaveAccount(): %v", err) } - if a, err := store.GetAccountByPeerID("testpeer"); a == nil { + if a, err := store.GetAccountByPeerID(context.Background(), "testpeer"); a == nil { t.Errorf("expecting PeerID2AccountID index updated after SaveAccount(): %v", err) } - if a, err := store.GetAccountBySetupKey(setupKey.Key); a == nil { + if a, err := store.GetAccountBySetupKey(context.Background(), setupKey.Key); a == nil { t.Errorf("expecting SetupKeyID2AccountID index updated after SaveAccount(): %v", err) } } @@ -270,7 +271,7 @@ func TestSqlite_DeleteAccount(t *testing.T) { Name: "test token", }} - account := newAccountWithId("account_id", testUserID, "") + account := newAccountWithId(context.Background(), "account_id", testUserID, "") setupKey := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ @@ -283,33 +284,33 @@ func TestSqlite_DeleteAccount(t *testing.T) { } account.Users[testUserID] = user - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) require.NoError(t, err) - if len(store.GetAllAccounts()) != 1 { + if len(store.GetAllAccounts(context.Background())) != 1 { t.Errorf("expecting 1 Accounts to be stored after SaveAccount()") } - err = store.DeleteAccount(account) + err = store.DeleteAccount(context.Background(), account) require.NoError(t, err) - if len(store.GetAllAccounts()) != 0 { + if len(store.GetAllAccounts(context.Background())) != 0 { t.Errorf("expecting 0 Accounts to be stored after DeleteAccount()") } - _, err = store.GetAccountByPeerPubKey("peerkey") + _, err = store.GetAccountByPeerPubKey(context.Background(), "peerkey") require.Error(t, err, "expecting error after removing DeleteAccount when getting account by peer public key") - _, err = store.GetAccountByUser("testuser") + _, err = store.GetAccountByUser(context.Background(), "testuser") require.Error(t, err, "expecting error after removing DeleteAccount when getting account by user") - _, err = store.GetAccountByPeerID("testpeer") + _, err = store.GetAccountByPeerID(context.Background(), "testpeer") require.Error(t, err, "expecting error after removing DeleteAccount when getting account by peer id") - _, err = store.GetAccountBySetupKey(setupKey.Key) + _, err = store.GetAccountBySetupKey(context.Background(), setupKey.Key) require.Error(t, err, "expecting error after removing DeleteAccount when getting account by setup key") - _, err = store.GetAccount(account.Id) + _, err = store.GetAccount(context.Background(), account.Id) require.Error(t, err, "expecting error after removing DeleteAccount when getting account by id") for _, policy := range account.Policies { @@ -339,11 +340,11 @@ func TestSqlite_GetAccount(t *testing.T) { id := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - account, err := store.GetAccount(id) + account, err := store.GetAccount(context.Background(), id) require.NoError(t, err) require.Equal(t, id, account.Id, "account id should match") - _, err = store.GetAccount("non-existing-account") + _, err = store.GetAccount(context.Background(), "non-existing-account") assert.Error(t, err) parsedErr, ok := status.FromError(err) require.True(t, ok) @@ -357,7 +358,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { store := newSqliteStoreFromFile(t, "testdata/store.json") - account, err := store.GetAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b") + account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) // save status of non-existing peer @@ -379,13 +380,13 @@ func TestSqlite_SavePeerStatus(t *testing.T) { Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}, } - err = store.SaveAccount(account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) err = store.SavePeerStatus(account.Id, "testpeer", newStatus) require.NoError(t, err) - account, err = store.GetAccount(account.Id) + account, err = store.GetAccount(context.Background(), account.Id) require.NoError(t, err) actual := account.Peers["testpeer"].Status @@ -398,7 +399,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) { store := newSqliteStoreFromFile(t, "testdata/store.json") - account, err := store.GetAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b") + account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) peer := &nbpeer.Peer{ @@ -417,7 +418,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) { assert.Error(t, err) account.Peers[peer.ID] = peer - err = store.SaveAccount(account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) peer.Location.ConnectionIP = net.ParseIP("35.1.1.1") @@ -428,7 +429,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) { err = store.SavePeerLocation(account.Id, account.Peers[peer.ID]) assert.NoError(t, err) - account, err = store.GetAccount(account.Id) + account, err = store.GetAccount(context.Background(), account.Id) require.NoError(t, err) actual := account.Peers[peer.ID].Location @@ -451,11 +452,11 @@ func TestSqlite_TestGetAccountByPrivateDomain(t *testing.T) { existingDomain := "test.com" - account, err := store.GetAccountByPrivateDomain(existingDomain) + account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain) require.NoError(t, err, "should found account") require.Equal(t, existingDomain, account.Domain, "domains should match") - _, err = store.GetAccountByPrivateDomain("missing-domain.com") + _, err = store.GetAccountByPrivateDomain(context.Background(), "missing-domain.com") require.Error(t, err, "should return error on domain lookup") parsedErr, ok := status.FromError(err) require.True(t, ok) @@ -472,11 +473,11 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) { hashed := "SoMeHaShEdToKeN" id := "9dj38s35-63fb-11ec-90d6-0242ac120003" - token, err := store.GetTokenIDByHashedToken(hashed) + token, err := store.GetTokenIDByHashedToken(context.Background(), hashed) require.NoError(t, err) require.Equal(t, id, token) - _, err = store.GetTokenIDByHashedToken("non-existing-hash") + _, err = store.GetTokenIDByHashedToken(context.Background(), "non-existing-hash") require.Error(t, err) parsedErr, ok := status.FromError(err) require.True(t, ok) @@ -492,11 +493,11 @@ func TestSqlite_GetUserByTokenID(t *testing.T) { id := "9dj38s35-63fb-11ec-90d6-0242ac120003" - user, err := store.GetUserByTokenID(id) + user, err := store.GetUserByTokenID(context.Background(), id) require.NoError(t, err) require.Equal(t, id, user.PATs[id].ID) - _, err = store.GetUserByTokenID("non-existing-id") + _, err = store.GetUserByTokenID(context.Background(), "non-existing-id") require.Error(t, err) parsedErr, ok := status.FromError(err) require.True(t, ok) @@ -510,7 +511,7 @@ func TestMigrate(t *testing.T) { store := newSqliteStore(t) - err := migrate(store.db) + err := migrate(context.Background(), store.db) require.NoError(t, err, "Migration should not fail on empty db") _, ipnet, err := net.ParseCIDR("10.0.0.0/24") @@ -559,22 +560,43 @@ func TestMigrate(t *testing.T) { rt := &route{ Network: prefix, PeerGroups: []string{"group1", "group2"}, + Route: route2.Route{ID: "route1"}, } err = store.db.Save(rt).Error require.NoError(t, err, "Failed to insert Gob data") - err = migrate(store.db) + err = migrate(context.Background(), store.db) require.NoError(t, err, "Migration should not fail on gob populated db") - err = migrate(store.db) + err = migrate(context.Background(), store.db) require.NoError(t, err, "Migration should not fail on migrated db") + + err = store.db.Delete(rt).Where("id = ?", "route1").Error + require.NoError(t, err, "Failed to delete Gob data") + + prefix = netip.MustParsePrefix("12.0.0.0/24") + nRT := &route2.Route{ + Network: prefix, + ID: "route2", + Peer: "peer-id", + } + + err = store.db.Save(nRT).Error + require.NoError(t, err, "Failed to insert json nil slice data") + + err = migrate(context.Background(), store.db) + require.NoError(t, err, "Migration should not fail on json nil slice populated db") + + err = migrate(context.Background(), store.db) + require.NoError(t, err, "Migration should not fail on migrated db") + } func newSqliteStore(t *testing.T) *SqlStore { t.Helper() - store, err := NewSqliteStore(t.TempDir(), nil) + store, err := NewSqliteStore(context.Background(), t.TempDir(), nil) require.NoError(t, err) require.NotNil(t, store) @@ -589,10 +611,10 @@ func newSqliteStoreFromFile(t *testing.T, filename string) *SqlStore { err := util.CopyFileContents(filename, filepath.Join(storeDir, "store.json")) require.NoError(t, err) - fStore, err := NewFileStore(storeDir, nil) + fStore, err := NewFileStore(context.Background(), storeDir, nil) require.NoError(t, err) - store, err := NewSqliteStoreFromFileStore(fStore, storeDir, nil) + store, err := NewSqliteStoreFromFileStore(context.Background(), fStore, storeDir, nil) require.NoError(t, err) require.NotNil(t, store) @@ -601,7 +623,7 @@ func newSqliteStoreFromFile(t *testing.T, filename string) *SqlStore { func newAccount(store Store, id int) error { str := fmt.Sprintf("%s-%d", uuid.New().String(), id) - account := newAccountWithId(str, str+"-testuser", "example.com") + account := newAccountWithId(context.Background(), str, str+"-testuser", "example.com") setupKey := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["p"+str] = &nbpeer.Peer{ @@ -613,7 +635,7 @@ func newAccount(store Store, id int) error { Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } - return store.SaveAccount(account) + return store.SaveAccount(context.Background(), account) } func newPostgresqlStore(t *testing.T) *SqlStore { @@ -630,7 +652,7 @@ func newPostgresqlStore(t *testing.T) *SqlStore { t.Fatalf("could not initialize postgresql store: %s is not set", postgresDsnEnv) } - store, err := NewPostgresqlStore(postgresDsn, nil) + store, err := NewPostgresqlStore(context.Background(), postgresDsn, nil) if err != nil { t.Fatalf("could not initialize postgresql store: %s", err) } @@ -647,7 +669,7 @@ func newPostgresqlStoreFromFile(t *testing.T, filename string) *SqlStore { err := util.CopyFileContents(filename, filepath.Join(storeDir, "store.json")) require.NoError(t, err) - fStore, err := NewFileStore(storeDir, nil) + fStore, err := NewFileStore(context.Background(), storeDir, nil) require.NoError(t, err) cleanUp, err := testutil.CreatePGDB() @@ -661,7 +683,7 @@ func newPostgresqlStoreFromFile(t *testing.T, filename string) *SqlStore { t.Fatalf("could not initialize postgresql store: %s is not set", postgresDsnEnv) } - store, err := NewPostgresqlStoreFromFileStore(fStore, postgresDsn, nil) + store, err := NewPostgresqlStoreFromFileStore(context.Background(), fStore, postgresDsn, nil) require.NoError(t, err) require.NotNil(t, store) @@ -675,7 +697,7 @@ func TestPostgresql_NewStore(t *testing.T) { store := newPostgresqlStore(t) - if len(store.GetAllAccounts()) != 0 { + if len(store.GetAllAccounts(context.Background())) != 0 { t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") } } @@ -687,7 +709,7 @@ func TestPostgresql_SaveAccount(t *testing.T) { store := newPostgresqlStore(t) - account := newAccountWithId("account_id", "testuser", "") + account := newAccountWithId(context.Background(), "account_id", "testuser", "") setupKey := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ @@ -699,10 +721,10 @@ func TestPostgresql_SaveAccount(t *testing.T) { Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) require.NoError(t, err) - account2 := newAccountWithId("account_id2", "testuser2", "") + account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") setupKey = GenerateDefaultSetupKey() account2.SetupKeys[setupKey.Key] = setupKey account2.Peers["testpeer2"] = &nbpeer.Peer{ @@ -714,14 +736,14 @@ func TestPostgresql_SaveAccount(t *testing.T) { Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } - err = store.SaveAccount(account2) + err = store.SaveAccount(context.Background(), account2) require.NoError(t, err) - if len(store.GetAllAccounts()) != 2 { + if len(store.GetAllAccounts(context.Background())) != 2 { t.Errorf("expecting 2 Accounts to be stored after SaveAccount()") } - a, err := store.GetAccount(account.Id) + a, err := store.GetAccount(context.Background(), account.Id) if a == nil { t.Errorf("expecting Account to be stored after SaveAccount(): %v", err) } @@ -735,19 +757,19 @@ func TestPostgresql_SaveAccount(t *testing.T) { return } - if a, err := store.GetAccountByPeerPubKey("peerkey"); a == nil { + if a, err := store.GetAccountByPeerPubKey(context.Background(), "peerkey"); a == nil { t.Errorf("expecting PeerKeyID2AccountID index updated after SaveAccount(): %v", err) } - if a, err := store.GetAccountByUser("testuser"); a == nil { + if a, err := store.GetAccountByUser(context.Background(), "testuser"); a == nil { t.Errorf("expecting UserID2AccountID index updated after SaveAccount(): %v", err) } - if a, err := store.GetAccountByPeerID("testpeer"); a == nil { + if a, err := store.GetAccountByPeerID(context.Background(), "testpeer"); a == nil { t.Errorf("expecting PeerID2AccountID index updated after SaveAccount(): %v", err) } - if a, err := store.GetAccountBySetupKey(setupKey.Key); a == nil { + if a, err := store.GetAccountBySetupKey(context.Background(), setupKey.Key); a == nil { t.Errorf("expecting SetupKeyID2AccountID index updated after SaveAccount(): %v", err) } } @@ -766,7 +788,7 @@ func TestPostgresql_DeleteAccount(t *testing.T) { Name: "test token", }} - account := newAccountWithId("account_id", testUserID, "") + account := newAccountWithId(context.Background(), "account_id", testUserID, "") setupKey := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ @@ -779,33 +801,33 @@ func TestPostgresql_DeleteAccount(t *testing.T) { } account.Users[testUserID] = user - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) require.NoError(t, err) - if len(store.GetAllAccounts()) != 1 { + if len(store.GetAllAccounts(context.Background())) != 1 { t.Errorf("expecting 1 Accounts to be stored after SaveAccount()") } - err = store.DeleteAccount(account) + err = store.DeleteAccount(context.Background(), account) require.NoError(t, err) - if len(store.GetAllAccounts()) != 0 { + if len(store.GetAllAccounts(context.Background())) != 0 { t.Errorf("expecting 0 Accounts to be stored after DeleteAccount()") } - _, err = store.GetAccountByPeerPubKey("peerkey") + _, err = store.GetAccountByPeerPubKey(context.Background(), "peerkey") require.Error(t, err, "expecting error after removing DeleteAccount when getting account by peer public key") - _, err = store.GetAccountByUser("testuser") + _, err = store.GetAccountByUser(context.Background(), "testuser") require.Error(t, err, "expecting error after removing DeleteAccount when getting account by user") - _, err = store.GetAccountByPeerID("testpeer") + _, err = store.GetAccountByPeerID(context.Background(), "testpeer") require.Error(t, err, "expecting error after removing DeleteAccount when getting account by peer id") - _, err = store.GetAccountBySetupKey(setupKey.Key) + _, err = store.GetAccountBySetupKey(context.Background(), setupKey.Key) require.Error(t, err, "expecting error after removing DeleteAccount when getting account by setup key") - _, err = store.GetAccount(account.Id) + _, err = store.GetAccount(context.Background(), account.Id) require.Error(t, err, "expecting error after removing DeleteAccount when getting account by id") for _, policy := range account.Policies { @@ -833,7 +855,7 @@ func TestPostgresql_SavePeerStatus(t *testing.T) { store := newPostgresqlStoreFromFile(t, "testdata/store.json") - account, err := store.GetAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b") + account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) // save status of non-existing peer @@ -852,13 +874,13 @@ func TestPostgresql_SavePeerStatus(t *testing.T) { Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}, } - err = store.SaveAccount(account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) err = store.SavePeerStatus(account.Id, "testpeer", newStatus) require.NoError(t, err) - account, err = store.GetAccount(account.Id) + account, err = store.GetAccount(context.Background(), account.Id) require.NoError(t, err) actual := account.Peers["testpeer"].Status @@ -874,11 +896,11 @@ func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) { existingDomain := "test.com" - account, err := store.GetAccountByPrivateDomain(existingDomain) + account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain) require.NoError(t, err, "should found account") require.Equal(t, existingDomain, account.Domain, "domains should match") - _, err = store.GetAccountByPrivateDomain("missing-domain.com") + _, err = store.GetAccountByPrivateDomain(context.Background(), "missing-domain.com") require.Error(t, err, "should return error on domain lookup") } @@ -892,7 +914,7 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) { hashed := "SoMeHaShEdToKeN" id := "9dj38s35-63fb-11ec-90d6-0242ac120003" - token, err := store.GetTokenIDByHashedToken(hashed) + token, err := store.GetTokenIDByHashedToken(context.Background(), hashed) require.NoError(t, err) require.Equal(t, id, token) } @@ -906,7 +928,7 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) { id := "9dj38s35-63fb-11ec-90d6-0242ac120003" - user, err := store.GetUserByTokenID(id) + user, err := store.GetUserByTokenID(context.Background(), id) require.NoError(t, err) require.Equal(t, id, user.PATs[id].ID) } diff --git a/management/server/store.go b/management/server/store.go index 5210f1210..05a09b3ee 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -1,10 +1,13 @@ package server import ( + "context" + "errors" "fmt" "net" "net/netip" "os" + "path" "path/filepath" "strings" "time" @@ -12,50 +15,52 @@ import ( log "github.com/sirupsen/logrus" "gorm.io/gorm" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/util" + "github.com/netbirdio/netbird/management/server/migration" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/testutil" "github.com/netbirdio/netbird/route" ) type Store interface { - GetAllAccounts() []*Account - GetAccount(accountID string) (*Account, error) - DeleteAccount(account *Account) error - GetAccountByUser(userID string) (*Account, error) - GetAccountByPeerPubKey(peerKey string) (*Account, error) - GetAccountIDByPeerPubKey(peerKey string) (string, error) + GetAllAccounts(ctx context.Context) []*Account + GetAccount(ctx context.Context, accountID string) (*Account, error) + DeleteAccount(ctx context.Context, account *Account) error + GetAccountByUser(ctx context.Context, userID string) (*Account, error) + GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) + GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) GetAccountIDByUserID(peerKey string) (string, error) - GetAccountIDBySetupKey(peerKey string) (string, error) - GetAccountByPeerID(peerID string) (*Account, error) - GetAccountBySetupKey(setupKey string) (*Account, error) // todo use key hash later - GetAccountByPrivateDomain(domain string) (*Account, error) - GetTokenIDByHashedToken(secret string) (string, error) - GetUserByTokenID(tokenID string) (*User, error) + GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error) + GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) + GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later + GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) + GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) + GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) - SaveAccount(account *Account) error + SaveAccount(ctx context.Context, account *Account) error DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteTokenID2UserIDIndex(tokenID string) error GetInstallationID() string - SaveInstallationID(ID string) error + SaveInstallationID(ctx context.Context, ID string) error // AcquireAccountWriteLock should attempt to acquire account lock for write purposes and return a function that releases the lock - AcquireAccountWriteLock(accountID string) func() + AcquireAccountWriteLock(ctx context.Context, accountID string) func() // AcquireAccountReadLock should attempt to acquire account lock for read purposes and return a function that releases the lock - AcquireAccountReadLock(accountID string) func() + AcquireAccountReadLock(ctx context.Context, accountID string) func() // AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock - AcquireGlobalLock() func() + AcquireGlobalLock(ctx context.Context) func() SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error // Close should close the store persisting all unsaved data. - Close() error + Close(ctx context.Context) error // GetStoreEngine should return StoreEngine of the current store implementation. // This is also a method of metrics.DataSource interface. GetStoreEngine() StoreEngine - GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error) - GetAccountSettings(accountID string) (*Settings, error) + GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error) + GetAccountSettings(ctx context.Context, accountID string) (*Settings, error) } type StoreEngine string @@ -76,53 +81,76 @@ func getStoreEngineFromEnv() StoreEngine { } value := StoreEngine(strings.ToLower(kind)) - if value == FileStoreEngine || value == SqliteStoreEngine || value == PostgresStoreEngine { + if value == SqliteStoreEngine || value == PostgresStoreEngine { return value } return SqliteStoreEngine } -func getStoreEngineFromDatadir(dataDir string) StoreEngine { - storeFile := filepath.Join(dataDir, storeFileName) - if _, err := os.Stat(storeFile); err != nil { - // json file not found then use sqlite as default - return SqliteStoreEngine - } - return FileStoreEngine -} - -func NewStore(kind StoreEngine, dataDir string, metrics telemetry.AppMetrics) (Store, error) { +// getStoreEngine determines the store engine to use. +// If no engine is specified, it attempts to retrieve it from the environment. +// If still not specified, it defaults to using SQLite. +// Additionally, it handles the migration from a JSON store file to SQLite if applicable. +func getStoreEngine(ctx context.Context, dataDir string, kind StoreEngine) StoreEngine { if kind == "" { - // if store engine is not set in the config we first try to evaluate NETBIRD_STORE_ENGINE kind = getStoreEngineFromEnv() if kind == "" { - // NETBIRD_STORE_ENGINE is not set we evaluate default based on dataDir - kind = getStoreEngineFromDatadir(dataDir) + kind = SqliteStoreEngine + + // Migrate if it is the first run with a JSON file existing and no SQLite file present + jsonStoreFile := filepath.Join(dataDir, storeFileName) + sqliteStoreFile := filepath.Join(dataDir, storeSqliteFileName) + + if util.FileExists(jsonStoreFile) && !util.FileExists(sqliteStoreFile) { + log.WithContext(ctx).Warnf("unsupported store engine specified, but found %s. Automatically migrating to SQLite.", jsonStoreFile) + + // Attempt to migrate from JSON store to SQLite + if err := MigrateFileStoreToSqlite(ctx, dataDir); err != nil { + log.WithContext(ctx).Errorf("failed to migrate filestore to SQLite: %v", err) + kind = FileStoreEngine + } + } } } + + return kind +} + +// NewStore creates a new store based on the provided engine type, data directory, and telemetry metrics +func NewStore(ctx context.Context, kind StoreEngine, dataDir string, metrics telemetry.AppMetrics) (Store, error) { + kind = getStoreEngine(ctx, dataDir, kind) + + if err := checkFileStoreEngine(kind, dataDir); err != nil { + return nil, err + } + switch kind { - case FileStoreEngine: - log.Info("using JSON file store engine") - return NewFileStore(dataDir, metrics) case SqliteStoreEngine: - log.Info("using SQLite store engine") - return NewSqliteStore(dataDir, metrics) + log.WithContext(ctx).Info("using SQLite store engine") + return NewSqliteStore(ctx, dataDir, metrics) case PostgresStoreEngine: - log.Info("using Postgres store engine") - dsn, ok := os.LookupEnv(postgresDsnEnv) - if !ok { - return nil, fmt.Errorf("%s is not set", postgresDsnEnv) - } - return NewPostgresqlStore(dsn, metrics) + log.WithContext(ctx).Info("using Postgres store engine") + return newPostgresStore(ctx, metrics) default: - return nil, fmt.Errorf("unsupported kind of store %s", kind) + return nil, fmt.Errorf("unsupported kind of store: %s", kind) } } +func checkFileStoreEngine(kind StoreEngine, dataDir string) error { + if kind == FileStoreEngine { + storeFile := filepath.Join(dataDir, storeFileName) + if util.FileExists(storeFile) { + return fmt.Errorf("%s is not supported. Please refer to the documentation for migrating to SQLite: "+ + "https://docs.netbird.io/selfhosted/sqlite-store#migrating-from-json-store-to-sq-lite-store", FileStoreEngine) + } + } + return nil +} + // migrate migrates the SQLite database to the latest schema -func migrate(db *gorm.DB) error { - migrations := getMigrations() +func migrate(ctx context.Context, db *gorm.DB) error { + migrations := getMigrations(ctx) for _, m := range migrations { if err := m(db); err != nil { @@ -133,52 +161,45 @@ func migrate(db *gorm.DB) error { return nil } -func getMigrations() []migrationFunc { +func getMigrations(ctx context.Context) []migrationFunc { return []migrationFunc{ func(db *gorm.DB) error { - return migration.MigrateFieldFromGobToJSON[Account, net.IPNet](db, "network_net") + return migration.MigrateFieldFromGobToJSON[Account, net.IPNet](ctx, db, "network_net") }, func(db *gorm.DB) error { - return migration.MigrateFieldFromGobToJSON[route.Route, netip.Prefix](db, "network") + return migration.MigrateFieldFromGobToJSON[route.Route, netip.Prefix](ctx, db, "network") }, func(db *gorm.DB) error { - return migration.MigrateFieldFromGobToJSON[route.Route, []string](db, "peer_groups") + return migration.MigrateFieldFromGobToJSON[route.Route, []string](ctx, db, "peer_groups") }, func(db *gorm.DB) error { - return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "") + return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](ctx, db, "location_connection_ip", "") }, func(db *gorm.DB) error { - return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "ip", "idx_peers_account_id_ip") + return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](ctx, db, "ip", "idx_peers_account_id_ip") }, } } // NewTestStoreFromJson is only used in tests -func NewTestStoreFromJson(dataDir string) (Store, func(), error) { - fstore, err := NewFileStore(dataDir, nil) +func NewTestStoreFromJson(ctx context.Context, dataDir string) (Store, func(), error) { + fstore, err := NewFileStore(ctx, dataDir, nil) if err != nil { return nil, nil, err } - cleanUp := func() {} - // if store engine is not set in the config we first try to evaluate NETBIRD_STORE_ENGINE kind := getStoreEngineFromEnv() if kind == "" { - // NETBIRD_STORE_ENGINE is not set we evaluate default based on dataDir - kind = getStoreEngineFromDatadir(dataDir) + kind = SqliteStoreEngine } - switch kind { - case FileStoreEngine: - return fstore, cleanUp, nil - case SqliteStoreEngine: - store, err := NewSqliteStoreFromFileStore(fstore, dataDir, nil) - if err != nil { - return nil, nil, err - } - return store, cleanUp, nil - case PostgresStoreEngine: + var ( + store Store + cleanUp func() + ) + + if kind == PostgresStoreEngine { cleanUp, err = testutil.CreatePGDB() if err != nil { return nil, nil, err @@ -189,16 +210,52 @@ func NewTestStoreFromJson(dataDir string) (Store, func(), error) { return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv) } - store, err := NewPostgresqlStoreFromFileStore(fstore, dsn, nil) + store, err = NewPostgresqlStoreFromFileStore(ctx, fstore, dsn, nil) if err != nil { return nil, nil, err } - return store, cleanUp, nil - default: - store, err := NewSqliteStoreFromFileStore(fstore, dataDir, nil) + } else { + store, err = NewSqliteStoreFromFileStore(ctx, fstore, dataDir, nil) if err != nil { return nil, nil, err } - return store, cleanUp, nil + cleanUp = func() { store.Close(ctx) } } + + return store, cleanUp, nil +} + +// MigrateFileStoreToSqlite migrates the file store to the SQLite store. +func MigrateFileStoreToSqlite(ctx context.Context, dataDir string) error { + fileStorePath := path.Join(dataDir, storeFileName) + if _, err := os.Stat(fileStorePath); errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("%s doesn't exist, couldn't continue the operation", fileStorePath) + } + + sqlStorePath := path.Join(dataDir, storeSqliteFileName) + if _, err := os.Stat(sqlStorePath); err == nil { + return fmt.Errorf("%s already exists, couldn't continue the operation", sqlStorePath) + } + + fstore, err := NewFileStore(ctx, dataDir, nil) + if err != nil { + return fmt.Errorf("failed creating file store: %s: %v", dataDir, err) + } + + fsStoreAccounts := len(fstore.GetAllAccounts(ctx)) + log.WithContext(ctx).Infof("%d account will be migrated from file store %s to sqlite store %s", + fsStoreAccounts, fileStorePath, sqlStorePath) + + store, err := NewSqliteStoreFromFileStore(ctx, fstore, dataDir, nil) + if err != nil { + return fmt.Errorf("failed creating file store: %s: %v", dataDir, err) + } + + sqliteStoreAccounts := len(store.GetAllAccounts(ctx)) + if fsStoreAccounts != sqliteStoreAccounts { + return fmt.Errorf("failed to migrate accounts from file to sqlite. Expected accounts: %d, got: %d", + fsStoreAccounts, sqliteStoreAccounts) + } + + return nil } diff --git a/management/server/store_test.go b/management/server/store_test.go index 3f8c5d18b..40c36c9e0 100644 --- a/management/server/store_test.go +++ b/management/server/store_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "testing" @@ -15,13 +16,13 @@ type benchCase struct { var newFs = func(b *testing.B) Store { b.Helper() - store, _ := NewFileStore(b.TempDir(), nil) + store, _ := NewFileStore(context.Background(), b.TempDir(), nil) return store } var newSqlite = func(b *testing.B) Store { b.Helper() - store, _ := NewSqliteStore(b.TempDir(), nil) + store, _ := NewSqliteStore(context.Background(), b.TempDir(), nil) return store } @@ -76,13 +77,13 @@ func BenchmarkTest_StoreRead(b *testing.B) { _ = newAccount(store, i) } - accounts := store.GetAllAccounts() + accounts := store.GetAllAccounts(context.Background()) id := accounts[c.size-1].Id b.Run(name, func(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - _, _ = store.GetAccount(id) + _, _ = store.GetAccount(context.Background(), id) } }) }) diff --git a/management/server/telemetry/app_metrics.go b/management/server/telemetry/app_metrics.go index 56f4fb9c8..d88e18d8a 100644 --- a/management/server/telemetry/app_metrics.go +++ b/management/server/telemetry/app_metrics.go @@ -22,7 +22,7 @@ const defaultEndpoint = "/metrics" type MockAppMetrics struct { GetMeterFunc func() metric2.Meter CloseFunc func() error - ExposeFunc func(port int, endpoint string) error + ExposeFunc func(ctx context.Context, port int, endpoint string) error IDPMetricsFunc func() *IDPMetrics HTTPMiddlewareFunc func() *HTTPMiddleware GRPCMetricsFunc func() *GRPCMetrics @@ -47,9 +47,9 @@ func (mock *MockAppMetrics) Close() error { } // Expose mocks the Expose function of the AppMetrics interface -func (mock *MockAppMetrics) Expose(port int, endpoint string) error { +func (mock *MockAppMetrics) Expose(ctx context.Context, port int, endpoint string) error { if mock.ExposeFunc != nil { - return mock.ExposeFunc(port, endpoint) + return mock.ExposeFunc(ctx, port, endpoint) } return fmt.Errorf("unimplemented") } @@ -98,7 +98,7 @@ func (mock *MockAppMetrics) UpdateChannelMetrics() *UpdateChannelMetrics { type AppMetrics interface { GetMeter() metric2.Meter Close() error - Expose(port int, endpoint string) error + Expose(ctx context.Context, port int, endpoint string) error IDPMetrics() *IDPMetrics HTTPMiddleware() *HTTPMiddleware GRPCMetrics() *GRPCMetrics @@ -154,7 +154,7 @@ func (appMetrics *defaultAppMetrics) Close() error { // Expose metrics on a given port and endpoint. If endpoint is empty a defaultEndpoint one will be used. // Exposes metrics in the Prometheus format https://prometheus.io/ -func (appMetrics *defaultAppMetrics) Expose(port int, endpoint string) error { +func (appMetrics *defaultAppMetrics) Expose(ctx context.Context, port int, endpoint string) error { if endpoint == "" { endpoint = defaultEndpoint } @@ -174,7 +174,7 @@ func (appMetrics *defaultAppMetrics) Expose(port int, endpoint string) error { } }() - log.Infof("enabled application metrics and exposing on http://%s", listener.Addr().String()) + log.WithContext(ctx).Infof("enabled application metrics and exposing on http://%s", listener.Addr().String()) return nil } diff --git a/management/server/telemetry/http_api_metrics.go b/management/server/telemetry/http_api_metrics.go index c29533661..a80453dca 100644 --- a/management/server/telemetry/http_api_metrics.go +++ b/management/server/telemetry/http_api_metrics.go @@ -3,14 +3,17 @@ package telemetry import ( "context" "fmt" - "hash/fnv" "net/http" "strings" "time" + "github.com/google/uuid" log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" + + "github.com/netbirdio/netbird/formatter" + nbContext "github.com/netbirdio/netbird/management/server/context" ) const ( @@ -163,8 +166,15 @@ func getResponseCounterKey(endpoint, method string, status int) string { func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler { fn := func(rw http.ResponseWriter, r *http.Request) { reqStart := time.Now() - traceID := hash(fmt.Sprintf("%v", r)) - log.Tracef("HTTP request %v: %v %v", traceID, r.Method, r.URL) + + //nolint + ctx := context.WithValue(r.Context(), formatter.ExecutionContextKey, formatter.HTTPSource) + + reqID := uuid.New().String() + //nolint + ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID) + + log.WithContext(ctx).Tracef("HTTP request %v: %v %v", reqID, r.Method, r.URL) metricKey := getRequestCounterKey(r.URL.Path, r.Method) @@ -175,12 +185,12 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler { w := WrapResponseWriter(rw) - h.ServeHTTP(w, r) + h.ServeHTTP(w, r.WithContext(ctx)) if w.Status() > 399 { - log.Errorf("HTTP response %v: %v %v status %v", traceID, r.Method, r.URL, w.Status()) + log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status()) } else { - log.Tracef("HTTP response %v: %v %v status %v", traceID, r.Method, r.URL, w.Status()) + log.WithContext(ctx).Tracef("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status()) } metricKey = getResponseCounterKey(r.URL.Path, r.Method, w.Status()) @@ -198,7 +208,7 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler { if c, ok := m.httpRequestDurations[durationKey]; ok { c.Record(m.ctx, reqTook.Milliseconds()) } - log.Debugf("request %s %s took %d ms and finished with status %d", r.Method, r.URL.Path, reqTook.Milliseconds(), w.Status()) + log.WithContext(ctx).Debugf("request %s %s took %d ms and finished with status %d", r.Method, r.URL.Path, reqTook.Milliseconds(), w.Status()) if w.Status() == 200 && (r.Method == http.MethodPut || r.Method == http.MethodPost || r.Method == http.MethodDelete) { opts := metric.WithAttributeSet(attribute.NewSet(attribute.String("type", "write"))) @@ -212,12 +222,3 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler { return http.HandlerFunc(fn) } - -func hash(s string) uint32 { - h := fnv.New32a() - _, err := h.Write([]byte(s)) - if err != nil { - panic(err) - } - return h.Sum32() -} diff --git a/management/server/testutil/store.go b/management/server/testutil/store.go index 8db95bd2c..156a762fb 100644 --- a/management/server/testutil/store.go +++ b/management/server/testutil/store.go @@ -33,7 +33,7 @@ func CreatePGDB() (func(), error) { timeout := 10 * time.Second err = c.Stop(ctx, &timeout) if err != nil { - log.Warnf("failed to stop container: %s", err) + log.WithContext(ctx).Warnf("failed to stop container: %s", err) } } diff --git a/management/server/token_mgr.go b/management/server/token_mgr.go index 3f30d0494..f5003004b 100644 --- a/management/server/token_mgr.go +++ b/management/server/token_mgr.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "sync" "time" @@ -14,7 +15,7 @@ import ( // TURNRelayTokenManager used to manage TURN credentials type TURNRelayTokenManager interface { Generate() (*TURNRelayToken, error) - SetupRefresh(peerKey string) + SetupRefresh(ctx context.Context, peerKey string) CancelRefresh(peerKey string) } @@ -67,13 +68,13 @@ func (m *TimeBasedAuthSecretsManager) CancelRefresh(peerID string) { // SetupRefresh starts peer credentials refresh. Since credentials are expiring (TTL) it is necessary to always generate them and send to the peer. // A goroutine is created and put into TimeBasedAuthSecretsManager.cancelMap. This routine should be cancelled if peer is gone. -func (m *TimeBasedAuthSecretsManager) SetupRefresh(peerID string) { +func (m *TimeBasedAuthSecretsManager) SetupRefresh(ctx context.Context, peerID string) { m.mux.Lock() defer m.mux.Unlock() m.cancel(peerID) cancel := make(chan struct{}, 1) m.cancelMap[peerID] = cancel - log.Debugf("starting turn refresh for %s", peerID) + log.WithContext(ctx).Debugf("starting turn refresh for %s", peerID) go func() { // we don't want to regenerate credentials right on expiration, so we do it slightly before (at 3/4 of TTL) @@ -83,16 +84,16 @@ func (m *TimeBasedAuthSecretsManager) SetupRefresh(peerID string) { for { select { case <-cancel: - log.Debugf("stopping turn refresh for %s", peerID) + log.WithContext(ctx).Debugf("stopping turn refresh for %s", peerID) return case <-ticker.C: - m.pushNewTokens(peerID) + m.pushNewTokens(ctx, peerID) } } }() } -func (m *TimeBasedAuthSecretsManager) pushNewTokens(peerID string) { +func (m *TimeBasedAuthSecretsManager) pushNewTokens(ctx context.Context, peerID string) { token, err := m.hmacToken.GenerateToken() if err != nil { log.Errorf("failed to generate token for peer '%s': %s", peerID, err) @@ -121,6 +122,6 @@ func (m *TimeBasedAuthSecretsManager) pushNewTokens(peerID string) { }, }, } - log.Debugf("sending new TURN credentials to peer %s", peerID) - m.updateManager.SendUpdate(peerID, &UpdateMessage{Update: update}) + log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID) + m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update}) } diff --git a/management/server/token_mgr_test.go b/management/server/token_mgr_test.go index 70314f4fc..6ac3571bb 100644 --- a/management/server/token_mgr_test.go +++ b/management/server/token_mgr_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "crypto/hmac" "crypto/sha1" "encoding/base64" @@ -46,7 +47,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) { secret := "some_secret" peersManager := NewPeersUpdateManager(nil) peer := "some_peer" - updateChannel := peersManager.CreateChannel(peer) + updateChannel := peersManager.CreateChannel(context.Background(), peer) tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{ CredentialsTTL: ttl, @@ -54,7 +55,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) { Turns: []*Host{TurnTestHost}, }, "") - tested.SetupRefresh(peer) + tested.SetupRefresh(context.Background(), peer) if _, ok := tested.cancelMap[peer]; !ok { t.Errorf("expecting peer to be present in a cancel map, got not present") @@ -102,7 +103,7 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) { Turns: []*Host{TurnTestHost}, }, "") - tested.SetupRefresh(peer) + tested.SetupRefresh(context.Background(), peer) if _, ok := tested.cancelMap[peer]; !ok { t.Errorf("expecting peer to be present in a cancel map, got not present") } diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go index f760c5a75..c11225dbc 100644 --- a/management/server/updatechannel.go +++ b/management/server/updatechannel.go @@ -1,6 +1,7 @@ package server import ( + "context" "sync" "time" @@ -35,7 +36,7 @@ func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager { } // SendUpdate sends update message to the peer's channel -func (p *PeersUpdateManager) SendUpdate(peerID string, update *UpdateMessage) { +func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, update *UpdateMessage) { start := time.Now() var found, dropped bool @@ -51,18 +52,18 @@ func (p *PeersUpdateManager) SendUpdate(peerID string, update *UpdateMessage) { found = true select { case channel <- update: - log.Debugf("update was sent to channel for peer %s", peerID) + log.WithContext(ctx).Debugf("update was sent to channel for peer %s", peerID) default: dropped = true - log.Warnf("channel for peer %s is %d full", peerID, len(channel)) + log.WithContext(ctx).Warnf("channel for peer %s is %d full", peerID, len(channel)) } } else { - log.Debugf("peer %s has no channel", peerID) + log.WithContext(ctx).Debugf("peer %s has no channel", peerID) } } // CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer. -func (p *PeersUpdateManager) CreateChannel(peerID string) chan *UpdateMessage { +func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage { start := time.Now() closed := false @@ -84,22 +85,22 @@ func (p *PeersUpdateManager) CreateChannel(peerID string) chan *UpdateMessage { channel := make(chan *UpdateMessage, channelBufferSize) p.peerChannels[peerID] = channel - log.Debugf("opened updates channel for a peer %s", peerID) + log.WithContext(ctx).Debugf("opened updates channel for a peer %s", peerID) return channel } -func (p *PeersUpdateManager) closeChannel(peerID string) { +func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) { if channel, ok := p.peerChannels[peerID]; ok { delete(p.peerChannels, peerID) close(channel) } - log.Debugf("closed updates channel of a peer %s", peerID) + log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID) } // CloseChannels closes updates channel for each given peer -func (p *PeersUpdateManager) CloseChannels(peerIDs []string) { +func (p *PeersUpdateManager) CloseChannels(ctx context.Context, peerIDs []string) { start := time.Now() p.channelsMux.Lock() @@ -111,12 +112,12 @@ func (p *PeersUpdateManager) CloseChannels(peerIDs []string) { }() for _, id := range peerIDs { - p.closeChannel(id) + p.closeChannel(ctx, id) } } // CloseChannel closes updates channel of a given peer -func (p *PeersUpdateManager) CloseChannel(peerID string) { +func (p *PeersUpdateManager) CloseChannel(ctx context.Context, peerID string) { start := time.Now() p.channelsMux.Lock() @@ -127,7 +128,7 @@ func (p *PeersUpdateManager) CloseChannel(peerID string) { } }() - p.closeChannel(peerID) + p.closeChannel(ctx, peerID) } // GetAllConnectedPeers returns a copy of the connected peers map diff --git a/management/server/updatechannel_test.go b/management/server/updatechannel_test.go index 187e404c5..69f5b895c 100644 --- a/management/server/updatechannel_test.go +++ b/management/server/updatechannel_test.go @@ -1,20 +1,21 @@ package server import ( + "context" "testing" "time" "github.com/netbirdio/netbird/management/proto" ) -//var peersUpdater *PeersUpdateManager +// var peersUpdater *PeersUpdateManager func TestCreateChannel(t *testing.T) { peer := "test-create" peersUpdater := NewPeersUpdateManager(nil) - defer peersUpdater.CloseChannel(peer) + defer peersUpdater.CloseChannel(context.Background(), peer) - _ = peersUpdater.CreateChannel(peer) + _ = peersUpdater.CreateChannel(context.Background(), peer) if _, ok := peersUpdater.peerChannels[peer]; !ok { t.Error("Error creating the channel") } @@ -28,11 +29,11 @@ func TestSendUpdate(t *testing.T) { Serial: 0, }, }} - _ = peersUpdater.CreateChannel(peer) + _ = peersUpdater.CreateChannel(context.Background(), peer) if _, ok := peersUpdater.peerChannels[peer]; !ok { t.Error("Error creating the channel") } - peersUpdater.SendUpdate(peer, update1) + peersUpdater.SendUpdate(context.Background(), peer, update1) select { case <-peersUpdater.peerChannels[peer]: default: @@ -40,7 +41,7 @@ func TestSendUpdate(t *testing.T) { } for range [channelBufferSize]int{} { - peersUpdater.SendUpdate(peer, update1) + peersUpdater.SendUpdate(context.Background(), peer, update1) } update2 := &UpdateMessage{Update: &proto.SyncResponse{ @@ -49,7 +50,7 @@ func TestSendUpdate(t *testing.T) { }, }} - peersUpdater.SendUpdate(peer, update2) + peersUpdater.SendUpdate(context.Background(), peer, update2) timeout := time.After(5 * time.Second) for range [channelBufferSize]int{} { select { @@ -67,11 +68,11 @@ func TestSendUpdate(t *testing.T) { func TestCloseChannel(t *testing.T) { peer := "test-close" peersUpdater := NewPeersUpdateManager(nil) - _ = peersUpdater.CreateChannel(peer) + _ = peersUpdater.CreateChannel(context.Background(), peer) if _, ok := peersUpdater.peerChannels[peer]; !ok { t.Error("Error creating the channel") } - peersUpdater.CloseChannel(peer) + peersUpdater.CloseChannel(context.Background(), peer) if _, ok := peersUpdater.peerChannels[peer]; ok { t.Error("Error closing the channel") } diff --git a/management/server/user.go b/management/server/user.go index 2be73fa07..302cfccaa 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "strings" "time" @@ -209,11 +210,11 @@ func NewOwnerUser(id string) *User { } // createServiceUser creates a new service user under the given account. -func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID) } @@ -232,16 +233,16 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUs newUserID := uuid.New().String() newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI) - log.Debugf("New User: %v", newUser) + log.WithContext(ctx).Debugf("New User: %v", newUser) account.Users[newUserID] = newUser - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return nil, err } meta := map[string]any{"name": newUser.ServiceUserName} - am.StoreEvent(initiatorUserID, newUser.Id, accountID, activity.ServiceUserCreated, meta) + am.StoreEvent(ctx, initiatorUserID, newUser.Id, accountID, activity.ServiceUserCreated, meta) return &UserInfo{ ID: newUser.Id, @@ -257,16 +258,16 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUs } // CreateUser creates a new user under the given account. Effectively this is a user invite. -func (am *DefaultAccountManager) CreateUser(accountID, userID string, user *UserInfo) (*UserInfo, error) { +func (am *DefaultAccountManager) CreateUser(ctx context.Context, accountID, userID string, user *UserInfo) (*UserInfo, error) { if user.IsServiceUser { - return am.createServiceUser(accountID, userID, StrRoleToUserRole(user.Role), user.Name, user.NonDeletable, user.AutoGroups) + return am.createServiceUser(ctx, accountID, userID, StrRoleToUserRole(user.Role), user.Name, user.NonDeletable, user.AutoGroups) } - return am.inviteNewUser(accountID, userID, user) + return am.inviteNewUser(ctx, accountID, userID, user) } // inviteNewUser Invites a USer to a given account and creates reference in datastore -func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite *UserInfo) (*UserInfo, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, userID string, invite *UserInfo) (*UserInfo, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() if am.idpManager == nil { @@ -289,7 +290,7 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite default: } - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID) } @@ -305,13 +306,13 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite } // inviterUser is the one who is inviting the new user - inviterUser, err := am.lookupUserInCache(inviterID, account) + inviterUser, err := am.lookupUserInCache(ctx, inviterID, account) if err != nil || inviterUser == nil { return nil, status.Errorf(status.NotFound, "inviter user with ID %s doesn't exist in IdP", inviterID) } // check if the user is already registered with this email => reject - user, err := am.lookupUserInCacheByEmail(invite.Email, accountID) + user, err := am.lookupUserInCacheByEmail(ctx, invite.Email, accountID) if err != nil { return nil, err } @@ -320,7 +321,7 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account") } - users, err := am.idpManager.GetUserByEmail(invite.Email) + users, err := am.idpManager.GetUserByEmail(ctx, invite.Email) if err != nil { return nil, err } @@ -329,7 +330,7 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account") } - idpUser, err := am.idpManager.CreateUser(invite.Email, invite.Name, accountID, inviterUser.Email) + idpUser, err := am.idpManager.CreateUser(ctx, invite.Email, invite.Name, accountID, inviterUser.Email) if err != nil { return nil, err } @@ -344,33 +345,33 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite } account.Users[idpUser.ID] = newUser - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return nil, err } - _, err = am.refreshCache(account.Id) + _, err = am.refreshCache(ctx, account.Id) if err != nil { return nil, err } - am.StoreEvent(userID, newUser.Id, accountID, activity.UserInvited, nil) + am.StoreEvent(ctx, userID, newUser.Id, accountID, activity.UserInvited, nil) return newUser.ToUserInfo(idpUser, account.Settings) } // GetUser looks up a user by provided authorization claims. // It will also create an account if didn't exist for this user before. -func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (*User, error) { - account, _, err := am.GetAccountFromToken(claims) +func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) { + account, _, err := am.GetAccountFromToken(ctx, claims) if err != nil { return nil, fmt.Errorf("failed to get account with token claims %v", err) } - unlock := am.Store.AcquireAccountWriteLock(account.Id) + unlock := am.Store.AcquireAccountWriteLock(ctx, account.Id) defer unlock() - account, err = am.Store.GetAccount(account.Id) + account, err = am.Store.GetAccount(ctx, account.Id) if err != nil { return nil, fmt.Errorf("failed to get an account from store %v", err) } @@ -386,12 +387,12 @@ func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) ( err = am.Store.SaveUserLastLogin(account.Id, claims.UserId, claims.LastLogin) if err != nil { - log.Errorf("failed saving user last login: %v", err) + log.WithContext(ctx).Errorf("failed saving user last login: %v", err) } if newLogin { meta := map[string]any{"timestamp": claims.LastLogin} - am.StoreEvent(claims.UserId, claims.UserId, account.Id, activity.DashboardLogin, meta) + am.StoreEvent(ctx, claims.UserId, claims.UserId, account.Id, activity.DashboardLogin, meta) } return user, nil @@ -399,11 +400,11 @@ func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) ( // ListUsers returns lists of all users under the account. // It doesn't populate user information such as email or name. -func (am *DefaultAccountManager) ListUsers(accountID string) ([]*User, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*User, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -416,21 +417,21 @@ func (am *DefaultAccountManager) ListUsers(accountID string) ([]*User, error) { return users, nil } -func (am *DefaultAccountManager) deleteServiceUser(account *Account, initiatorUserID string, targetUser *User) { +func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, account *Account, initiatorUserID string, targetUser *User) { meta := map[string]any{"name": targetUser.ServiceUserName, "created_at": targetUser.CreatedAt} - am.StoreEvent(initiatorUserID, targetUser.Id, account.Id, activity.ServiceUserDeleted, meta) + am.StoreEvent(ctx, initiatorUserID, targetUser.Id, account.Id, activity.ServiceUserDeleted, meta) delete(account.Users, targetUser.Id) } // DeleteUser deletes a user from the given account. -func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, targetUserID string) error { +func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error { if initiatorUserID == targetUserID { return status.Errorf(status.InvalidArgument, "self deletion is not allowed") } - unlock := am.Store.AcquireAccountWriteLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } @@ -463,43 +464,43 @@ func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, t return status.Errorf(status.PermissionDenied, "service user is marked as non-deletable") } - am.deleteServiceUser(account, initiatorUserID, targetUser) - return am.Store.SaveAccount(account) + am.deleteServiceUser(ctx, account, initiatorUserID, targetUser) + return am.Store.SaveAccount(ctx, account) } - return am.deleteRegularUser(account, initiatorUserID, targetUserID) + return am.deleteRegularUser(ctx, account, initiatorUserID, targetUserID) } -func (am *DefaultAccountManager) deleteRegularUser(account *Account, initiatorUserID, targetUserID string) error { - tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(account.Id, initiatorUserID, targetUserID) +func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account *Account, initiatorUserID, targetUserID string) error { + tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(ctx, account.Id, initiatorUserID, targetUserID) if err != nil { - log.Errorf("failed to resolve email address: %s", err) + log.WithContext(ctx).Errorf("failed to resolve email address: %s", err) return err } if !isNil(am.idpManager) { // Delete if the user already exists in the IdP.Necessary in cases where a user account // was created where a user account was provisioned but the user did not sign in - _, err = am.idpManager.GetUserDataByID(targetUserID, idp.AppMetadata{WTAccountID: account.Id}) + _, err = am.idpManager.GetUserDataByID(ctx, targetUserID, idp.AppMetadata{WTAccountID: account.Id}) if err == nil { - err = am.deleteUserFromIDP(targetUserID, account.Id) + err = am.deleteUserFromIDP(ctx, targetUserID, account.Id) if err != nil { - log.Debugf("failed to delete user from IDP: %s", targetUserID) + log.WithContext(ctx).Debugf("failed to delete user from IDP: %s", targetUserID) return err } } else { - log.Debugf("skipped deleting user %s from IDP, error: %v", targetUserID, err) + log.WithContext(ctx).Debugf("skipped deleting user %s from IDP, error: %v", targetUserID, err) } } - err = am.deleteUserPeers(initiatorUserID, targetUserID, account) + err = am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account) if err != nil { return err } u, err := account.FindUser(targetUserID) if err != nil { - log.Errorf("failed to find user %s for deletion, this should never happen: %s", targetUserID, err) + log.WithContext(ctx).Errorf("failed to find user %s for deletion, this should never happen: %s", targetUserID, err) } var tuCreatedAt time.Time @@ -508,20 +509,20 @@ func (am *DefaultAccountManager) deleteRegularUser(account *Account, initiatorUs } delete(account.Users, targetUserID) - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return err } meta := map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt} - am.StoreEvent(initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta) + am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta) - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) return nil } -func (am *DefaultAccountManager) deleteUserPeers(initiatorUserID string, targetUserID string, account *Account) error { +func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *Account) error { peers, err := account.FindUserPeers(targetUserID) if err != nil { return status.Errorf(status.Internal, "failed to find user peers") @@ -532,25 +533,25 @@ func (am *DefaultAccountManager) deleteUserPeers(initiatorUserID string, targetU peerIDs = append(peerIDs, peer.ID) } - return am.deletePeers(account, peerIDs, initiatorUserID) + return am.deletePeers(ctx, account, peerIDs, initiatorUserID) } // InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period. -func (am *DefaultAccountManager) InviteUser(accountID string, initiatorUserID string, targetUserID string) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() if am.idpManager == nil { return status.Errorf(status.PreconditionFailed, "IdP manager must be enabled to send user invites") } - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return status.Errorf(status.NotFound, "account %s doesn't exist", accountID) } // check if the user is already registered with this ID - user, err := am.lookupUserInCache(targetUserID, account) + user, err := am.lookupUserInCache(ctx, targetUserID, account) if err != nil { return err } @@ -565,19 +566,19 @@ func (am *DefaultAccountManager) InviteUser(accountID string, initiatorUserID st return status.Errorf(status.PreconditionFailed, "can't invite a user with an activated NetBird account") } - err = am.idpManager.InviteUserByID(user.ID) + err = am.idpManager.InviteUserByID(ctx, user.ID) if err != nil { return err } - am.StoreEvent(initiatorUserID, user.ID, accountID, activity.UserInvited, nil) + am.StoreEvent(ctx, initiatorUserID, user.ID, accountID, activity.UserInvited, nil) return nil } // CreatePAT creates a new PAT for the given user -func (am *DefaultAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() if tokenName == "" { @@ -588,7 +589,7 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, initiatorUserID str return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365") } - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -614,23 +615,23 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, initiatorUserID str targetUser.PATs[pat.ID] = &pat.PersonalAccessToken - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return nil, status.Errorf(status.Internal, "failed to save account: %v", err) } meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName} - am.StoreEvent(initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenCreated, meta) + am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenCreated, meta) return pat, nil } // DeletePAT deletes a specific PAT from a user -func (am *DefaultAccountManager) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return status.Errorf(status.NotFound, "account not found: %s", err) } @@ -664,11 +665,11 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, initiatorUserID str } meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName} - am.StoreEvent(initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta) + am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta) delete(targetUser.PATs, tokenID) - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return status.Errorf(status.Internal, "Failed to save account: %s", err) } @@ -676,11 +677,11 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, initiatorUserID str } // GetPAT returns a specific PAT from a user -func (am *DefaultAccountManager) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, status.Errorf(status.NotFound, "account not found: %s", err) } @@ -708,11 +709,11 @@ func (am *DefaultAccountManager) GetPAT(accountID string, initiatorUserID string } // GetAllPATs returns all PATs for a user -func (am *DefaultAccountManager) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, status.Errorf(status.NotFound, "account not found: %s", err) } @@ -740,21 +741,21 @@ func (am *DefaultAccountManager) GetAllPATs(accountID string, initiatorUserID st } // SaveUser saves updates to the given user. If the user doesn't exit it will throw status.NotFound error. -func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error) { - return am.SaveOrAddUser(accountID, initiatorUserID, update, false) // false means do not create user and throw status.NotFound +func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error) { + return am.SaveOrAddUser(ctx, accountID, initiatorUserID, update, false) // false means do not create user and throw status.NotFound } // SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist // Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now. -func (am *DefaultAccountManager) SaveOrAddUser(accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() if update == nil { return nil, status.Errorf(status.InvalidArgument, "provided user update is nil") } - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -834,8 +835,8 @@ func (am *DefaultAccountManager) SaveOrAddUser(accountID, initiatorUserID string return nil, err } - if err := am.expireAndUpdatePeers(account, blockedPeers); err != nil { - log.Errorf("failed update expired peers: %s", err) + if err := am.expireAndUpdatePeers(ctx, account, blockedPeers); err != nil { + log.WithContext(ctx).Errorf("failed update expired peers: %s", err) return nil, err } } @@ -847,13 +848,13 @@ func (am *DefaultAccountManager) SaveOrAddUser(accountID, initiatorUserID string account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...) account.Network.IncSerial() - if err = am.Store.SaveAccount(account); err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return nil, err } - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) } else { - if err = am.Store.SaveAccount(account); err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return nil, err } } @@ -861,17 +862,17 @@ func (am *DefaultAccountManager) SaveOrAddUser(accountID, initiatorUserID string defer func() { if oldUser.IsBlocked() != update.IsBlocked() { if update.IsBlocked() { - am.StoreEvent(initiatorUserID, oldUser.Id, accountID, activity.UserBlocked, nil) + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserBlocked, nil) } else { - am.StoreEvent(initiatorUserID, oldUser.Id, accountID, activity.UserUnblocked, nil) + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserUnblocked, nil) } } switch { case transferedOwnerRole: - am.StoreEvent(initiatorUserID, oldUser.Id, accountID, activity.TransferredOwnerRole, nil) + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.TransferredOwnerRole, nil) case oldUser.Role != newUser.Role: - am.StoreEvent(initiatorUserID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role}) + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role}) default: } @@ -881,17 +882,17 @@ func (am *DefaultAccountManager) SaveOrAddUser(accountID, initiatorUserID string for _, g := range removedGroups { group := account.GetGroup(g) if group != nil { - am.StoreEvent(initiatorUserID, oldUser.Id, accountID, activity.GroupRemovedFromUser, + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupRemovedFromUser, map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) } else { - log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id) + log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id) } } for _, g := range addedGroups { group := account.GetGroup(g) if group != nil { - am.StoreEvent(initiatorUserID, oldUser.Id, accountID, activity.GroupAddedToUser, + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupAddedToUser, map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) } } @@ -899,7 +900,7 @@ func (am *DefaultAccountManager) SaveOrAddUser(accountID, initiatorUserID string }() if !isNil(am.idpManager) && !newUser.IsServiceUser { - userData, err := am.lookupUserInCache(newUser.Id, account) + userData, err := am.lookupUserInCache(ctx, newUser.Id, account) if err != nil { return nil, err } @@ -909,22 +910,22 @@ func (am *DefaultAccountManager) SaveOrAddUser(accountID, initiatorUserID string } // GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist -func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string) (*Account, error) { +func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, userID, domain string) (*Account, error) { start := time.Now() - unlock := am.Store.AcquireGlobalLock() + unlock := am.Store.AcquireGlobalLock(ctx) defer unlock() - log.Debugf("Acquired global lock in %s for user %s", time.Since(start), userID) + log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), userID) lowerDomain := strings.ToLower(domain) - account, err := am.Store.GetAccountByUser(userID) + account, err := am.Store.GetAccountByUser(ctx, userID) if err != nil { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { - account, err = am.newAccount(userID, lowerDomain) + account, err = am.newAccount(ctx, userID, lowerDomain) if err != nil { return nil, err } - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return nil, err } @@ -938,7 +939,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string) if account.Domain != lowerDomain && userObj.Role == UserRoleOwner { account.Domain = lowerDomain - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return nil, status.Errorf(status.Internal, "failed updating account with domain") } @@ -949,8 +950,8 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string) // GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return // based on provided user role. -func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) { - account, err := am.Store.GetAccount(accountID) +func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error) { + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -969,7 +970,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ( key := user.IntegrationReference.CacheKey(accountID, user.Id) info, err := am.externalCacheManager.Get(am.ctx, key) if err != nil { - log.Infof("Get ExternalCache for key: %s, error: %s", key, err) + log.WithContext(ctx).Infof("Get ExternalCache for key: %s, error: %s", key, err) users[user.Id] = true continue } @@ -980,12 +981,12 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ( users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero()) } } - queriedUsers, err = am.lookupCache(users, accountID) + queriedUsers, err = am.lookupCache(ctx, users, accountID) if err != nil { return nil, err } - log.Debugf("Got %d users from ExternalCache for account %s", len(usersFromIntegration), accountID) - log.Debugf("Got %d users from InternalCache for account %s", len(queriedUsers), accountID) + log.WithContext(ctx).Debugf("Got %d users from ExternalCache for account %s", len(usersFromIntegration), accountID) + log.WithContext(ctx).Debugf("Got %d users from InternalCache for account %s", len(queriedUsers), accountID) queriedUsers = append(queriedUsers, usersFromIntegration...) } @@ -1052,7 +1053,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ( } // expireAndUpdatePeers expires all peers of the given user and updates them in the account -func (am *DefaultAccountManager) expireAndUpdatePeers(account *Account, peers []*nbpeer.Peer) error { +func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, account *Account, peers []*nbpeer.Peer) error { var peerIDs []string for _, peer := range peers { if peer.Status.LoginExpired { @@ -1065,6 +1066,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(account *Account, peers [] return err } am.StoreEvent( + ctx, peer.UserID, peer.ID, account.Id, activity.PeerLoginExpired, peer.EventMeta(am.GetDNSDomain()), ) @@ -1072,34 +1074,34 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(account *Account, peers [] if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service - am.peersUpdateManager.CloseChannels(peerIDs) - am.updateAccountPeers(account) + am.peersUpdateManager.CloseChannels(ctx, peerIDs) + am.updateAccountPeers(ctx, account) } return nil } -func (am *DefaultAccountManager) deleteUserFromIDP(targetUserID, accountID string) error { +func (am *DefaultAccountManager) deleteUserFromIDP(ctx context.Context, targetUserID, accountID string) error { if am.userDeleteFromIDPEnabled { - log.Debugf("user %s deleted from IdP", targetUserID) - err := am.idpManager.DeleteUser(targetUserID) + log.WithContext(ctx).Debugf("user %s deleted from IdP", targetUserID) + err := am.idpManager.DeleteUser(ctx, targetUserID) if err != nil { return fmt.Errorf("failed to delete user %s from IdP: %s", targetUserID, err) } } else { - err := am.idpManager.UpdateUserAppMetadata(targetUserID, idp.AppMetadata{}) + err := am.idpManager.UpdateUserAppMetadata(ctx, targetUserID, idp.AppMetadata{}) if err != nil { return fmt.Errorf("failed to remove user %s app metadata in IdP: %s", targetUserID, err) } } - err := am.removeUserFromCache(accountID, targetUserID) + err := am.removeUserFromCache(ctx, accountID, targetUserID) if err != nil { - log.Errorf("remove user from account (%q) cache failed with error: %v", accountID, err) + log.WithContext(ctx).Errorf("remove user from account (%q) cache failed with error: %v", accountID, err) } return nil } -func (am *DefaultAccountManager) getEmailAndNameOfTargetUser(accountId, initiatorId, targetId string) (string, string, error) { - userInfos, err := am.GetUsersFromAccount(accountId, initiatorId) +func (am *DefaultAccountManager) getEmailAndNameOfTargetUser(ctx context.Context, accountId, initiatorId, targetId string) (string, string, error) { + userInfos, err := am.GetUsersFromAccount(ctx, accountId, initiatorId) if err != nil { return "", "", err } diff --git a/management/server/user_test.go b/management/server/user_test.go index 5edb811c6..99d2792df 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -39,10 +39,10 @@ const ( func TestUser_CreatePAT_ForSameUser(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -52,7 +52,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - pat, err := am.CreatePAT(mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn) + pat, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn) if err != nil { t.Fatalf("Error when adding PAT to user: %s", err) } @@ -77,13 +77,13 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) { func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account.Users[mockTargetUserId] = &User{ Id: mockTargetUserId, IsServiceUser: false, } - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -93,19 +93,19 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - _, err = am.CreatePAT(mockAccountID, mockUserID, mockTargetUserId, mockTokenName, mockExpiresIn) + _, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockTargetUserId, mockTokenName, mockExpiresIn) assert.Errorf(t, err, "Creating PAT for different user should thorw error") } func TestUser_CreatePAT_ForServiceUser(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account.Users[mockTargetUserId] = &User{ Id: mockTargetUserId, IsServiceUser: true, } - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -115,7 +115,7 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - pat, err := am.CreatePAT(mockAccountID, mockUserID, mockTargetUserId, mockTokenName, mockExpiresIn) + pat, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockTargetUserId, mockTokenName, mockExpiresIn) if err != nil { t.Fatalf("Error when adding PAT to user: %s", err) } @@ -125,10 +125,10 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) { func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -138,16 +138,16 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - _, err = am.CreatePAT(mockAccountID, mockUserID, mockUserID, mockTokenName, mockWrongExpiresIn) + _, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockWrongExpiresIn) assert.Errorf(t, err, "Wrong expiration should thorw error") } func TestUser_CreatePAT_WithEmptyName(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -157,14 +157,14 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - _, err = am.CreatePAT(mockAccountID, mockUserID, mockUserID, mockEmptyTokenName, mockExpiresIn) + _, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockEmptyTokenName, mockExpiresIn) assert.Errorf(t, err, "Wrong expiration should thorw error") } func TestUser_DeletePAT(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account.Users[mockUserID] = &User{ Id: mockUserID, PATs: map[string]*PersonalAccessToken{ @@ -174,7 +174,7 @@ func TestUser_DeletePAT(t *testing.T) { }, }, } - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -184,7 +184,7 @@ func TestUser_DeletePAT(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - err = am.DeletePAT(mockAccountID, mockUserID, mockUserID, mockTokenID1) + err = am.DeletePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenID1) if err != nil { t.Fatalf("Error when adding PAT to user: %s", err) } @@ -196,8 +196,8 @@ func TestUser_DeletePAT(t *testing.T) { func TestUser_GetPAT(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account.Users[mockUserID] = &User{ Id: mockUserID, PATs: map[string]*PersonalAccessToken{ @@ -207,7 +207,7 @@ func TestUser_GetPAT(t *testing.T) { }, }, } - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -217,7 +217,7 @@ func TestUser_GetPAT(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - pat, err := am.GetPAT(mockAccountID, mockUserID, mockUserID, mockTokenID1) + pat, err := am.GetPAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenID1) if err != nil { t.Fatalf("Error when adding PAT to user: %s", err) } @@ -228,8 +228,8 @@ func TestUser_GetPAT(t *testing.T) { func TestUser_GetAllPATs(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account.Users[mockUserID] = &User{ Id: mockUserID, PATs: map[string]*PersonalAccessToken{ @@ -243,7 +243,7 @@ func TestUser_GetAllPATs(t *testing.T) { }, }, } - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -253,7 +253,7 @@ func TestUser_GetAllPATs(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - pats, err := am.GetAllPATs(mockAccountID, mockUserID, mockUserID) + pats, err := am.GetAllPATs(context.Background(), mockAccountID, mockUserID, mockUserID) if err != nil { t.Fatalf("Error when adding PAT to user: %s", err) } @@ -330,10 +330,10 @@ func validateStruct(s interface{}) (err error) { func TestUser_CreateServiceUser(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -343,7 +343,7 @@ func TestUser_CreateServiceUser(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - user, err := am.createServiceUser(mockAccountID, mockUserID, mockRole, mockServiceUserName, false, []string{"group1", "group2"}) + user, err := am.createServiceUser(context.Background(), mockAccountID, mockUserID, mockRole, mockServiceUserName, false, []string{"group1", "group2"}) if err != nil { t.Fatalf("Error when creating service user: %s", err) } @@ -360,7 +360,7 @@ func TestUser_CreateServiceUser(t *testing.T) { assert.True(t, user.IsServiceUser) assert.Equal(t, "active", user.Status) - _, err = am.createServiceUser(mockAccountID, mockUserID, UserRoleOwner, mockServiceUserName, false, nil) + _, err = am.createServiceUser(context.Background(), mockAccountID, mockUserID, UserRoleOwner, mockServiceUserName, false, nil) if err == nil { t.Fatal("should return error when creating service user with owner role") } @@ -368,10 +368,10 @@ func TestUser_CreateServiceUser(t *testing.T) { func TestUser_CreateUser_ServiceUser(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -381,7 +381,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - user, err := am.CreateUser(mockAccountID, mockUserID, &UserInfo{ + user, err := am.CreateUser(context.Background(), mockAccountID, mockUserID, &UserInfo{ Name: mockServiceUserName, Role: mockRole, IsServiceUser: true, @@ -407,10 +407,10 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) { func TestUser_CreateUser_RegularUser(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -420,7 +420,7 @@ func TestUser_CreateUser_RegularUser(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - _, err = am.CreateUser(mockAccountID, mockUserID, &UserInfo{ + _, err = am.CreateUser(context.Background(), mockAccountID, mockUserID, &UserInfo{ Name: mockServiceUserName, Role: mockRole, IsServiceUser: false, @@ -432,10 +432,10 @@ func TestUser_CreateUser_RegularUser(t *testing.T) { func TestUser_InviteNewUser(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -459,7 +459,7 @@ func TestUser_InviteNewUser(t *testing.T) { } idpMock := idp.MockIDP{ - CreateUserFunc: func(email, name, accountID, invitedByEmail string) (*idp.UserData, error) { + CreateUserFunc: func(_ context.Context, email, name, accountID, invitedByEmail string) (*idp.UserData, error) { newData := &idp.UserData{ Email: email, Name: name, @@ -470,7 +470,7 @@ func TestUser_InviteNewUser(t *testing.T) { return newData, nil }, - GetAccountFunc: func(accountId string) ([]*idp.UserData, error) { + GetAccountFunc: func(_ context.Context, accountId string) ([]*idp.UserData, error) { return mockData, nil }, } @@ -478,7 +478,7 @@ func TestUser_InviteNewUser(t *testing.T) { am.idpManager = &idpMock // test if new invite with regular role works - _, err = am.inviteNewUser(mockAccountID, mockUserID, &UserInfo{ + _, err = am.inviteNewUser(context.Background(), mockAccountID, mockUserID, &UserInfo{ Name: mockServiceUserName, Role: mockRole, Email: "test@teste.com", @@ -489,7 +489,7 @@ func TestUser_InviteNewUser(t *testing.T) { assert.NoErrorf(t, err, "Invite user should not throw error") // test if new invite with owner role fails - _, err = am.inviteNewUser(mockAccountID, mockUserID, &UserInfo{ + _, err = am.inviteNewUser(context.Background(), mockAccountID, mockUserID, &UserInfo{ Name: mockServiceUserName, Role: string(UserRoleOwner), Email: "test2@teste.com", @@ -532,10 +532,10 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { store := newStore(t) - account := newAccountWithId(mockAccountID, mockUserID, "") + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account.Users[mockServiceUserID] = tt.serviceUser - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -545,7 +545,7 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - err = am.DeleteUser(mockAccountID, mockUserID, mockServiceUserID) + err = am.DeleteUser(context.Background(), mockAccountID, mockUserID, mockServiceUserID) tt.assertErrFunc(t, err, tt.assertErrMessage) if err != nil { @@ -561,10 +561,10 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { func TestUser_DeleteUser_SelfDelete(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -574,7 +574,7 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - err = am.DeleteUser(mockAccountID, mockUserID, mockUserID) + err = am.DeleteUser(context.Background(), mockAccountID, mockUserID, mockUserID) if err == nil { t.Fatalf("failed to prevent self deletion") } @@ -582,8 +582,8 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) { func TestUser_DeleteUser_regularUser(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") targetId := "user2" account.Users[targetId] = &User{ @@ -612,7 +612,7 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { Role: UserRoleOwner, } - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -655,7 +655,7 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - err = am.DeleteUser(mockAccountID, mockUserID, testCase.userID) + err = am.DeleteUser(context.Background(), mockAccountID, mockUserID, testCase.userID) testCase.assertErrFunc(t, err, testCase.assertErrMessage) }) } @@ -664,10 +664,10 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { func TestDefaultAccountManager_GetUser(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -681,7 +681,7 @@ func TestDefaultAccountManager_GetUser(t *testing.T) { UserId: mockUserID, } - user, err := am.GetUser(claims) + user, err := am.GetUser(context.Background(), claims) if err != nil { t.Fatalf("Error when checking user role: %s", err) } @@ -693,12 +693,12 @@ func TestDefaultAccountManager_GetUser(t *testing.T) { func TestDefaultAccountManager_ListUsers(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account.Users["normal_user1"] = NewRegularUser("normal_user1") account.Users["normal_user2"] = NewRegularUser("normal_user2") - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -708,7 +708,7 @@ func TestDefaultAccountManager_ListUsers(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - users, err := am.ListUsers(mockAccountID) + users, err := am.ListUsers(context.Background(), mockAccountID) if err != nil { t.Fatalf("Error when checking user role: %s", err) } @@ -775,12 +775,12 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { store := newStore(t) - account := newAccountWithId(mockAccountID, mockUserID, "") + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI) account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings delete(account.Users, mockUserID) - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -790,7 +790,7 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - users, err := am.ListUsers(mockAccountID) + users, err := am.ListUsers(context.Background(), mockAccountID) if err != nil { t.Fatalf("Error when checking user role: %s", err) } @@ -806,8 +806,8 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { func TestDefaultAccountManager_ExternalCache(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") externalUser := &User{ Id: "externalUser", Role: UserRoleUser, @@ -819,7 +819,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) { } account.Users[externalUser.Id] = externalUser - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -846,7 +846,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) { err = cacheManager.Set(context.Background(), cacheKey, &idp.UserData{ID: externalUser.Id, Name: "Test User", Email: "user@example.com"}) assert.NoError(t, err) - infos, err := am.GetUsersFromAccount(mockAccountID, mockUserID) + infos, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockUserID) assert.NoError(t, err) assert.Equal(t, 2, len(infos)) var user *UserInfo @@ -870,15 +870,15 @@ func TestUser_IsAdmin(t *testing.T) { func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account.Users[mockServiceUserID] = &User{ Id: mockServiceUserID, Role: "user", IsServiceUser: true, } - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -888,7 +888,7 @@ func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - users, err := am.GetUsersFromAccount(mockAccountID, mockUserID) + users, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockUserID) if err != nil { t.Fatalf("Error when getting users from account: %s", err) } @@ -898,16 +898,16 @@ func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) { func TestUser_GetUsersFromAccount_ForUser(t *testing.T) { store := newStore(t) - defer store.Close() + defer store.Close(context.Background()) - account := newAccountWithId(mockAccountID, mockUserID, "") + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account.Users[mockServiceUserID] = &User{ Id: mockServiceUserID, Role: "user", IsServiceUser: true, } - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -917,7 +917,7 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - users, err := am.GetUsersFromAccount(mockAccountID, mockServiceUserID) + users, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockServiceUserID) if err != nil { t.Fatalf("Error when getting users from account: %s", err) } @@ -1069,7 +1069,7 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // create an account and an admin user - account, err := manager.GetOrCreateAccountByUser(ownerUserID, "netbird.io") + account, err := manager.GetOrCreateAccountByUser(context.Background(), ownerUserID, "netbird.io") if err != nil { t.Fatal(err) } @@ -1078,12 +1078,12 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { account.Users[regularUserID] = NewRegularUser(regularUserID) account.Users[adminUserID] = NewAdminUser(adminUserID) account.Users[serviceUserID] = &User{IsServiceUser: true, Id: serviceUserID, Role: UserRoleAdmin, ServiceUserName: "service"} - err = manager.Store.SaveAccount(account) + err = manager.Store.SaveAccount(context.Background(), account) if err != nil { t.Fatal(err) } - updated, err := manager.SaveUser(account.Id, tc.initiatorID, tc.update) + updated, err := manager.SaveUser(context.Background(), account.Id, tc.initiatorID, tc.update) if tc.expectedErr { require.Errorf(t, err, "expecting SaveUser to throw an error") } else { diff --git a/relay/server/relay.go b/relay/server/relay.go index 928c52322..e3aefb06a 100644 --- a/relay/server/relay.go +++ b/relay/server/relay.go @@ -106,8 +106,8 @@ func (r *Relay) handShake(conn net.Conn) ([]byte, error) { } if err := r.validator.Validate(authPayload); err != nil { - log.Errorf("failed to authenticate peer with id: %s, %s", peerID, err) - return nil, fmt.Errorf("failed to authenticate peer") + log.Errorf("failed to authenticate connection: %s", err) + return nil, err } msg, _ := messages.MarshalHelloResponse(r.instaceURL) diff --git a/route/hauniqueid.go b/route/hauniqueid.go index 6f74563e2..4d952beba 100644 --- a/route/hauniqueid.go +++ b/route/hauniqueid.go @@ -2,12 +2,9 @@ package route import "strings" -type HAUniqueID string +const haSeparator = "|" -// GetHAUniqueID returns a highly available route ID by combining Network ID and Network range address -func GetHAUniqueID(input *Route) HAUniqueID { - return HAUniqueID(string(input.NetID) + "-" + input.Network.String()) -} +type HAUniqueID string func (id HAUniqueID) String() string { return string(id) @@ -15,7 +12,7 @@ func (id HAUniqueID) String() string { // NetID returns the Network ID from the HAUniqueID func (id HAUniqueID) NetID() NetID { - if i := strings.LastIndex(string(id), "-"); i != -1 { + if i := strings.LastIndex(string(id), haSeparator); i != -1 { return NetID(id[:i]) } return NetID(id) diff --git a/route/route.go b/route/route.go index 50c53cbe6..eb6c36bd8 100644 --- a/route/route.go +++ b/route/route.go @@ -1,8 +1,13 @@ package route import ( + "fmt" "net/netip" + "slices" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server/status" ) @@ -25,6 +30,8 @@ const ( IPv4NetworkString = "IPv4" // IPv6NetworkString IPv6 network type string IPv6NetworkString = "IPv6" + // DomainNetworkString domain network type string + DomainNetworkString = "Domain" ) const ( @@ -34,6 +41,8 @@ const ( IPv4Network // IPv6Network IPv6 network type IPv6Network + // DomainNetwork domain network type + DomainNetwork ) type ID string @@ -52,6 +61,8 @@ func (p NetworkType) String() string { return IPv4NetworkString case IPv6Network: return IPv6NetworkString + case DomainNetwork: + return DomainNetworkString default: return InvalidNetworkString } @@ -64,6 +75,8 @@ func ToPrefixType(prefix string) NetworkType { return IPv4Network case IPv6NetworkString: return IPv6Network + case DomainNetworkString: + return DomainNetwork default: return InvalidNetwork } @@ -73,8 +86,11 @@ func ToPrefixType(prefix string) NetworkType { type Route struct { ID ID `gorm:"primaryKey"` // AccountID is a reference to Account that this object belongs - AccountID string `gorm:"index"` + AccountID string `gorm:"index"` + // Network and Domains are mutually exclusive Network netip.Prefix `gorm:"serializer:json"` + Domains domain.List `gorm:"serializer:json"` + KeepRoute bool NetID NetID Description string Peer string @@ -88,7 +104,7 @@ type Route struct { // EventMeta returns activity event meta related to the route func (r *Route) EventMeta() map[string]any { - return map[string]any{"name": r.NetID, "network_range": r.Network.String(), "peer_id": r.Peer, "peer_groups": r.PeerGroups} + return map[string]any{"name": r.NetID, "network_range": r.Network.String(), "domains": r.Domains.SafeString(), "peer_id": r.Peer, "peer_groups": r.PeerGroups} } // Copy copies a route object @@ -98,16 +114,16 @@ func (r *Route) Copy() *Route { Description: r.Description, NetID: r.NetID, Network: r.Network, + Domains: slices.Clone(r.Domains), + KeepRoute: r.KeepRoute, NetworkType: r.NetworkType, Peer: r.Peer, - PeerGroups: make([]string, len(r.PeerGroups)), + PeerGroups: slices.Clone(r.PeerGroups), Metric: r.Metric, Masquerade: r.Masquerade, Enabled: r.Enabled, - Groups: make([]string, len(r.Groups)), + Groups: slices.Clone(r.Groups), } - copy(route.Groups, r.Groups) - copy(route.PeerGroups, r.PeerGroups) return route } @@ -123,13 +139,32 @@ func (r *Route) IsEqual(other *Route) bool { other.Description == r.Description && other.NetID == r.NetID && other.Network == r.Network && + slices.Equal(r.Domains, other.Domains) && + other.KeepRoute == r.KeepRoute && other.NetworkType == r.NetworkType && other.Peer == r.Peer && other.Metric == r.Metric && other.Masquerade == r.Masquerade && other.Enabled == r.Enabled && - compareList(r.Groups, other.Groups) && - compareList(r.PeerGroups, other.PeerGroups) + slices.Equal(r.Groups, other.Groups) && + slices.Equal(r.PeerGroups, other.PeerGroups) +} + +// IsDynamic returns if the route is dynamic, i.e. has domains +func (r *Route) IsDynamic() bool { + return r.NetworkType == DomainNetwork +} + +func (r *Route) GetHAUniqueID() HAUniqueID { + if r.IsDynamic() { + domains, err := r.Domains.String() + if err != nil { + log.Errorf("Failed to convert domains to string: %v", err) + domains = r.Domains.PunycodeString() + } + return HAUniqueID(fmt.Sprintf("%s%s%s", r.NetID, haSeparator, domains)) + } + return HAUniqueID(fmt.Sprintf("%s%s%s", r.NetID, haSeparator, r.Network.String())) } // ParseNetwork Parses a network prefix string and returns a netip.Prefix object and if is invalid, IPv4 or IPv6 @@ -151,23 +186,3 @@ func ParseNetwork(networkString string) (NetworkType, netip.Prefix, error) { return IPv4Network, masked, nil } - -func compareList(list, other []string) bool { - if len(list) != len(other) { - return false - } - for _, id := range list { - match := false - for _, otherID := range other { - if id == otherID { - match = true - break - } - } - if !match { - return false - } - } - - return true -} diff --git a/signal/README.md b/signal/README.md index 96c80a490..dd2d761ad 100644 --- a/signal/README.md +++ b/signal/README.md @@ -1,9 +1,12 @@ # netbird Signal Server -This is a netbird signal-exchange server and client library to exchange connection information between netbird peers +This is a netbird signal-exchange server and client library to exchange +connection information between netbird peers ## Command Options -The CLI accepts the command **management** with the following options: + +The CLI accepts the the following options: + ```shell start Netbird Signal Server daemon @@ -20,24 +23,38 @@ Global Flags: --log-file string sets Netbird log path. If console is specified the the log will be output to stdout (default "/var/log/netbird/signal.log") --log-level string (default "info") ``` + ## Running the Signal service (Docker) -We have packed the Signal server into docker image. You can pull the image from Docker Hub and execute it with the following commands: +We have packed the Signal server into docker image. You can pull the image from +Docker Hub and execute it with the +following commands: + ````shell docker pull netbirdio/signal:latest docker run -d --name netbird-signal -p 10000:10000 netbirdio/signal:latest ```` -The default log-level is set to INFO, if you need you can change it using by updating the docker cmd as followed: + +The default log-level is set to INFO, if you need you can change it using by +updating the docker cmd as followed: + ````shell docker run -d --name netbird-signal -p 10000:10000 netbirdio/signal:latest --log-level DEBUG ```` + ### Run with TLS (Let's Encrypt). -By specifying the **--letsencrypt-domain** the daemon will handle SSL certificate request and configuration. -In the following example ```10000``` is the signal service **default** port, and ```443``` will be used as port for Let's Encrypt challenge and HTTP API. -> The server where you are running a container has to have a public IP (for Let's Encrypt certificate challenge). +By specifying the **--letsencrypt-domain** the daemon will handle SSL +certificate request and configuration. -Replace with your server's public domain (e.g. mydomain.com or subdomain sub.mydomain.com). +In the following example ```10000``` is the signal service **default** port, +and ```443``` will be used as port for +Let's Encrypt challenge and HTTP API. +> The server where you are running a container has to have a public IP (for +> Let's Encrypt certificate challenge). + +Replace `` with your server's public domain (e.g. mydomain.com or +subdomain sub.mydomain.com). ```bash # create a volume @@ -50,14 +67,57 @@ docker run -d --name netbird-signal \ netbirdio/signal:latest \ --letsencrypt-domain ``` + +## Metrics + +The Signal Server exposes the following metrics in Prometheus format: + +### Application Metrics + +- **active_peers**: A Gauge metric that tracks the number of active peers + connected to the server. +- **peer_connection_duration_seconds**: A Histogram metric that measures the + duration a peer was connected in seconds. +- **registrations_total**: A Counter metric that counts the total number of peer + registrations. +- **deregistrations_total**: A Counter metric that counts the total number of + peer deregistrations. +- **registration_failures_total**: A Counter metric that counts the total number + of failed peer registrations. Possible + labels: + - `error`: The type of error that caused the registration failure ( + e.g., `missing_id`, `missing_meta`, `failed_header`). +- **registration_delay_milliseconds**: A Histogram metric that measures the time + it took to register a peer in + milliseconds. +- **messages_forwarded_total**: A Counter metric that counts the total number of + messages forwarded between peers. +- **message_forward_failures_total**: A Counter metric that counts the total + number of failed message forwards between + peers. Possible labels: + - `type`: The type of failure ( + e.g., `error`, `not_connected`, `not_registered`). +- **message_forward_latency_milliseconds**: A Histogram metric that measures the + latency of message forwarding between + peers in milliseconds. + +### Endpoint + +The metrics are exposed in Prometheus format on the `/metrics` endpoint. By +default, the server listens on port `9090`, +so the full endpoint would be: + +> http://:9090/metrics + ## For development purposes: The project uses gRpc library and defines service in protobuf file located in: - ```proto/signalexchange.proto``` +```proto/signalexchange.proto``` To build the project you have to do the following things. Install golang gRpc tools: + ```bash #!/bin/bash go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.26 diff --git a/signal/client/client_test.go b/signal/client/client_test.go index 3ad348b6f..2525493b4 100644 --- a/signal/client/client_test.go +++ b/signal/client/client_test.go @@ -9,6 +9,7 @@ import ( . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -198,7 +199,11 @@ func startSignal() (*grpc.Server, net.Listener) { panic(err) } s := grpc.NewServer() - sigProto.RegisterSignalExchangeServer(s, server.NewServer()) + srv, err := server.NewServer(otel.Meter("")) + if err != nil { + panic(err) + } + sigProto.RegisterSignalExchangeServer(s, srv) go func() { if err := s.Serve(lis); err != nil { log.Fatalf("failed to serve: %v", err) diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 10a2da636..4b0dc583e 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "errors" "flag" "fmt" @@ -13,8 +14,11 @@ import ( "strings" "time" + "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "golang.org/x/crypto/acme/autocert" + "github.com/netbirdio/netbird/signal/metrics" + "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/signal/proto" "github.com/netbirdio/netbird/signal/server" @@ -28,6 +32,10 @@ import ( "google.golang.org/grpc/keepalive" ) +const ( + metricsPort = 9090 +) + var ( signalPort int signalLetsencryptDomain string @@ -95,9 +103,26 @@ var ( opts = append(opts, grpc.Creds(transportCredentials)) } - opts = append(opts, signalKaep, signalKasp) + metricsServer := metrics.NewServer(metricsPort, "") + if err != nil { + return fmt.Errorf("setup metrics: %v", err) + } + + opts = append(opts, signalKaep, signalKasp, grpc.StatsHandler(otelgrpc.NewServerHandler())) grpcServer := grpc.NewServer(opts...) - proto.RegisterSignalExchangeServer(grpcServer, server.NewServer()) + + go func() { + log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint) + if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { + log.Fatalf("Failed to start metrics server: %v", err) + } + }() + + srv, err := server.NewServer(metricsServer.Meter) + if err != nil { + return fmt.Errorf("creating signal server: %v", err) + } + proto.RegisterSignalExchangeServer(grpcServer, srv) var compatListener net.Listener if signalPort != 10000 { @@ -150,6 +175,14 @@ var ( _ = compatListener.Close() log.Infof("stopped gRPC backward compatibility server") } + + ctx, cancel := context.WithTimeout(cmd.Context(), 5*time.Second) + defer cancel() + if err := metricsServer.Shutdown(ctx); err != nil { + log.Errorf("Failed to stop metrics server: %v", err) + } + log.Infof("stopped metrics server") + log.Infof("stopped Signal Service") return nil diff --git a/signal/metrics/app.go b/signal/metrics/app.go new file mode 100644 index 000000000..fb882a5d4 --- /dev/null +++ b/signal/metrics/app.go @@ -0,0 +1,124 @@ +package metrics + +import ( + "go.opentelemetry.io/otel/metric" +) + +// AppMetrics holds all the application metrics +type AppMetrics struct { + metric.Meter + + ActivePeers metric.Int64UpDownCounter + PeerConnectionDuration metric.Int64Histogram + + Registrations metric.Int64Counter + Deregistrations metric.Int64Counter + RegistrationFailures metric.Int64Counter + RegistrationDelay metric.Float64Histogram + + MessagesForwarded metric.Int64Counter + MessageForwardFailures metric.Int64Counter + MessageForwardLatency metric.Float64Histogram +} + +func NewAppMetrics(meter metric.Meter) (*AppMetrics, error) { + activePeers, err := meter.Int64UpDownCounter("active_peers") + if err != nil { + return nil, err + } + + peerConnectionDuration, err := meter.Int64Histogram("peer_connection_duration_seconds", + metric.WithExplicitBucketBoundaries(getPeerConnectionDurationBucketBoundaries()...)) + if err != nil { + return nil, err + } + + registrations, err := meter.Int64Counter("registrations_total") + if err != nil { + return nil, err + } + + deregistrations, err := meter.Int64Counter("deregistrations_total") + if err != nil { + return nil, err + } + + registrationFailures, err := meter.Int64Counter("registration_failures_total") + if err != nil { + return nil, err + } + + registrationDelay, err := meter.Float64Histogram("registration_delay_milliseconds", + metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...)) + if err != nil { + return nil, err + } + + messagesForwarded, err := meter.Int64Counter("messages_forwarded_total") + if err != nil { + return nil, err + } + + messageForwardFailures, err := meter.Int64Counter("message_forward_failures_total") + if err != nil { + return nil, err + } + + messageForwardLatency, err := meter.Float64Histogram("message_forward_latency_milliseconds", + metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...)) + if err != nil { + return nil, err + } + + return &AppMetrics{ + Meter: meter, + + ActivePeers: activePeers, + PeerConnectionDuration: peerConnectionDuration, + + Registrations: registrations, + Deregistrations: deregistrations, + RegistrationFailures: registrationFailures, + RegistrationDelay: registrationDelay, + + MessagesForwarded: messagesForwarded, + MessageForwardFailures: messageForwardFailures, + MessageForwardLatency: messageForwardLatency, + }, nil +} + +func getStandardBucketBoundaries() []float64 { + return []float64{ + 0.1, + 0.5, + 1, + 5, + 10, + 50, + 100, + 500, + 1000, + 5000, + 10000, + } +} +func getPeerConnectionDurationBucketBoundaries() []float64 { + return []float64{ + 1, + 60, + // 10m + 600, + // 1h + 3600, + // 2h, + 7200, + // 6h, + 21600, + // 12h, + 43200, + // 24h, + 86400, + // 48h, + 172800, + } +} diff --git a/signal/metrics/metrics.go b/signal/metrics/metrics.go new file mode 100644 index 000000000..30db1600a --- /dev/null +++ b/signal/metrics/metrics.go @@ -0,0 +1,74 @@ +package metrics + +import ( + "context" + "fmt" + "net/http" + "reflect" + + prometheus2 "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/exporters/prometheus" + api "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/sdk/metric" +) + +const defaultEndpoint = "/metrics" + +// Metrics holds the metrics information and exposes it +type Metrics struct { + Meter api.Meter + provider *metric.MeterProvider + Endpoint string + + *http.Server +} + +// NewServer initializes and returns a new Metrics instance +func NewServer(port int, endpoint string) *Metrics { + exporter, err := prometheus.New() + if err != nil { + return nil + } + + provider := metric.NewMeterProvider(metric.WithReader(exporter)) + otel.SetMeterProvider(provider) + + pkg := reflect.TypeOf(defaultEndpoint).PkgPath() + meter := provider.Meter(pkg) + + if endpoint == "" { + endpoint = defaultEndpoint + } + + router := http.NewServeMux() + router.Handle(endpoint, promhttp.HandlerFor( + prometheus2.DefaultGatherer, + promhttp.HandlerOpts{EnableOpenMetrics: true})) + + server := &http.Server{ + Addr: fmt.Sprintf(":%d", port), + Handler: router, + } + + return &Metrics{ + Meter: meter, + provider: provider, + Endpoint: endpoint, + Server: server, + } +} + +// Shutdown stops the metrics server +func (m *Metrics) Shutdown(ctx context.Context) error { + if err := m.Server.Shutdown(ctx); err != nil { + return fmt.Errorf("http server: %w", err) + } + + if err := m.provider.Shutdown(ctx); err != nil { + return fmt.Errorf("meter provider: %w", err) + } + + return nil +} diff --git a/signal/peer/peer.go b/signal/peer/peer.go index 612e250a5..3149526b2 100644 --- a/signal/peer/peer.go +++ b/signal/peer/peer.go @@ -1,10 +1,14 @@ package peer import ( - "github.com/netbirdio/netbird/signal/proto" - log "github.com/sirupsen/logrus" + "context" "sync" "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/signal/metrics" + "github.com/netbirdio/netbird/signal/proto" ) // Peer representation of a connected Peer @@ -33,12 +37,14 @@ type Registry struct { Peers sync.Map // regMutex ensures that registration and de-registrations are safe regMutex sync.Mutex + metrics *metrics.AppMetrics } // NewRegistry creates a new connected Peer registry -func NewRegistry() *Registry { +func NewRegistry(metrics *metrics.AppMetrics) *Registry { return &Registry{ regMutex: sync.Mutex{}, + metrics: metrics, } } @@ -60,6 +66,8 @@ func (registry *Registry) IsPeerRegistered(peerId string) bool { // Register registers peer in the registry func (registry *Registry) Register(peer *Peer) { + start := time.Now() + registry.regMutex.Lock() defer registry.regMutex.Unlock() @@ -72,6 +80,11 @@ func (registry *Registry) Register(peer *Peer) { registry.Peers.Store(peer.Id, peer) } log.Debugf("peer registered [%s]", peer.Id) + + // record time as milliseconds + registry.metrics.RegistrationDelay.Record(context.Background(), float64(time.Since(start).Nanoseconds())/1e6) + + registry.metrics.Registrations.Add(context.Background(), 1) } // Deregister Peer from the Registry (usually once it disconnects) @@ -90,4 +103,6 @@ func (registry *Registry) Deregister(peer *Peer) { } } log.Debugf("peer deregistered [%s]", peer.Id) + + registry.metrics.Deregistrations.Add(context.Background(), 1) } diff --git a/signal/peer/peer_test.go b/signal/peer/peer_test.go index bf3dc706a..fb85fedda 100644 --- a/signal/peer/peer_test.go +++ b/signal/peer/peer_test.go @@ -1,13 +1,21 @@ package peer import ( - "github.com/stretchr/testify/assert" "testing" "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" + + "github.com/netbirdio/netbird/signal/metrics" ) func TestRegistry_ShouldNotDeregisterWhenHasNewerStreamRegistered(t *testing.T) { - r := NewRegistry() + metrics, err := metrics.NewAppMetrics(otel.Meter("")) + require.NoError(t, err) + + r := NewRegistry(metrics) peerID := "peer" @@ -30,7 +38,10 @@ func TestRegistry_ShouldNotDeregisterWhenHasNewerStreamRegistered(t *testing.T) } func TestRegistry_GetNonExistentPeer(t *testing.T) { - r := NewRegistry() + metrics, err := metrics.NewAppMetrics(otel.Meter("")) + require.NoError(t, err) + + r := NewRegistry(metrics) peer, ok := r.Get("non_existent_peer") @@ -44,7 +55,10 @@ func TestRegistry_GetNonExistentPeer(t *testing.T) { } func TestRegistry_Register(t *testing.T) { - r := NewRegistry() + metrics, err := metrics.NewAppMetrics(otel.Meter("")) + require.NoError(t, err) + + r := NewRegistry(metrics) peer1 := NewPeer("test_peer_1", nil) peer2 := NewPeer("test_peer_2", nil) r.Register(peer1) @@ -60,7 +74,10 @@ func TestRegistry_Register(t *testing.T) { } func TestRegistry_Deregister(t *testing.T) { - r := NewRegistry() + metrics, err := metrics.NewAppMetrics(otel.Meter("")) + require.NoError(t, err) + + r := NewRegistry(metrics) peer1 := NewPeer("test_peer_1", nil) peer2 := NewPeer("test_peer_2", nil) r.Register(peer1) diff --git a/signal/server/signal.go b/signal/server/signal.go index 84045e800..fc9c19efd 100644 --- a/signal/server/signal.go +++ b/signal/server/signal.go @@ -3,72 +3,114 @@ package server import ( "context" "fmt" - "github.com/netbirdio/netbird/signal/peer" - "github.com/netbirdio/netbird/signal/proto" + "io" + "time" + log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" - "io" + + "github.com/netbirdio/netbird/signal/metrics" + "github.com/netbirdio/netbird/signal/peer" + "github.com/netbirdio/netbird/signal/proto" +) + +const ( + labelType = "type" + labelTypeError = "error" + labelTypeNotConnected = "not_connected" + labelTypeNotRegistered = "not_registered" + + labelError = "error" + labelErrorMissingId = "missing_id" + labelErrorMissingMeta = "missing_meta" + labelErrorFailedHeader = "failed_header" ) // Server an instance of a Signal server type Server struct { registry *peer.Registry proto.UnimplementedSignalExchangeServer + + metrics *metrics.AppMetrics } // NewServer creates a new Signal server -func NewServer() *Server { - return &Server{ - registry: peer.NewRegistry(), +func NewServer(meter metric.Meter) (*Server, error) { + appMetrics, err := metrics.NewAppMetrics(meter) + if err != nil { + return nil, fmt.Errorf("creating app metrics: %v", err) } + + s := &Server{ + registry: peer.NewRegistry(appMetrics), + metrics: appMetrics, + } + + return s, nil } // Send forwards a message to the signal peer func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { - if !s.registry.IsPeerRegistered(msg.Key) { + s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotRegistered))) + return nil, fmt.Errorf("peer %s is not registered", msg.Key) } if dstPeer, found := s.registry.Get(msg.RemoteKey); found { //forward the message to the target peer - err := dstPeer.Stream.Send(msg) - if err != nil { + if err := dstPeer.Stream.Send(msg); err != nil { log.Errorf("error while forwarding message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err) //todo respond to the sender? + + s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError))) + } else { + s.metrics.MessagesForwarded.Add(context.Background(), 1) } } else { log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey) //todo respond to the sender? + + s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected))) } return &proto.EncryptedMessage{}, nil } // ConnectStream connects to the exchange stream func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) error { - p, err := s.connectPeer(stream) if err != nil { return err } + startRegister := time.Now() + + s.metrics.ActivePeers.Add(stream.Context(), 1) + defer func() { log.Infof("peer disconnected [%s] [streamID %d] ", p.Id, p.StreamID) s.registry.Deregister(p) + + s.metrics.PeerConnectionDuration.Record(stream.Context(), int64(time.Since(startRegister).Seconds())) + s.metrics.ActivePeers.Add(context.Background(), -1) }() //needed to confirm that the peer has been registered so that the client can proceed header := metadata.Pairs(proto.HeaderRegistered, "1") err = stream.SendHeader(header) if err != nil { + s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorFailedHeader))) return err } log.Infof("peer connected [%s] [streamID %d] ", p.Id, p.StreamID) for { + //read incoming messages msg, err := stream.Recv() if err == io.EOF { @@ -76,18 +118,28 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) } else if err != nil { return err } + start := time.Now() + log.Debugf("received a new message from peer [%s] to peer [%s]", p.Id, msg.RemoteKey) + // lookup the target peer where the message is going to if dstPeer, found := s.registry.Get(msg.RemoteKey); found { //forward the message to the target peer - err := dstPeer.Stream.Send(msg) - if err != nil { + if err := dstPeer.Stream.Send(msg); err != nil { log.Errorf("error while forwarding message from peer [%s] to peer [%s] %v", p.Id, msg.RemoteKey, err) //todo respond to the sender? + + // in milliseconds + s.metrics.MessageForwardLatency.Record(stream.Context(), float64(time.Since(start).Nanoseconds())/1e6) + s.metrics.MessagesForwarded.Add(stream.Context(), 1) + } else { + s.metrics.MessageForwardFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelType, labelTypeError))) } } else { log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", p.Id, msg.RemoteKey) //todo respond to the sender? + + s.metrics.MessageForwardFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected))) } } <-stream.Context().Done() @@ -101,12 +153,16 @@ func (s Server) connectPeer(stream proto.SignalExchange_ConnectStreamServer) (*p if meta, hasMeta := metadata.FromIncomingContext(stream.Context()); hasMeta { if id, found := meta[proto.HeaderId]; found { p := peer.NewPeer(id[0], stream) + s.registry.Register(p) + return p, nil } else { + s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingId))) return nil, status.Errorf(codes.FailedPrecondition, "missing connection header: "+proto.HeaderId) } } else { + s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingMeta))) return nil, status.Errorf(codes.FailedPrecondition, "missing connection stream meta") } } diff --git a/util/file.go b/util/file.go index 0cbfa37ab..2a6182556 100644 --- a/util/file.go +++ b/util/file.go @@ -1,6 +1,7 @@ package util import ( + "context" "encoding/json" "io" "os" @@ -57,7 +58,7 @@ func WriteJson(file string, obj interface{}) error { } // DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file -func DirectWriteJson(file string, obj interface{}) error { +func DirectWriteJson(ctx context.Context, file string, obj interface{}) error { _, _, err := prepareConfigFileDir(file) if err != nil { diff --git a/util/grpc/dialer.go b/util/grpc/dialer.go index 63c56de17..3fba0c84e 100644 --- a/util/grpc/dialer.go +++ b/util/grpc/dialer.go @@ -27,7 +27,6 @@ func WithCustomDialer() grpc.DialOption { } } - conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) if err != nil { log.Errorf("Failed to dial: %s", err) diff --git a/util/log.go b/util/log.go index fda15a541..90ccea48f 100644 --- a/util/log.go +++ b/util/log.go @@ -2,6 +2,7 @@ package util import ( "io" + "os" "path/filepath" log "github.com/sirupsen/logrus" @@ -30,7 +31,11 @@ func InitLog(logLevel string, logPath string) error { log.SetOutput(io.Writer(lumberjackLogger)) } - formatter.SetTextFormatter(log.StandardLogger()) + if os.Getenv("NB_LOG_FORMAT") == "json" { + formatter.SetJSONFormatter(log.StandardLogger()) + } else { + formatter.SetTextFormatter(log.StandardLogger()) + } log.SetLevel(level) return nil } diff --git a/util/membership_unix.go b/util/membership_unix.go index 82237461c..a9e55af84 100644 --- a/util/membership_unix.go +++ b/util/membership_unix.go @@ -1,4 +1,4 @@ -//go:build linux || darwin +//go:build linux || darwin || freebsd package util diff --git a/util/net/net.go b/util/net/net.go index 3856911b1..8d1fcebd0 100644 --- a/util/net/net.go +++ b/util/net/net.go @@ -1,8 +1,11 @@ package net import ( + "net" "os" + "github.com/netbirdio/netbird/iface/netstack" + "github.com/google/uuid" ) @@ -17,11 +20,17 @@ const ( // It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook. type ConnectionID string +type AddHookFunc func(connID ConnectionID, IP net.IP) error +type RemoveHookFunc func(connID ConnectionID) error + // GenerateConnID generates a unique identifier for each connection. func GenerateConnID() ConnectionID { return ConnectionID(uuid.NewString()) } func CustomRoutingDisabled() bool { + if netstack.IsEnabled() { + return true + } return os.Getenv(envDisableCustomRouting) == "true" } diff --git a/version/url_freebsd.go b/version/url_freebsd.go new file mode 100644 index 000000000..c8193e30c --- /dev/null +++ b/version/url_freebsd.go @@ -0,0 +1,6 @@ +package version + +// DownloadUrl return with the proper download link +func DownloadUrl() string { + return downloadURL +}