From 58ff7ab797fcde081b3a0a802487f13a4ab4945a Mon Sep 17 00:00:00 2001 From: adasauce <60991921+adasauce@users.noreply.github.com> Date: Fri, 27 Sep 2024 16:21:34 -0300 Subject: [PATCH 01/37] [management] improve zitadel idp error response detail by decoding errors (#2634) * [management] improve zitadel idp error response detail by decoding errors * [management] extend readZitadelError to be used for requestJWTToken more generically parse the error returned by zitadel. * fix lint --------- Co-authored-by: bcmmbaga --- management/server/idp/zitadel.go | 49 +++++++++++++++++++++++++-- management/server/idp/zitadel_test.go | 10 +++--- 2 files changed, 50 insertions(+), 9 deletions(-) diff --git a/management/server/idp/zitadel.go b/management/server/idp/zitadel.go index 729b49733..9d7626844 100644 --- a/management/server/idp/zitadel.go +++ b/management/server/idp/zitadel.go @@ -2,10 +2,12 @@ package idp import ( "context" + "errors" "fmt" "io" "net/http" "net/url" + "slices" "strings" "sync" "time" @@ -97,6 +99,42 @@ type zitadelUserResponse struct { PasswordlessRegistration zitadelPasswordlessRegistration `json:"passwordlessRegistration"` } +// readZitadelError parses errors returned by the zitadel APIs from a response. +func readZitadelError(body io.ReadCloser) error { + bodyBytes, err := io.ReadAll(body) + if err != nil { + return fmt.Errorf("failed to read response body: %w", err) + } + + helper := JsonParser{} + var target map[string]interface{} + err = helper.Unmarshal(bodyBytes, &target) + if err != nil { + return fmt.Errorf("error unparsable body: %s", string(bodyBytes)) + } + + // ensure keys are ordered for consistent logging behaviour. + errorKeys := make([]string, 0, len(target)) + for k := range target { + errorKeys = append(errorKeys, k) + } + slices.Sort(errorKeys) + + var errsOut []string + for _, k := range errorKeys { + if _, isEmbedded := target[k].(map[string]interface{}); isEmbedded { + continue + } + errsOut = append(errsOut, fmt.Sprintf("%s: %v", k, target[k])) + } + + if len(errsOut) == 0 { + return errors.New("unknown error") + } + + return errors.New(strings.Join(errsOut, " ")) +} + // NewZitadelManager creates a new instance of the ZitadelManager. func NewZitadelManager(config ZitadelClientConfig, appMetrics telemetry.AppMetrics) (*ZitadelManager, error) { httpTransport := http.DefaultTransport.(*http.Transport).Clone() @@ -176,7 +214,8 @@ func (zc *ZitadelCredentials) requestJWTToken(ctx context.Context) (*http.Respon } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unable to get zitadel token, statusCode %d", resp.StatusCode) + zErr := readZitadelError(resp.Body) + return nil, fmt.Errorf("unable to get zitadel token, statusCode %d, zitadel: %w", resp.StatusCode, zErr) } return resp, nil @@ -489,7 +528,9 @@ func (zm *ZitadelManager) post(ctx context.Context, resource string, body string zm.appMetrics.IDPMetrics().CountRequestStatusError() } - return nil, fmt.Errorf("unable to post %s, statusCode %d", reqURL, resp.StatusCode) + zErr := readZitadelError(resp.Body) + + return nil, fmt.Errorf("unable to post %s, statusCode %d, zitadel: %w", reqURL, resp.StatusCode, zErr) } return io.ReadAll(resp.Body) @@ -561,7 +602,9 @@ func (zm *ZitadelManager) get(ctx context.Context, resource string, q url.Values zm.appMetrics.IDPMetrics().CountRequestStatusError() } - return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode) + zErr := readZitadelError(resp.Body) + + return nil, fmt.Errorf("unable to get %s, statusCode %d, zitadel: %w", reqURL, resp.StatusCode, zErr) } return io.ReadAll(resp.Body) diff --git a/management/server/idp/zitadel_test.go b/management/server/idp/zitadel_test.go index 6bc612e78..722f94fe0 100644 --- a/management/server/idp/zitadel_test.go +++ b/management/server/idp/zitadel_test.go @@ -66,7 +66,6 @@ func TestNewZitadelManager(t *testing.T) { } func TestZitadelRequestJWTToken(t *testing.T) { - type requestJWTTokenTest struct { name string inputCode int @@ -88,15 +87,14 @@ func TestZitadelRequestJWTToken(t *testing.T) { requestJWTTokenTestCase2 := requestJWTTokenTest{ name: "Request Bad Status Code", inputCode: 400, - inputRespBody: "{}", + inputRespBody: "{\"error\": \"invalid_scope\", \"error_description\":\"openid missing\"}", helper: JsonParser{}, - expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400"), + expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400, zitadel: error: invalid_scope error_description: openid missing"), expectedToken: "", } for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} { t.Run(testCase.name, func(t *testing.T) { - jwtReqClient := mockHTTPClient{ resBody: testCase.inputRespBody, code: testCase.inputCode, @@ -156,7 +154,7 @@ func TestZitadelParseRequestJWTResponse(t *testing.T) { } parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{ name: "Parse Bad json JWT Body", - inputRespBody: "", + inputRespBody: "{}", helper: JsonParser{}, expectedToken: "", expectedExpiresIn: 0, @@ -254,7 +252,7 @@ func TestZitadelAuthenticate(t *testing.T) { inputCode: 400, inputResBody: "{}", helper: JsonParser{}, - expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400"), + expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400, zitadel: unknown error"), expectedCode: 200, expectedToken: "", } From 52ae693c9e5eff72082d4330eab6723251559546 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Sun, 29 Sep 2024 00:22:47 +0200 Subject: [PATCH 02/37] [signal] add context to signal-dispatcher (#2662) --- client/cmd/testutil_test.go | 2 +- client/internal/engine_test.go | 2 +- client/server/server_test.go | 2 +- go.mod | 2 +- go.sum | 4 ++-- signal/client/client_test.go | 2 +- signal/cmd/run.go | 2 +- signal/server/signal.go | 4 ++-- 8 files changed, 10 insertions(+), 10 deletions(-) diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 780cc8b04..f0dc8bf21 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -57,7 +57,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) { t.Fatal(err) } s := grpc.NewServer() - srv, err := sig.NewServer(otel.Meter("")) + srv, err := sig.NewServer(context.Background(), otel.Meter("")) require.NoError(t, err) sigProto.RegisterSignalExchangeServer(s, srv) diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index f30566380..95aadf141 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -1056,7 +1056,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) { log.Fatalf("failed to listen: %v", err) } - srv, err := signalServer.NewServer(otel.Meter("")) + srv, err := signalServer.NewServer(context.Background(), otel.Meter("")) require.NoError(t, err) proto.RegisterSignalExchangeServer(s, srv) diff --git a/client/server/server_test.go b/client/server/server_test.go index 795060fab..9b18df4d3 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -160,7 +160,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) { log.Fatalf("failed to listen: %v", err) } - srv, err := signalServer.NewServer(otel.Meter("")) + srv, err := signalServer.NewServer(context.Background(), otel.Meter("")) require.NoError(t, err) proto.RegisterSignalExchangeServer(s, srv) diff --git a/go.mod b/go.mod index 12709e50d..cf3b610bd 100644 --- a/go.mod +++ b/go.mod @@ -60,7 +60,7 @@ require ( github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e - github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080 + github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240928205912-5569c4c5e086 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 github.com/patrickmn/go-cache v2.1.0+incompatible diff --git a/go.sum b/go.sum index 2355f6f0c..089629cdf 100644 --- a/go.sum +++ b/go.sum @@ -525,8 +525,8 @@ github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513- github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= -github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080 h1:mXJkoWLdqJTlkQ7DgQ536kcXHXIdUPeagkN8i4eFDdg= -github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= +github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240928205912-5569c4c5e086 h1:NZm4JvvjKuEh3p7daHUy3rWKhKsnUzzYpGv1qT4dYLc= +github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240928205912-5569c4c5e086/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= diff --git a/signal/client/client_test.go b/signal/client/client_test.go index 2525493b4..f7d4ebc50 100644 --- a/signal/client/client_test.go +++ b/signal/client/client_test.go @@ -199,7 +199,7 @@ func startSignal() (*grpc.Server, net.Listener) { panic(err) } s := grpc.NewServer() - srv, err := server.NewServer(otel.Meter("")) + srv, err := server.NewServer(context.Background(), otel.Meter("")) if err != nil { panic(err) } diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 0bdc62ead..1bb2f1d0c 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -102,7 +102,7 @@ var ( } }() - srv, err := server.NewServer(metricsServer.Meter) + srv, err := server.NewServer(cmd.Context(), metricsServer.Meter) if err != nil { return fmt.Errorf("creating signal server: %v", err) } diff --git a/signal/server/signal.go b/signal/server/signal.go index b268aa3fc..c020c5604 100644 --- a/signal/server/signal.go +++ b/signal/server/signal.go @@ -47,13 +47,13 @@ type Server struct { } // NewServer creates a new Signal server -func NewServer(meter metric.Meter) (*Server, error) { +func NewServer(ctx context.Context, meter metric.Meter) (*Server, error) { appMetrics, err := metrics.NewAppMetrics(meter) if err != nil { return nil, fmt.Errorf("creating app metrics: %v", err) } - dispatcher, err := dispatcher.NewDispatcher() + dispatcher, err := dispatcher.NewDispatcher(ctx) if err != nil { return nil, fmt.Errorf("creating dispatcher: %v", err) } From cfbcf507fb0ae039c270af48822679a754b8c530 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Sun, 29 Sep 2024 20:23:34 +0200 Subject: [PATCH 03/37] propagate meter (#2668) --- go.mod | 2 +- go.sum | 4 ++-- signal/server/signal.go | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/go.mod b/go.mod index cf3b610bd..edee0ede4 100644 --- a/go.mod +++ b/go.mod @@ -60,7 +60,7 @@ require ( github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e - github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240928205912-5569c4c5e086 + github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240929132730-cbef5d331757 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 github.com/patrickmn/go-cache v2.1.0+incompatible diff --git a/go.sum b/go.sum index 089629cdf..2160fa1f8 100644 --- a/go.sum +++ b/go.sum @@ -525,8 +525,8 @@ github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513- github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= -github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240928205912-5569c4c5e086 h1:NZm4JvvjKuEh3p7daHUy3rWKhKsnUzzYpGv1qT4dYLc= -github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240928205912-5569c4c5e086/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= +github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240929132730-cbef5d331757 h1:6XniCzDt+1jvXWMUY4EDH0Hi5RXbUOYB0A8XEQqSlZk= +github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240929132730-cbef5d331757/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= diff --git a/signal/server/signal.go b/signal/server/signal.go index c020c5604..386ce7238 100644 --- a/signal/server/signal.go +++ b/signal/server/signal.go @@ -53,7 +53,7 @@ func NewServer(ctx context.Context, meter metric.Meter) (*Server, error) { return nil, fmt.Errorf("creating app metrics: %v", err) } - dispatcher, err := dispatcher.NewDispatcher(ctx) + dispatcher, err := dispatcher.NewDispatcher(ctx, meter) if err != nil { return nil, fmt.Errorf("creating dispatcher: %v", err) } From 3dca6099d4f1a32c2e2ddbabe88a49d786fb3c41 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 30 Sep 2024 10:34:57 +0200 Subject: [PATCH 04/37] Fix ebpf close function (#2672) --- client/internal/wgproxy/ebpf/proxy.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/client/internal/wgproxy/ebpf/proxy.go b/client/internal/wgproxy/ebpf/proxy.go index 4bd4bfff6..27ede3ef1 100644 --- a/client/internal/wgproxy/ebpf/proxy.go +++ b/client/internal/wgproxy/ebpf/proxy.go @@ -81,8 +81,7 @@ func (p *WGEBPFProxy) Listen() error { conn, err := nbnet.ListenUDP("udp", &addr) if err != nil { - cErr := p.Free() - if cErr != nil { + if cErr := p.Free(); cErr != nil { log.Errorf("Failed to close the wgproxy: %s", cErr) } return err @@ -122,8 +121,10 @@ func (p *WGEBPFProxy) Free() error { p.ctxCancel() var result *multierror.Error - if err := p.conn.Close(); err != nil { - result = multierror.Append(result, err) + if p.conn != nil { // p.conn will be nil if we have failed to listen + if err := p.conn.Close(); err != nil { + result = multierror.Append(result, err) + } } if err := p.ebpfManager.FreeWGProxy(); err != nil { From 2fd60b2cb46a77f16b5e1e1f72a1a09f03f0ecbe Mon Sep 17 00:00:00 2001 From: Gianluca Boiano <491117+M0Rf30@users.noreply.github.com> Date: Mon, 30 Sep 2024 16:43:34 +0200 Subject: [PATCH 05/37] Specify goreleaser version and update to 2 (#2673) --- .github/workflows/release.yml | 72 +++++++++++++---------------------- .goreleaser.yaml | 32 ++++++++-------- .goreleaser_ui.yaml | 9 +++-- .goreleaser_ui_darwin.yaml | 6 ++- CONTRIBUTING.md | 2 +- 5 files changed, 52 insertions(+), 69 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 5f423f1c9..162e488c3 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -3,15 +3,14 @@ name: Release on: push: tags: - - 'v*' + - "v*" branches: - main pull_request: - env: SIGN_PIPE_VER: "v0.0.14" - GORELEASER_VER: "v1.14.1" + GORELEASER_VER: "v2.3.2" PRODUCT_NAME: "NetBird" COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)" @@ -34,19 +33,16 @@ jobs: - if: ${{ !startsWith(github.ref, 'refs/tags/v') }} run: echo "flags=--snapshot" >> $GITHUB_ENV - - - name: Checkout + - name: Checkout uses: actions/checkout@v4 with: fetch-depth: 0 # It is required for GoReleaser to work properly - - - name: Set up Go + - name: Set up Go uses: actions/setup-go@v5 with: go-version: "1.23" cache: false - - - name: Cache Go modules + - name: Cache Go modules uses: actions/cache@v4 with: path: | @@ -55,20 +51,15 @@ jobs: key: ${{ runner.os }}-go-releaser-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go-releaser- - - - name: Install modules + - name: Install modules run: go mod tidy - - - name: check git status + - name: check git status run: git --no-pager diff --exit-code - - - name: Set up QEMU + - name: Set up QEMU uses: docker/setup-qemu-action@v2 - - - name: Set up Docker Buildx + - name: Set up Docker Buildx uses: docker/setup-buildx-action@v2 - - - name: Login to Docker hub + - name: Login to Docker hub if: github.event_name != 'pull_request' uses: docker/login-action@v1 with: @@ -85,35 +76,31 @@ jobs: uses: goreleaser/goreleaser-action@v4 with: version: ${{ env.GORELEASER_VER }} - args: release --rm-dist ${{ env.flags }} + args: release --clean ${{ env.flags }} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} - - - name: upload non tags for debug purposes + - name: upload non tags for debug purposes uses: actions/upload-artifact@v4 with: name: release path: dist/ retention-days: 3 - - - name: upload linux packages + - name: upload linux packages uses: actions/upload-artifact@v4 with: name: linux-packages path: dist/netbird_linux** retention-days: 3 - - - name: upload windows packages + - name: upload windows packages uses: actions/upload-artifact@v4 with: name: windows-packages path: dist/netbird_windows** retention-days: 3 - - - name: upload macos packages + - name: upload macos packages uses: actions/upload-artifact@v4 with: name: macos-packages @@ -145,7 +132,7 @@ jobs: - name: Cache Go modules uses: actions/cache@v4 with: - path: | + path: | ~/go/pkg/mod ~/.cache/go-build key: ${{ runner.os }}-ui-go-releaser-${{ hashFiles('**/go.sum') }} @@ -169,7 +156,7 @@ jobs: uses: goreleaser/goreleaser-action@v4 with: version: ${{ env.GORELEASER_VER }} - args: release --config .goreleaser_ui.yaml --rm-dist ${{ env.flags }} + args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} @@ -187,19 +174,16 @@ jobs: steps: - if: ${{ !startsWith(github.ref, 'refs/tags/v') }} run: echo "flags=--snapshot" >> $GITHUB_ENV - - - name: Checkout + - name: Checkout uses: actions/checkout@v4 with: fetch-depth: 0 # It is required for GoReleaser to work properly - - - name: Set up Go + - name: Set up Go uses: actions/setup-go@v5 with: go-version: "1.23" cache: false - - - name: Cache Go modules + - name: Cache Go modules uses: actions/cache@v4 with: path: | @@ -208,23 +192,19 @@ jobs: key: ${{ runner.os }}-ui-go-releaser-darwin-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-ui-go-releaser-darwin- - - - name: Install modules + - name: Install modules run: go mod tidy - - - name: check git status + - name: check git status run: git --no-pager diff --exit-code - - - name: Run GoReleaser + - name: Run GoReleaser id: goreleaser uses: goreleaser/goreleaser-action@v4 with: version: ${{ env.GORELEASER_VER }} - args: release --config .goreleaser_ui_darwin.yaml --rm-dist ${{ env.flags }} + args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - - name: upload non tags for debug purposes + - name: upload non tags for debug purposes uses: actions/upload-artifact@v4 with: name: release-ui-darwin @@ -233,7 +213,7 @@ jobs: trigger_signer: runs-on: ubuntu-latest - needs: [release,release_ui,release_ui_darwin] + needs: [release, release_ui, release_ui_darwin] if: startsWith(github.ref, 'refs/tags/') steps: - name: Trigger binaries sign pipelines diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 068864d6e..cf2ce4f4f 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -1,3 +1,5 @@ +version: 2 + project_name: netbird builds: - id: netbird @@ -22,7 +24,7 @@ builds: goarch: 386 ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" tags: - load_wgnt_from_rsrc @@ -42,19 +44,19 @@ builds: - softfloat ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" tags: - load_wgnt_from_rsrc - id: netbird-mgmt dir: management env: - - CGO_ENABLED=1 - - >- - {{- if eq .Runtime.Goos "linux" }} - {{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }} - {{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }} - {{- end }} + - CGO_ENABLED=1 + - >- + {{- if eq .Runtime.Goos "linux" }} + {{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }} + {{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }} + {{- end }} binary: netbird-mgmt goos: - linux @@ -64,7 +66,7 @@ builds: - arm ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" - id: netbird-signal dir: signal @@ -78,7 +80,7 @@ builds: - arm ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" - id: netbird-relay dir: relay @@ -92,7 +94,7 @@ builds: - arm ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" archives: - builds: @@ -100,7 +102,6 @@ archives: - netbird-static nfpms: - - maintainer: Netbird description: Netbird client. homepage: https://netbird.io/ @@ -416,10 +417,9 @@ docker_manifests: - netbirdio/management:{{ .Version }}-debug-amd64 brews: - - - ids: + - ids: - default - tap: + repository: owner: netbirdio name: homebrew-tap token: "{{ .Env.HOMEBREW_TAP_GITHUB_TOKEN }}" @@ -436,7 +436,7 @@ brews: uploads: - name: debian ids: - - netbird-deb + - netbird-deb mode: archive target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package= username: dev@wiretrustee.com diff --git a/.goreleaser_ui.yaml b/.goreleaser_ui.yaml index fd92b5328..06577f4e3 100644 --- a/.goreleaser_ui.yaml +++ b/.goreleaser_ui.yaml @@ -1,3 +1,5 @@ +version: 2 + project_name: netbird-ui builds: - id: netbird-ui @@ -11,7 +13,7 @@ builds: - amd64 ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" - id: netbird-ui-windows dir: client/ui @@ -26,7 +28,7 @@ builds: ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -H windowsgui - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" archives: - id: linux-arch @@ -39,7 +41,6 @@ archives: - netbird-ui-windows nfpms: - - maintainer: Netbird description: Netbird client UI. homepage: https://netbird.io/ @@ -77,7 +78,7 @@ nfpms: uploads: - name: debian ids: - - netbird-ui-deb + - netbird-ui-deb mode: archive target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package= username: dev@wiretrustee.com diff --git a/.goreleaser_ui_darwin.yaml b/.goreleaser_ui_darwin.yaml index 2c3afa91b..bccb7f471 100644 --- a/.goreleaser_ui_darwin.yaml +++ b/.goreleaser_ui_darwin.yaml @@ -1,3 +1,5 @@ +version: 2 + project_name: netbird-ui builds: - id: netbird-ui-darwin @@ -17,7 +19,7 @@ builds: - softfloat ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" tags: - load_wgnt_from_rsrc @@ -28,4 +30,4 @@ archives: checksum: name_template: "{{ .ProjectName }}_darwin_checksums.txt" changelog: - skip: true \ No newline at end of file + disable: true diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 492aa5c2e..c82cfc763 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -96,7 +96,7 @@ They can be executed from the repository root before every push or PR: **Goreleaser** ```shell -goreleaser --snapshot --rm-dist +goreleaser build --snapshot --clean ``` **golangci-lint** ```shell From e27f85b317a97721921933659a80c8be35c785e1 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 30 Sep 2024 20:07:21 +0200 Subject: [PATCH 06/37] Update docker creds (#2677) --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 162e488c3..7af6d3e4d 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -63,7 +63,7 @@ jobs: if: github.event_name != 'pull_request' uses: docker/login-action@v1 with: - username: netbirdio + username: ${{ secrets.DOCKER_USER }} password: ${{ secrets.DOCKER_TOKEN }} - name: Install OS build dependencies run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu From 16179db599ef6fb42e709597bc260101dfa7cd74 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Mon, 30 Sep 2024 22:18:10 +0200 Subject: [PATCH 07/37] [management] Propagate metrics (#2667) --- go.mod | 2 +- go.sum | 4 ++-- management/server/http/handler.go | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/go.mod b/go.mod index edee0ede4..c29ba0763 100644 --- a/go.mod +++ b/go.mod @@ -59,7 +59,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e + github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240929132730-cbef5d331757 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index 2160fa1f8..1f6cbb785 100644 --- a/go.sum +++ b/go.sum @@ -521,8 +521,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= -github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e h1:LYxhAmiEzSldLELHSMVoUnRPq3ztTNQImrD27frrGsI= -github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= +github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd h1:phKq1S1Y/lnqEhP5Qknta733+rPX16dRDHM7hKkot9c= +github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240929132730-cbef5d331757 h1:6XniCzDt+1jvXWMUY4EDH0Hi5RXbUOYB0A8XEQqSlZk= diff --git a/management/server/http/handler.go b/management/server/http/handler.go index ef94f22b9..3f8a8554d 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -82,7 +82,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa AuthCfg: authCfg, } - if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator); err != nil { + if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator, appMetrics.GetMeter()); err != nil { return nil, fmt.Errorf("register integrations endpoints: %w", err) } From 24c0aaa745bc2ac46bdcf1f855834306a886db95 Mon Sep 17 00:00:00 2001 From: Simen <97337442+simen64@users.noreply.github.com> Date: Tue, 1 Oct 2024 13:32:58 +0200 Subject: [PATCH 08/37] Install sh alpine fixes (#2678) * Made changes to the peer install script that makes it work on alpine linux without changes * fix small oversight with doas fix * use try catch approach when curling binaries --- release_files/install.sh | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/release_files/install.sh b/release_files/install.sh index d6aabebd8..5dd0f67bb 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -21,6 +21,8 @@ SUDO="" if command -v sudo > /dev/null && [ "$(id -u)" -ne 0 ]; then SUDO="sudo" +elif command -v doas > /dev/null && [ "$(id -u)" -ne 0 ]; then + SUDO="doas" fi if [ -z ${NETBIRD_RELEASE+x} ]; then @@ -68,7 +70,7 @@ download_release_binary() { if [ -n "$GITHUB_TOKEN" ]; then cd /tmp && curl -H "Authorization: token ${GITHUB_TOKEN}" -LO "$DOWNLOAD_URL" else - cd /tmp && curl -LO "$DOWNLOAD_URL" + cd /tmp && curl -LO "$DOWNLOAD_URL" || curl -LO --dns-servers 8.8.8.8 "$DOWNLOAD_URL" fi @@ -316,7 +318,7 @@ install_netbird() { } version_greater_equal() { - printf '%s\n%s\n' "$2" "$1" | sort -V -C + printf '%s\n%s\n' "$2" "$1" | sort -V -c } is_bin_package_manager() { From ee0ea86a0a9394b2632ed2be3149d45c04baca67 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Tue, 1 Oct 2024 16:22:18 +0200 Subject: [PATCH 09/37] [relay-client] Fix Relay disconnection handling (#2680) * Fix Relay disconnection handling If has an active P2P connection meanwhile the Relay connection broken with the server then we removed the WireGuard peer configuration. * Change logs --- client/internal/peer/conn.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index ea6d892b9..baff1372a 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -586,13 +586,17 @@ func (conn *Conn) onWorkerRelayStateDisconnected() { return } - if conn.wgProxyRelay != nil { - log.Debugf("relayed connection is closed, clean up WireGuard config") + log.Debugf("relay connection is disconnected") + + if conn.currentConnPriority == connPriorityRelay { + log.Debugf("clean up WireGuard config") err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) if err != nil { conn.log.Errorf("failed to remove wg endpoint: %v", err) } + } + if conn.wgProxyRelay != nil { conn.endpointRelay = nil _ = conn.wgProxyRelay.CloseConn() conn.wgProxyRelay = nil From 5932298ce03ccda417cbf954020665fdc096baaa Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 2 Oct 2024 11:48:09 +0200 Subject: [PATCH 10/37] Add log setting to Caddy container (#2684) This avoids full disk on busy systems --- infrastructure_files/getting-started-with-zitadel.sh | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index c0275536b..2c5c35d53 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -793,6 +793,11 @@ services: volumes: - netbird_caddy_data:/data - ./Caddyfile:/etc/caddy/Caddyfile + logging: + driver: "json-file" + options: + max-size: "500m" + max-file: "2" # UI dashboard dashboard: image: netbirdio/dashboard:latest From a3a479429eb13dc53b9d9dd7bfb1b0710c5055c0 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 2 Oct 2024 11:48:42 +0200 Subject: [PATCH 11/37] Use the pkgs to get the latest version (#2682) * Use the pkgs to get the latest version * disable fail fast --- .github/workflows/install-script-test.yml | 1 + release_files/install.sh | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/install-script-test.yml b/.github/workflows/install-script-test.yml index 04c222e87..22d002a48 100644 --- a/.github/workflows/install-script-test.yml +++ b/.github/workflows/install-script-test.yml @@ -13,6 +13,7 @@ concurrency: jobs: test-install-script: strategy: + fail-fast: false max-parallel: 2 matrix: os: [ubuntu-latest, macos-latest] diff --git a/release_files/install.sh b/release_files/install.sh index 5dd0f67bb..b7a6c08f9 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -33,14 +33,16 @@ get_release() { local RELEASE=$1 if [ "$RELEASE" = "latest" ]; then local TAG="latest" + local URL="https://pkgs.netbird.io/releases/latest" else local TAG="tags/${RELEASE}" + local URL="https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}" fi if [ -n "$GITHUB_TOKEN" ]; then - curl -H "Authorization: token ${GITHUB_TOKEN}" -s "https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}" \ + curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}" \ | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' else - curl -s "https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}" \ + curl -s "${URL}" \ | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' fi } From ff7863785f81c64ce0570b28950f806b75800c6a Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Wed, 2 Oct 2024 14:41:00 +0300 Subject: [PATCH 12/37] [management, client] Add access control support to network routes (#2100) --- .github/workflows/golangci-lint.yml | 2 +- client/firewall/iface.go | 4 +- client/firewall/iptables/acl_linux.go | 174 +-- client/firewall/iptables/manager_linux.go | 64 +- .../firewall/iptables/manager_linux_test.go | 54 +- client/firewall/iptables/router_linux.go | 616 ++++++---- client/firewall/iptables/router_linux_test.go | 270 ++-- client/firewall/manager/firewall.go | 125 +- client/firewall/manager/firewall_test.go | 192 +++ client/firewall/manager/routerpair.go | 16 +- client/firewall/nftables/acl_linux.go | 549 +-------- client/firewall/nftables/manager_linux.go | 121 +- .../firewall/nftables/manager_linux_test.go | 78 +- client/firewall/nftables/route_linux.go | 431 ------- client/firewall/nftables/router_linux.go | 798 ++++++++++++ client/firewall/nftables/router_linux_test.go | 605 +++++++-- client/firewall/test/cases_linux.go | 20 +- client/firewall/uspfilter/uspfilter.go | 42 +- client/firewall/uspfilter/uspfilter_test.go | 20 +- client/internal/acl/id/id.go | 25 + client/internal/acl/manager.go | 255 ++-- client/internal/acl/manager_test.go | 170 +-- client/internal/engine.go | 9 +- client/internal/routemanager/dynamic/route.go | 2 +- client/internal/routemanager/manager.go | 6 +- .../routemanager/refcounter/refcounter.go | 197 ++- .../internal/routemanager/refcounter/types.go | 6 +- .../routemanager/server_nonandroid.go | 16 +- client/internal/routemanager/static/route.go | 2 +- .../routemanager/systemops/systemops.go | 2 +- .../systemops/systemops_generic.go | 4 +- management/proto/management.pb.go | 1087 +++++++++++------ management/proto/management.proto | 84 +- management/server/account.go | 4 +- management/server/account_test.go | 7 +- management/server/grpcserver.go | 4 + management/server/http/api/openapi.yml | 30 +- management/server/http/api/types.gen.go | 30 +- management/server/http/policies_handler.go | 33 +- management/server/http/routes_handler.go | 16 +- management/server/http/routes_handler_test.go | 48 +- management/server/mock_server/account_mock.go | 8 +- management/server/network.go | 13 +- management/server/peer_test.go | 7 +- management/server/policy.go | 48 +- management/server/route.go | 292 ++++- management/server/route_test.go | 536 ++++++-- route/route.go | 5 +- 48 files changed, 4683 insertions(+), 2444 deletions(-) create mode 100644 client/firewall/manager/firewall_test.go delete mode 100644 client/firewall/nftables/route_linux.go create mode 100644 client/firewall/nftables/router_linux.go create mode 100644 client/internal/acl/id/id.go diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 8b7136841..2d743f790 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -19,7 +19,7 @@ jobs: - name: codespell uses: codespell-project/actions-codespell@v2 with: - ignore_words_list: erro,clienta,hastable, + ignore_words_list: erro,clienta,hastable,iif skip: go.mod,go.sum only_warn: 1 golangci: diff --git a/client/firewall/iface.go b/client/firewall/iface.go index 882daef75..d0b5209c0 100644 --- a/client/firewall/iface.go +++ b/client/firewall/iface.go @@ -1,6 +1,8 @@ package firewall -import "github.com/netbirdio/netbird/iface" +import ( + "github.com/netbirdio/netbird/iface" +) // IFaceMapper defines subset methods of interface required for manager type IFaceMapper interface { diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index b77cc8f43..c6a96a876 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -19,24 +19,22 @@ const ( // rules chains contains the effective ACL rules chainNameInputRules = "NETBIRD-ACL-INPUT" chainNameOutputRules = "NETBIRD-ACL-OUTPUT" - - postRoutingMark = "0x000007e4" ) type aclManager struct { - iptablesClient *iptables.IPTables - wgIface iFaceMapper - routeingFwChainName string + iptablesClient *iptables.IPTables + wgIface iFaceMapper + routingFwChainName string entries map[string][][]string ipsetStore *ipsetStore } -func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routeingFwChainName string) (*aclManager, error) { +func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) { m := &aclManager{ - iptablesClient: iptablesClient, - wgIface: wgIface, - routeingFwChainName: routeingFwChainName, + iptablesClient: iptablesClient, + wgIface: wgIface, + routingFwChainName: routingFwChainName, entries: make(map[string][][]string), ipsetStore: newIpsetStore(), @@ -61,7 +59,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, route return m, nil } -func (m *aclManager) AddFiltering( +func (m *aclManager) AddPeerFiltering( ip net.IP, protocol firewall.Protocol, sPort *firewall.Port, @@ -127,7 +125,7 @@ func (m *aclManager) AddFiltering( return nil, fmt.Errorf("rule already exists") } - if err := m.iptablesClient.Insert("filter", chain, 1, specs...); err != nil { + if err := m.iptablesClient.Append("filter", chain, specs...); err != nil { return nil, err } @@ -139,28 +137,16 @@ func (m *aclManager) AddFiltering( chain: chain, } - if !shouldAddToPrerouting(protocol, dPort, direction) { - return []firewall.Rule{rule}, nil - } - - rulePrerouting, err := m.addPreroutingFilter(ipsetName, string(protocol), dPortVal, ip) - if err != nil { - return []firewall.Rule{rule}, err - } - return []firewall.Rule{rule, rulePrerouting}, nil + return []firewall.Rule{rule}, nil } -// DeleteRule from the firewall by rule definition -func (m *aclManager) DeleteRule(rule firewall.Rule) error { +// DeletePeerRule from the firewall by rule definition +func (m *aclManager) DeletePeerRule(rule firewall.Rule) error { r, ok := rule.(*Rule) if !ok { return fmt.Errorf("invalid rule type") } - if r.chain == "PREROUTING" { - goto DELETERULE - } - if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok { // delete IP from ruleset IPs list and ipset if _, ok := ipsetList.ips[r.ip]; ok { @@ -185,14 +171,7 @@ func (m *aclManager) DeleteRule(rule firewall.Rule) error { } } -DELETERULE: - var table string - if r.chain == "PREROUTING" { - table = "mangle" - } else { - table = "filter" - } - err := m.iptablesClient.Delete(table, r.chain, r.specs...) + err := m.iptablesClient.Delete(tableName, r.chain, r.specs...) if err != nil { log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err) } @@ -203,44 +182,6 @@ func (m *aclManager) Reset() error { return m.cleanChains() } -func (m *aclManager) addPreroutingFilter(ipsetName string, protocol string, port string, ip net.IP) (*Rule, error) { - var src []string - if ipsetName != "" { - src = []string{"-m", "set", "--set", ipsetName, "src"} - } else { - src = []string{"-s", ip.String()} - } - specs := []string{ - "-d", m.wgIface.Address().IP.String(), - "-p", protocol, - "--dport", port, - "-j", "MARK", "--set-mark", postRoutingMark, - } - - specs = append(src, specs...) - - ok, err := m.iptablesClient.Exists("mangle", "PREROUTING", specs...) - if err != nil { - return nil, fmt.Errorf("failed to check rule: %w", err) - } - if ok { - return nil, fmt.Errorf("rule already exists") - } - - if err := m.iptablesClient.Insert("mangle", "PREROUTING", 1, specs...); err != nil { - return nil, err - } - - rule := &Rule{ - ruleID: uuid.New().String(), - specs: specs, - ipsetName: ipsetName, - ip: ip.String(), - chain: "PREROUTING", - } - return rule, nil -} - // todo write less destructive cleanup mechanism func (m *aclManager) cleanChains() error { ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules) @@ -291,25 +232,6 @@ func (m *aclManager) cleanChains() error { } } - ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING") - if err != nil { - log.Debugf("failed to list chains: %s", err) - return err - } - if ok { - for _, rule := range m.entries["PREROUTING"] { - err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...) - if err != nil { - log.Errorf("failed to delete rule: %v, %s", rule, err) - } - } - err = m.iptablesClient.ClearChain("mangle", "PREROUTING") - if err != nil { - log.Debugf("failed to clear %s chain: %s", "PREROUTING", err) - return err - } - } - for _, ipsetName := range m.ipsetStore.ipsetNames() { if err := ipset.Flush(ipsetName); err != nil { log.Errorf("flush ipset %q during reset: %v", ipsetName, err) @@ -338,17 +260,9 @@ func (m *aclManager) createDefaultChains() error { for chainName, rules := range m.entries { for _, rule := range rules { - if chainName == "FORWARD" { - // position 2 because we add it after router's, jump rule - if err := m.iptablesClient.InsertUnique(tableName, "FORWARD", 2, rule...); err != nil { - log.Debugf("failed to create input chain jump rule: %s", err) - return err - } - } else { - if err := m.iptablesClient.AppendUnique(tableName, chainName, rule...); err != nil { - log.Debugf("failed to create input chain jump rule: %s", err) - return err - } + if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil { + log.Debugf("failed to create input chain jump rule: %s", err) + return err } } } @@ -356,40 +270,29 @@ func (m *aclManager) createDefaultChains() error { return nil } +// seedInitialEntries adds default rules to the entries map, rules are inserted on pos 1, hence the order is reversed. +// We want to make sure our traffic is not dropped by existing rules. + +// The existing FORWARD rules/policies decide outbound traffic towards our interface. +// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. + +// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule. func (m *aclManager) seedInitialEntries() { - m.appendToEntries("INPUT", - []string{"-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"}) - m.appendToEntries("INPUT", - []string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"}) - - m.appendToEntries("INPUT", - []string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameInputRules}) + established := getConntrackEstablished() m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) - - m.appendToEntries("OUTPUT", - []string{"-o", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"}) - - m.appendToEntries("OUTPUT", - []string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"}) - - m.appendToEntries("OUTPUT", - []string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameOutputRules}) + m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules}) + m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...)) m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"}) + m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", chainNameOutputRules}) + m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"}) + m.appendToEntries("OUTPUT", append([]string{"-o", m.wgIface.Name()}, established...)) m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) - m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules}) - m.appendToEntries("FORWARD", - []string{"-o", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"}) - m.appendToEntries("FORWARD", - []string{"-i", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"}) - m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", m.routeingFwChainName}) - m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routeingFwChainName}) - - m.appendToEntries("PREROUTING", - []string{"-t", "mangle", "-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().IP.String(), "-m", "mark", "--mark", postRoutingMark}) + m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName}) + m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...)) } func (m *aclManager) appendToEntries(chainName string, spec []string) { @@ -456,18 +359,3 @@ func transformIPsetName(ipsetName string, sPort, dPort string) string { return ipsetName } } - -func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool { - if proto == "all" { - return false - } - - if direction != firewall.RuleDirectionIN { - return false - } - - if dPort == nil { - return false - } - return true -} diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 2d231ec45..fae41d9c5 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "net/netip" "sync" "github.com/coreos/go-iptables/iptables" @@ -21,7 +22,7 @@ type Manager struct { ipv4Client *iptables.IPTables aclMgr *aclManager - router *routerManager + router *router } // iFaceMapper defines subset methods of interface required for manager @@ -43,12 +44,12 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { ipv4Client: iptablesClient, } - m.router, err = newRouterManager(context, iptablesClient) + m.router, err = newRouter(context, iptablesClient, wgIface) if err != nil { log.Debugf("failed to initialize route related chains: %s", err) return nil, err } - m.aclMgr, err = newAclManager(iptablesClient, wgIface, m.router.RouteingFwChainName()) + m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD) if err != nil { log.Debugf("failed to initialize ACL manager: %s", err) return nil, err @@ -57,10 +58,10 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { return m, nil } -// AddFiltering rule to the firewall +// AddPeerFiltering adds a rule to the firewall // // Comment will be ignored because some system this feature is not supported -func (m *Manager) AddFiltering( +func (m *Manager) AddPeerFiltering( ip net.IP, protocol firewall.Protocol, sPort *firewall.Port, @@ -73,33 +74,62 @@ func (m *Manager) AddFiltering( m.mutex.Lock() defer m.mutex.Unlock() - return m.aclMgr.AddFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName) + return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName) } -// DeleteRule from the firewall by rule definition -func (m *Manager) DeleteRule(rule firewall.Rule) error { +func (m *Manager) AddRouteFiltering( + sources [] netip.Prefix, + destination netip.Prefix, + proto firewall.Protocol, + sPort *firewall.Port, + dPort *firewall.Port, + action firewall.Action, +) (firewall.Rule, error) { m.mutex.Lock() defer m.mutex.Unlock() - return m.aclMgr.DeleteRule(rule) + if !destination.Addr().Is4() { + return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String()) + } + + return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action) +} + +// DeletePeerRule from the firewall by rule definition +func (m *Manager) DeletePeerRule(rule firewall.Rule) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.aclMgr.DeletePeerRule(rule) +} + +func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.DeleteRouteRule(rule) } func (m *Manager) IsServerRouteSupported() bool { return true } -func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error { +func (m *Manager) AddNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.InsertRoutingRules(pair) + return m.router.AddNatRule(pair) } -func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error { +func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.RemoveRoutingRules(pair) + return m.router.RemoveNatRule(pair) +} + +func (m *Manager) SetLegacyManagement(isLegacy bool) error { + return firewall.SetLegacyManagement(m.router, isLegacy) } // Reset firewall to the default state @@ -125,7 +155,7 @@ func (m *Manager) AllowNetbird() error { return nil } - _, err := m.AddFiltering( + _, err := m.AddPeerFiltering( net.ParseIP("0.0.0.0"), "all", nil, @@ -138,7 +168,7 @@ func (m *Manager) AllowNetbird() error { if err != nil { return fmt.Errorf("failed to allow netbird interface traffic: %w", err) } - _, err = m.AddFiltering( + _, err = m.AddPeerFiltering( net.ParseIP("0.0.0.0"), "all", nil, @@ -153,3 +183,7 @@ func (m *Manager) AllowNetbird() error { // Flush doesn't need to be implemented for this manager func (m *Manager) Flush() error { return nil } + +func getConntrackEstablished() []string { + return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} +} diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go index ceb116c62..0072aa159 100644 --- a/client/firewall/iptables/manager_linux_test.go +++ b/client/firewall/iptables/manager_linux_test.go @@ -14,6 +14,21 @@ import ( "github.com/netbirdio/netbird/iface" ) +var ifaceMock = &iFaceMock{ + NameFunc: func() string { + return "lo" + }, + AddressFunc: func() iface.WGAddress { + return iface.WGAddress{ + IP: net.ParseIP("10.20.0.1"), + Network: &net.IPNet{ + IP: net.ParseIP("10.20.0.0"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + } + }, +} + // iFaceMapper defines subset methods of interface required for manager type iFaceMock struct { NameFunc func() string @@ -40,23 +55,8 @@ func TestIptablesManager(t *testing.T) { ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err) - mock := &iFaceMock{ - NameFunc: func() string { - return "lo" - }, - AddressFunc: func() iface.WGAddress { - return iface.WGAddress{ - IP: net.ParseIP("10.20.0.1"), - Network: &net.IPNet{ - IP: net.ParseIP("10.20.0.0"), - Mask: net.IPv4Mask(255, 255, 255, 0), - }, - } - }, - } - // just check on the local interface - manager, err := Create(context.Background(), mock) + manager, err := Create(context.Background(), ifaceMock) require.NoError(t, err) time.Sleep(time.Second) @@ -72,7 +72,7 @@ func TestIptablesManager(t *testing.T) { t.Run("add first rule", func(t *testing.T) { ip := net.ParseIP("10.20.0.2") port := &fw.Port{Values: []int{8080}} - rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") + rule1, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") require.NoError(t, err, "failed to add rule") for _, r := range rule1 { @@ -87,7 +87,7 @@ func TestIptablesManager(t *testing.T) { port := &fw.Port{ Values: []int{8043: 8046}, } - rule2, err = manager.AddFiltering( + rule2, err = manager.AddPeerFiltering( ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range") require.NoError(t, err, "failed to add rule") @@ -99,7 +99,7 @@ func TestIptablesManager(t *testing.T) { t.Run("delete first rule", func(t *testing.T) { for _, r := range rule1 { - err := manager.DeleteRule(r) + err := manager.DeletePeerRule(r) require.NoError(t, err, "failed to delete rule") checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...) @@ -108,7 +108,7 @@ func TestIptablesManager(t *testing.T) { t.Run("delete second rule", func(t *testing.T) { for _, r := range rule2 { - err := manager.DeleteRule(r) + err := manager.DeletePeerRule(r) require.NoError(t, err, "failed to delete rule") } @@ -119,7 +119,7 @@ func TestIptablesManager(t *testing.T) { // add second rule ip := net.ParseIP("10.20.0.3") port := &fw.Port{Values: []int{5353}} - _, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic") + _, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic") require.NoError(t, err, "failed to add rule") err = manager.Reset() @@ -170,7 +170,7 @@ func TestIptablesManagerIPSet(t *testing.T) { t.Run("add first rule with set", func(t *testing.T) { ip := net.ParseIP("10.20.0.2") port := &fw.Port{Values: []int{8080}} - rule1, err = manager.AddFiltering( + rule1, err = manager.AddPeerFiltering( ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "default", "accept HTTP traffic", ) @@ -189,7 +189,7 @@ func TestIptablesManagerIPSet(t *testing.T) { port := &fw.Port{ Values: []int{443}, } - rule2, err = manager.AddFiltering( + rule2, err = manager.AddPeerFiltering( ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "default", "accept HTTPS traffic from ports range", ) @@ -202,7 +202,7 @@ func TestIptablesManagerIPSet(t *testing.T) { t.Run("delete first rule", func(t *testing.T) { for _, r := range rule1 { - err := manager.DeleteRule(r) + err := manager.DeletePeerRule(r) require.NoError(t, err, "failed to delete rule") require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index") @@ -211,7 +211,7 @@ func TestIptablesManagerIPSet(t *testing.T) { t.Run("delete second rule", func(t *testing.T) { for _, r := range rule2 { - err := manager.DeleteRule(r) + err := manager.DeletePeerRule(r) require.NoError(t, err, "failed to delete rule") require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty") @@ -269,9 +269,9 @@ func TestIptablesCreatePerformance(t *testing.T) { for i := 0; i < testMax; i++ { port := &fw.Port{Values: []int{1000 + i}} if i%2 == 0 { - _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") + _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") } else { - _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") + _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") } require.NoError(t, err, "failed to add rule") diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index e8f09a106..737b20785 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -5,368 +5,478 @@ package iptables import ( "context" "fmt" + "net/netip" + "strconv" "strings" "github.com/coreos/go-iptables/iptables" + "github.com/hashicorp/go-multierror" + "github.com/nadoo/ipset" log "github.com/sirupsen/logrus" + nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/internal/acl/id" + "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" ) const ( - Ipv4Forwarding = "netbird-rt-forwarding" - ipv4Nat = "netbird-rt-nat" + ipv4Nat = "netbird-rt-nat" ) // constants needed to manage and create iptable rules const ( tableFilter = "filter" tableNat = "nat" - chainFORWARD = "FORWARD" chainPOSTROUTING = "POSTROUTING" chainRTNAT = "NETBIRD-RT-NAT" chainRTFWD = "NETBIRD-RT-FWD" routingFinalForwardJump = "ACCEPT" routingFinalNatJump = "MASQUERADE" + + matchSet = "--match-set" ) -type routerManager struct { - ctx context.Context - stop context.CancelFunc - iptablesClient *iptables.IPTables - rules map[string][]string +type routeFilteringRuleParams struct { + Sources []netip.Prefix + Destination netip.Prefix + Proto firewall.Protocol + SPort *firewall.Port + DPort *firewall.Port + Direction firewall.RuleDirection + Action firewall.Action + SetName string } -func newRouterManager(parentCtx context.Context, iptablesClient *iptables.IPTables) (*routerManager, error) { +type router struct { + ctx context.Context + stop context.CancelFunc + iptablesClient *iptables.IPTables + rules map[string][]string + ipsetCounter *refcounter.Counter[string, []netip.Prefix, struct{}] + wgIface iFaceMapper + legacyManagement bool +} + +func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) { ctx, cancel := context.WithCancel(parentCtx) - m := &routerManager{ + r := &router{ ctx: ctx, stop: cancel, iptablesClient: iptablesClient, rules: make(map[string][]string), + wgIface: wgIface, } - err := m.cleanUpDefaultForwardRules() + r.ipsetCounter = refcounter.New( + r.createIpSet, + func(name string, _ struct{}) error { + return r.deleteIpSet(name) + }, + ) + + if err := ipset.Init(); err != nil { + return nil, fmt.Errorf("init ipset: %w", err) + } + + err := r.cleanUpDefaultForwardRules() if err != nil { - log.Errorf("failed to cleanup routing rules: %s", err) + log.Errorf("cleanup routing rules: %s", err) return nil, err } - err = m.createContainers() + err = r.createContainers() if err != nil { - log.Errorf("failed to create containers for route: %s", err) + log.Errorf("create containers for route: %s", err) } - return m, err + return r, err } -// InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain -func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error { - err := i.insertRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, pair) - if err != nil { - return err +func (r *router) AddRouteFiltering( + sources []netip.Prefix, + destination netip.Prefix, + proto firewall.Protocol, + sPort *firewall.Port, + dPort *firewall.Port, + action firewall.Action, +) (firewall.Rule, error) { + ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action) + if _, ok := r.rules[string(ruleKey)]; ok { + return ruleKey, nil } - err = i.insertRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, firewall.GetInPair(pair)) - if err != nil { - return err + var setName string + if len(sources) > 1 { + setName = firewall.GenerateSetName(sources) + if _, err := r.ipsetCounter.Increment(setName, sources); err != nil { + return nil, fmt.Errorf("create or get ipset: %w", err) + } + } + + params := routeFilteringRuleParams{ + Sources: sources, + Destination: destination, + Proto: proto, + SPort: sPort, + DPort: dPort, + Action: action, + SetName: setName, + } + + rule := genRouteFilteringRuleSpec(params) + if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil { + return nil, fmt.Errorf("add route rule: %v", err) + } + + r.rules[string(ruleKey)] = rule + + return ruleKey, nil +} + +func (r *router) DeleteRouteRule(rule firewall.Rule) error { + ruleKey := rule.GetRuleID() + + if rule, exists := r.rules[ruleKey]; exists { + setName := r.findSetNameInRule(rule) + + if err := r.iptablesClient.Delete(tableFilter, chainRTFWD, rule...); err != nil { + return fmt.Errorf("delete route rule: %v", err) + } + delete(r.rules, ruleKey) + + if setName != "" { + if _, err := r.ipsetCounter.Decrement(setName); err != nil { + return fmt.Errorf("failed to remove ipset: %w", err) + } + } + } else { + log.Debugf("route rule %s not found", ruleKey) + } + + return nil +} + +func (r *router) findSetNameInRule(rule []string) string { + for i, arg := range rule { + if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet { + return rule[i+3] + } + } + return "" +} + +func (r *router) createIpSet(setName string, sources []netip.Prefix) (struct{}, error) { + if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil { + return struct{}{}, fmt.Errorf("create set %s: %w", setName, err) + } + + for _, prefix := range sources { + if err := ipset.AddPrefix(setName, prefix); err != nil { + return struct{}{}, fmt.Errorf("add element to set %s: %w", setName, err) + } + } + + return struct{}{}, nil +} + +func (r *router) deleteIpSet(setName string) error { + if err := ipset.Destroy(setName); err != nil { + return fmt.Errorf("destroy set %s: %w", setName, err) + } + return nil +} + +// AddNatRule inserts an iptables rule pair into the nat chain +func (r *router) AddNatRule(pair firewall.RouterPair) error { + if r.legacyManagement { + log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination) + if err := r.addLegacyRouteRule(pair); err != nil { + return fmt.Errorf("add legacy routing rule: %w", err) + } } if !pair.Masquerade { return nil } - err = i.addNATRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair) - if err != nil { - return err + if err := r.addNatRule(pair); err != nil { + return fmt.Errorf("add nat rule: %w", err) } - err = i.addNATRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair)) - if err != nil { - return err + if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil { + return fmt.Errorf("add inverse nat rule: %w", err) } return nil } -// insertRoutingRule inserts an iptables rule -func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error { - var err error +// RemoveNatRule removes an iptables rule pair from forwarding and nat chains +func (r *router) RemoveNatRule(pair firewall.RouterPair) error { + if err := r.removeNatRule(pair); err != nil { + return fmt.Errorf("remove nat rule: %w", err) + } - ruleKey := firewall.GenKey(keyFormat, pair.ID) - rule := genRuleSpec(jump, pair.Source, pair.Destination) - existingRule, found := i.rules[ruleKey] - if found { - err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...) - if err != nil { - return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err) + if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { + return fmt.Errorf("remove inverse nat rule: %w", err) + } + + if err := r.removeLegacyRouteRule(pair); err != nil { + return fmt.Errorf("remove legacy routing rule: %w", err) + } + + return nil +} + +// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls +func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { + ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) + + if err := r.removeLegacyRouteRule(pair); err != nil { + return err + } + + rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump} + if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil { + return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) + } + + r.rules[ruleKey] = rule + + return nil +} + +func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error { + ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) + + if rule, exists := r.rules[ruleKey]; exists { + if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil { + return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) } - delete(i.rules, ruleKey) - } - - err = i.iptablesClient.Insert(table, chain, 1, rule...) - if err != nil { - return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err) - } - - i.rules[ruleKey] = rule - - return nil -} - -// RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains -func (i *routerManager) RemoveRoutingRules(pair firewall.RouterPair) error { - err := i.removeRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, pair) - if err != nil { - return err - } - - err = i.removeRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, firewall.GetInPair(pair)) - if err != nil { - return err - } - - if !pair.Masquerade { - return nil - } - - err = i.removeRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, pair) - if err != nil { - return err - } - - err = i.removeRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, firewall.GetInPair(pair)) - if err != nil { - return err + delete(r.rules, ruleKey) + } else { + log.Debugf("legacy forwarding rule %s not found", ruleKey) } return nil } -func (i *routerManager) removeRoutingRule(keyFormat, table, chain string, pair firewall.RouterPair) error { - var err error +// GetLegacyManagement returns the current legacy management mode +func (r *router) GetLegacyManagement() bool { + return r.legacyManagement +} - ruleKey := firewall.GenKey(keyFormat, pair.ID) - existingRule, found := i.rules[ruleKey] - if found { - err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...) - if err != nil { - return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err) +// SetLegacyManagement sets the route manager to use legacy management mode +func (r *router) SetLegacyManagement(isLegacy bool) { + r.legacyManagement = isLegacy +} + +// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls +func (r *router) RemoveAllLegacyRouteRules() error { + var merr *multierror.Error + for k, rule := range r.rules { + if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) { + continue + } + if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err)) } } - delete(i.rules, ruleKey) - - return nil + return nberrors.FormatErrorOrNil(merr) } -func (i *routerManager) RouteingFwChainName() string { - return chainRTFWD -} - -func (i *routerManager) Reset() error { - err := i.cleanUpDefaultForwardRules() - if err != nil { - return err +func (r *router) Reset() error { + var merr *multierror.Error + if err := r.cleanUpDefaultForwardRules(); err != nil { + merr = multierror.Append(merr, err) } - i.rules = make(map[string][]string) - return nil + r.rules = make(map[string][]string) + + if err := r.ipsetCounter.Flush(); err != nil { + merr = multierror.Append(merr, err) + } + + return nberrors.FormatErrorOrNil(merr) } -func (i *routerManager) cleanUpDefaultForwardRules() error { - err := i.cleanJumpRules() +func (r *router) cleanUpDefaultForwardRules() error { + err := r.cleanJumpRules() if err != nil { return err } log.Debug("flushing routing related tables") - ok, err := i.iptablesClient.ChainExists(tableFilter, chainRTFWD) - if err != nil { - log.Errorf("failed check chain %s,error: %v", chainRTFWD, err) - return err - } else if ok { - err = i.iptablesClient.ClearAndDeleteChain(tableFilter, chainRTFWD) + for _, chain := range []string{chainRTFWD, chainRTNAT} { + table := tableFilter + if chain == chainRTNAT { + table = tableNat + } + + ok, err := r.iptablesClient.ChainExists(table, chain) if err != nil { - log.Errorf("failed cleaning chain %s,error: %v", chainRTFWD, err) + log.Errorf("failed check chain %s, error: %v", chain, err) return err + } else if ok { + err = r.iptablesClient.ClearAndDeleteChain(table, chain) + if err != nil { + log.Errorf("failed cleaning chain %s, error: %v", chain, err) + return err + } } } - ok, err = i.iptablesClient.ChainExists(tableNat, chainRTNAT) - if err != nil { - log.Errorf("failed check chain %s,error: %v", chainRTNAT, err) - return err - } else if ok { - err = i.iptablesClient.ClearAndDeleteChain(tableNat, chainRTNAT) - if err != nil { - log.Errorf("failed cleaning chain %s,error: %v", chainRTNAT, err) - return err - } - } - return nil -} - -func (i *routerManager) createContainers() error { - if i.rules[Ipv4Forwarding] != nil { - return nil - } - - errMSGFormat := "failed creating chain %s,error: %v" - err := i.createChain(tableFilter, chainRTFWD) - if err != nil { - return fmt.Errorf(errMSGFormat, chainRTFWD, err) - } - - err = i.createChain(tableNat, chainRTNAT) - if err != nil { - return fmt.Errorf(errMSGFormat, chainRTNAT, err) - } - - err = i.addJumpRules() - if err != nil { - return fmt.Errorf("error while creating jump rules: %v", err) - } - return nil } -// addJumpRules create jump rules to send packets to NetBird chains -func (i *routerManager) addJumpRules() error { - rule := []string{"-j", chainRTFWD} - err := i.iptablesClient.Insert(tableFilter, chainFORWARD, 1, rule...) +func (r *router) createContainers() error { + for _, chain := range []string{chainRTFWD, chainRTNAT} { + if err := r.createAndSetupChain(chain); err != nil { + return fmt.Errorf("create chain %s: %v", chain, err) + } + } + + if err := r.insertEstablishedRule(chainRTFWD); err != nil { + return fmt.Errorf("insert established rule: %v", err) + } + + return r.addJumpRules() +} + +func (r *router) createAndSetupChain(chain string) error { + table := r.getTableForChain(chain) + + if err := r.iptablesClient.NewChain(table, chain); err != nil { + return fmt.Errorf("failed creating chain %s, error: %v", chain, err) + } + + return nil +} + +func (r *router) getTableForChain(chain string) string { + if chain == chainRTNAT { + return tableNat + } + return tableFilter +} + +func (r *router) insertEstablishedRule(chain string) error { + establishedRule := getConntrackEstablished() + + err := r.iptablesClient.Insert(tableFilter, chain, 1, establishedRule...) + if err != nil { + return fmt.Errorf("failed to insert established rule: %v", err) + } + + ruleKey := "established-" + chain + r.rules[ruleKey] = establishedRule + + return nil +} + +func (r *router) addJumpRules() error { + rule := []string{"-j", chainRTNAT} + err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...) if err != nil { return err } - i.rules[Ipv4Forwarding] = rule - - rule = []string{"-j", chainRTNAT} - err = i.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...) - if err != nil { - return err - } - i.rules[ipv4Nat] = rule + r.rules[ipv4Nat] = rule return nil } -// cleanJumpRules cleans jump rules that was sending packets to NetBird chains -func (i *routerManager) cleanJumpRules() error { - var err error - errMSGFormat := "failed cleaning rule from chain %s,err: %v" - rule, found := i.rules[Ipv4Forwarding] +func (r *router) cleanJumpRules() error { + rule, found := r.rules[ipv4Nat] if found { - err = i.iptablesClient.DeleteIfExists(tableFilter, chainFORWARD, rule...) + err := r.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...) if err != nil { - return fmt.Errorf(errMSGFormat, chainFORWARD, err) - } - } - rule, found = i.rules[ipv4Nat] - if found { - err = i.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...) - if err != nil { - return fmt.Errorf(errMSGFormat, chainPOSTROUTING, err) + return fmt.Errorf("failed cleaning rule from chain %s, err: %v", chainPOSTROUTING, err) } } - rules, err := i.iptablesClient.List("nat", "POSTROUTING") - if err != nil { - return fmt.Errorf("failed to list rules: %s", err) - } - - for _, ruleString := range rules { - if !strings.Contains(ruleString, "NETBIRD") { - continue - } - rule := strings.Fields(ruleString) - err := i.iptablesClient.DeleteIfExists("nat", "POSTROUTING", rule[2:]...) - if err != nil { - return fmt.Errorf("failed to delete postrouting jump rule: %s", err) - } - } - - rules, err = i.iptablesClient.List(tableFilter, "FORWARD") - if err != nil { - return fmt.Errorf("failed to list rules in FORWARD chain: %s", err) - } - - for _, ruleString := range rules { - if !strings.Contains(ruleString, "NETBIRD") { - continue - } - rule := strings.Fields(ruleString) - err := i.iptablesClient.DeleteIfExists(tableFilter, "FORWARD", rule[2:]...) - if err != nil { - return fmt.Errorf("failed to delete FORWARD jump rule: %s", err) - } - } return nil } -func (i *routerManager) createChain(table, newChain string) error { - chains, err := i.iptablesClient.ListChains(table) - if err != nil { - return fmt.Errorf("couldn't get %s table chains, error: %v", table, err) - } +func (r *router) addNatRule(pair firewall.RouterPair) error { + ruleKey := firewall.GenKey(firewall.NatFormat, pair) - shouldCreateChain := true - for _, chain := range chains { - if chain == newChain { - shouldCreateChain = false - } - } - - if shouldCreateChain { - err = i.iptablesClient.NewChain(table, newChain) - if err != nil { - return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err) - } - - // Add the loopback return rule to the NAT chain - loopbackRule := []string{"-o", "lo", "-j", "RETURN"} - err = i.iptablesClient.Insert(table, newChain, 1, loopbackRule...) - if err != nil { - return fmt.Errorf("failed to add loopback return rule to %s: %v", chainRTNAT, err) - } - - err = i.iptablesClient.Append(table, newChain, "-j", "RETURN") - if err != nil { - return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err) - } - - } - return nil -} - -// addNATRule appends an iptables rule pair to the nat chain -func (i *routerManager) addNATRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error { - ruleKey := firewall.GenKey(keyFormat, pair.ID) - rule := genRuleSpec(jump, pair.Source, pair.Destination) - existingRule, found := i.rules[ruleKey] - if found { - err := i.iptablesClient.DeleteIfExists(table, chain, existingRule...) - if err != nil { + if rule, exists := r.rules[ruleKey]; exists { + if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil { return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err) } - delete(i.rules, ruleKey) + delete(r.rules, ruleKey) } - // inserting after loopback ignore rule - err := i.iptablesClient.Insert(table, chain, 2, rule...) - if err != nil { + rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, r.wgIface.Name(), pair.Inverse) + if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule...); err != nil { return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err) } - i.rules[ruleKey] = rule + r.rules[ruleKey] = rule return nil } -// genRuleSpec generates rule specification -func genRuleSpec(jump, source, destination string) []string { - return []string{"-s", source, "-d", destination, "-j", jump} +func (r *router) removeNatRule(pair firewall.RouterPair) error { + ruleKey := firewall.GenKey(firewall.NatFormat, pair) + + if rule, exists := r.rules[ruleKey]; exists { + if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil { + return fmt.Errorf("error while removing existing nat rule for %s: %v", pair.Destination, err) + } + + delete(r.rules, ruleKey) + } else { + log.Debugf("nat rule %s not found", ruleKey) + } + + return nil } -func getIptablesRuleType(table string) string { - ruleType := "forwarding" - if table == tableNat { - ruleType = "nat" +func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string { + intdir := "-i" + if inverse { + intdir = "-o" } - return ruleType + return []string{intdir, intf, "-s", source.String(), "-d", destination.String(), "-j", jump} +} + +func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string { + var rule []string + + if params.SetName != "" { + rule = append(rule, "-m", "set", matchSet, params.SetName, "src") + } else if len(params.Sources) > 0 { + source := params.Sources[0] + rule = append(rule, "-s", source.String()) + } + + rule = append(rule, "-d", params.Destination.String()) + + if params.Proto != firewall.ProtocolALL { + rule = append(rule, "-p", strings.ToLower(string(params.Proto))) + rule = append(rule, applyPort("--sport", params.SPort)...) + rule = append(rule, applyPort("--dport", params.DPort)...) + } + + rule = append(rule, "-j", actionToStr(params.Action)) + + return rule +} + +func applyPort(flag string, port *firewall.Port) []string { + if port == nil { + return nil + } + + if port.IsRange && len(port.Values) == 2 { + return []string{flag, fmt.Sprintf("%d:%d", port.Values[0], port.Values[1])} + } + + if len(port.Values) > 1 { + portList := make([]string, len(port.Values)) + for i, p := range port.Values { + portList[i] = strconv.Itoa(p) + } + return []string{"-m", "multiport", flag, strings.Join(portList, ",")} + } + + return []string{flag, strconv.Itoa(port.Values[0])} } diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go index 79b970c36..6cede09e2 100644 --- a/client/firewall/iptables/router_linux_test.go +++ b/client/firewall/iptables/router_linux_test.go @@ -4,11 +4,13 @@ package iptables import ( "context" + "net/netip" "os/exec" "testing" "github.com/coreos/go-iptables/iptables" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" firewall "github.com/netbirdio/netbird/client/firewall/manager" @@ -28,7 +30,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "failed to init iptables client") - manager, err := newRouterManager(context.TODO(), iptablesClient) + manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) require.NoError(t, err, "should return a valid iptables manager") defer func() { @@ -37,26 +39,22 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { require.Len(t, manager.rules, 2, "should have created rules map") - exists, err := manager.iptablesClient.Exists(tableFilter, chainFORWARD, manager.rules[Ipv4Forwarding]...) - require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainFORWARD) - require.True(t, exists, "forwarding rule should exist") - - exists, err = manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...) + exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING) require.True(t, exists, "postrouting rule should exist") pair := firewall.RouterPair{ ID: "abc", - Source: "100.100.100.1/32", - Destination: "100.100.100.0/24", + Source: netip.MustParsePrefix("100.100.100.1/32"), + Destination: netip.MustParsePrefix("100.100.100.0/24"), Masquerade: true, } - forward4Rule := genRuleSpec(routingFinalForwardJump, pair.Source, pair.Destination) + forward4Rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump} err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...) require.NoError(t, err, "inserting rule should not return error") - nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination) + nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, ifaceMock.Name(), false) err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...) require.NoError(t, err, "inserting rule should not return error") @@ -65,7 +63,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { require.NoError(t, err, "shouldn't return error") } -func TestIptablesManager_InsertRoutingRules(t *testing.T) { +func TestIptablesManager_AddNatRule(t *testing.T) { if !isIptablesSupported() { t.SkipNow() @@ -76,7 +74,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "failed to init iptables client") - manager, err := newRouterManager(context.TODO(), iptablesClient) + manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) require.NoError(t, err, "shouldn't return error") defer func() { @@ -86,35 +84,13 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) { } }() - err = manager.InsertRoutingRules(testCase.InputPair) + err = manager.AddNatRule(testCase.InputPair) require.NoError(t, err, "forwarding pair should be inserted") - forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID) - forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination) + natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) + natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false) - exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...) - require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD) - require.True(t, exists, "forwarding rule should exist") - - foundRule, found := manager.rules[forwardRuleKey] - require.True(t, found, "forwarding rule should exist in the manager map") - require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match") - - inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID) - inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination) - - exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...) - require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD) - require.True(t, exists, "income forwarding rule should exist") - - foundRule, found = manager.rules[inForwardRuleKey] - require.True(t, found, "income forwarding rule should exist in the manager map") - require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match") - - natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID) - natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination) - - exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...) + exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) if testCase.InputPair.Masquerade { require.True(t, exists, "nat rule should be created") @@ -127,8 +103,8 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) { require.False(t, foundNat, "nat rule should not exist in the map") } - inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID) - inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination) + inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) + inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true) exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) @@ -146,7 +122,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) { } } -func TestIptablesManager_RemoveRoutingRules(t *testing.T) { +func TestIptablesManager_RemoveNatRule(t *testing.T) { if !isIptablesSupported() { t.SkipNow() @@ -156,7 +132,7 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) { iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) - manager, err := newRouterManager(context.TODO(), iptablesClient) + manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) require.NoError(t, err, "shouldn't return error") defer func() { _ = manager.Reset() @@ -164,26 +140,14 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) { require.NoError(t, err, "shouldn't return error") - forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID) - forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination) - - err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, forwardRule...) - require.NoError(t, err, "inserting rule should not return error") - - inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID) - inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination) - - err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, inForwardRule...) - require.NoError(t, err, "inserting rule should not return error") - - natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID) - natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination) + natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) + natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false) err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...) require.NoError(t, err, "inserting rule should not return error") - inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID) - inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination) + inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) + inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true) err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...) require.NoError(t, err, "inserting rule should not return error") @@ -191,28 +155,14 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) { err = manager.Reset() require.NoError(t, err, "shouldn't return error") - err = manager.RemoveRoutingRules(testCase.InputPair) + err = manager.RemoveNatRule(testCase.InputPair) require.NoError(t, err, "shouldn't return error") - exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...) - require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD) - require.False(t, exists, "forwarding rule should not exist") - - _, found := manager.rules[forwardRuleKey] - require.False(t, found, "forwarding rule should exist in the manager map") - - exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...) - require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD) - require.False(t, exists, "income forwarding rule should not exist") - - _, found = manager.rules[inForwardRuleKey] - require.False(t, found, "income forwarding rule should exist in the manager map") - - exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...) + exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) require.False(t, exists, "nat rule should not exist") - _, found = manager.rules[natRuleKey] + _, found := manager.rules[natRuleKey] require.False(t, found, "nat rule should exist in the manager map") exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...) @@ -221,7 +171,175 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) { _, found = manager.rules[inNatRuleKey] require.False(t, found, "income nat rule should exist in the manager map") - + }) + } +} + +func TestRouter_AddRouteFiltering(t *testing.T) { + if !isIptablesSupported() { + t.Skip("iptables not supported on this system") + } + + iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) + require.NoError(t, err, "Failed to create iptables client") + + r, err := newRouter(context.Background(), iptablesClient, ifaceMock) + require.NoError(t, err, "Failed to create router manager") + + defer func() { + err := r.Reset() + require.NoError(t, err, "Failed to reset router") + }() + + tests := []struct { + name string + sources []netip.Prefix + destination netip.Prefix + proto firewall.Protocol + sPort *firewall.Port + dPort *firewall.Port + direction firewall.RuleDirection + action firewall.Action + expectSet bool + }{ + { + name: "Basic TCP rule with single source", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")}, + destination: netip.MustParsePrefix("10.0.0.0/24"), + proto: firewall.ProtocolTCP, + sPort: nil, + dPort: &firewall.Port{Values: []int{80}}, + direction: firewall.RuleDirectionIN, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "UDP rule with multiple sources", + sources: []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/16"), + netip.MustParsePrefix("192.168.0.0/16"), + }, + destination: netip.MustParsePrefix("10.0.0.0/8"), + proto: firewall.ProtocolUDP, + sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true}, + dPort: nil, + direction: firewall.RuleDirectionOUT, + action: firewall.ActionDrop, + expectSet: true, + }, + { + name: "All protocols rule", + sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, + destination: netip.MustParsePrefix("0.0.0.0/0"), + proto: firewall.ProtocolALL, + sPort: nil, + dPort: nil, + direction: firewall.RuleDirectionIN, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "ICMP rule", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + destination: netip.MustParsePrefix("10.0.0.0/8"), + proto: firewall.ProtocolICMP, + sPort: nil, + dPort: nil, + direction: firewall.RuleDirectionIN, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "TCP rule with multiple source ports", + sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")}, + destination: netip.MustParsePrefix("192.168.0.0/16"), + proto: firewall.ProtocolTCP, + sPort: &firewall.Port{Values: []int{80, 443, 8080}}, + dPort: nil, + direction: firewall.RuleDirectionOUT, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "UDP rule with single IP and port range", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")}, + destination: netip.MustParsePrefix("10.0.0.0/24"), + proto: firewall.ProtocolUDP, + sPort: nil, + dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true}, + direction: firewall.RuleDirectionIN, + action: firewall.ActionDrop, + expectSet: false, + }, + { + name: "TCP rule with source and destination ports", + sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}, + destination: netip.MustParsePrefix("172.16.0.0/16"), + proto: firewall.ProtocolTCP, + sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true}, + dPort: &firewall.Port{Values: []int{22}}, + direction: firewall.RuleDirectionOUT, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "Drop all incoming traffic", + sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, + destination: netip.MustParsePrefix("192.168.0.0/24"), + proto: firewall.ProtocolALL, + sPort: nil, + dPort: nil, + direction: firewall.RuleDirectionIN, + action: firewall.ActionDrop, + expectSet: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) + require.NoError(t, err, "AddRouteFiltering failed") + + // Check if the rule is in the internal map + rule, ok := r.rules[ruleKey.GetRuleID()] + assert.True(t, ok, "Rule not found in internal map") + + // Log the internal rule + t.Logf("Internal rule: %v", rule) + + // Check if the rule exists in iptables + exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, rule...) + assert.NoError(t, err, "Failed to check rule existence") + assert.True(t, exists, "Rule not found in iptables") + + // Verify rule content + params := routeFilteringRuleParams{ + Sources: tt.sources, + Destination: tt.destination, + Proto: tt.proto, + SPort: tt.sPort, + DPort: tt.dPort, + Action: tt.action, + SetName: "", + } + + expectedRule := genRouteFilteringRuleSpec(params) + + if tt.expectSet { + setName := firewall.GenerateSetName(tt.sources) + params.SetName = setName + expectedRule = genRouteFilteringRuleSpec(params) + + // Check if the set was created + _, exists := r.ipsetCounter.Get(setName) + assert.True(t, exists, "IPSet not created") + } + + assert.Equal(t, expectedRule, rule, "Rule content mismatch") + + // Clean up + err = r.DeleteRouteRule(ruleKey) + require.NoError(t, err, "Failed to delete rule") }) } } diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 6e4edb63e..a6185d370 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -1,15 +1,21 @@ package manager import ( + "crypto/sha256" + "encoding/hex" "fmt" "net" + "net/netip" + "sort" + "strings" + + log "github.com/sirupsen/logrus" ) const ( - NatFormat = "netbird-nat-%s" - ForwardingFormat = "netbird-fwd-%s" - InNatFormat = "netbird-nat-in-%s" - InForwardingFormat = "netbird-fwd-in-%s" + ForwardingFormatPrefix = "netbird-fwd-" + ForwardingFormat = "netbird-fwd-%s-%t" + NatFormat = "netbird-nat-%s-%t" ) // Rule abstraction should be implemented by each firewall manager @@ -49,11 +55,11 @@ type Manager interface { // AllowNetbird allows netbird interface traffic AllowNetbird() error - // AddFiltering rule to the firewall + // AddPeerFiltering adds a rule to the firewall // // If comment argument is empty firewall manager should set // rule ID as comment for the rule - AddFiltering( + AddPeerFiltering( ip net.IP, proto Protocol, sPort *Port, @@ -64,17 +70,25 @@ type Manager interface { comment string, ) ([]Rule, error) - // DeleteRule from the firewall by rule definition - DeleteRule(rule Rule) error + // DeletePeerRule from the firewall by rule definition + DeletePeerRule(rule Rule) error // IsServerRouteSupported returns true if the firewall supports server side routing operations IsServerRouteSupported() bool - // InsertRoutingRules inserts a routing firewall rule - InsertRoutingRules(pair RouterPair) error + AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error) - // RemoveRoutingRules removes a routing firewall rule - RemoveRoutingRules(pair RouterPair) error + // DeleteRouteRule deletes a routing rule + DeleteRouteRule(rule Rule) error + + // AddNatRule inserts a routing NAT rule + AddNatRule(pair RouterPair) error + + // RemoveNatRule removes a routing NAT rule + RemoveNatRule(pair RouterPair) error + + // SetLegacyManagement sets the legacy management mode + SetLegacyManagement(legacy bool) error // Reset firewall to the default state Reset() error @@ -83,6 +97,89 @@ type Manager interface { Flush() error } -func GenKey(format string, input string) string { - return fmt.Sprintf(format, input) +func GenKey(format string, pair RouterPair) string { + return fmt.Sprintf(format, pair.ID, pair.Inverse) +} + +// LegacyManager defines the interface for legacy management operations +type LegacyManager interface { + RemoveAllLegacyRouteRules() error + GetLegacyManagement() bool + SetLegacyManagement(bool) +} + +// SetLegacyManagement sets the route manager to use legacy management +func SetLegacyManagement(router LegacyManager, isLegacy bool) error { + oldLegacy := router.GetLegacyManagement() + + if oldLegacy != isLegacy { + router.SetLegacyManagement(isLegacy) + log.Debugf("Set legacy management to %v", isLegacy) + } + + // client reconnected to a newer mgmt, we need to clean up the legacy rules + if !isLegacy && oldLegacy { + if err := router.RemoveAllLegacyRouteRules(); err != nil { + return fmt.Errorf("remove legacy routing rules: %v", err) + } + + log.Debugf("Legacy routing rules removed") + } + + return nil +} + +// GenerateSetName generates a unique name for an ipset based on the given sources. +func GenerateSetName(sources []netip.Prefix) string { + // sort for consistent naming + sortPrefixes(sources) + + var sourcesStr strings.Builder + for _, src := range sources { + sourcesStr.WriteString(src.String()) + } + + hash := sha256.Sum256([]byte(sourcesStr.String())) + shortHash := hex.EncodeToString(hash[:])[:8] + + return fmt.Sprintf("nb-%s", shortHash) +} + +// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix +func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix { + if len(prefixes) == 0 { + return prefixes + } + + merged := []netip.Prefix{prefixes[0]} + for _, prefix := range prefixes[1:] { + last := merged[len(merged)-1] + if last.Contains(prefix.Addr()) { + // If the current prefix is contained within the last merged prefix, skip it + continue + } + if prefix.Contains(last.Addr()) { + // If the current prefix contains the last merged prefix, replace it + merged[len(merged)-1] = prefix + } else { + // Otherwise, add the current prefix to the merged list + merged = append(merged, prefix) + } + } + + return merged +} + +// sortPrefixes sorts the given slice of netip.Prefix in place. +// It sorts first by IP address, then by prefix length (most specific to least specific). +func sortPrefixes(prefixes []netip.Prefix) { + sort.Slice(prefixes, func(i, j int) bool { + addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr()) + if addrCmp != 0 { + return addrCmp < 0 + } + + // If IP addresses are the same, compare prefix lengths (longer prefixes first) + return prefixes[i].Bits() > prefixes[j].Bits() + }) } diff --git a/client/firewall/manager/firewall_test.go b/client/firewall/manager/firewall_test.go new file mode 100644 index 000000000..3f47d6679 --- /dev/null +++ b/client/firewall/manager/firewall_test.go @@ -0,0 +1,192 @@ +package manager_test + +import ( + "net/netip" + "reflect" + "regexp" + "testing" + + "github.com/netbirdio/netbird/client/firewall/manager" +) + +func TestGenerateSetName(t *testing.T) { + t.Run("Different orders result in same hash", func(t *testing.T) { + prefixes1 := []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("10.0.0.0/8"), + } + prefixes2 := []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("192.168.1.0/24"), + } + + result1 := manager.GenerateSetName(prefixes1) + result2 := manager.GenerateSetName(prefixes2) + + if result1 != result2 { + t.Errorf("Different orders produced different hashes: %s != %s", result1, result2) + } + }) + + t.Run("Result format is correct", func(t *testing.T) { + prefixes := []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("10.0.0.0/8"), + } + + result := manager.GenerateSetName(prefixes) + + matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result) + if err != nil { + t.Fatalf("Error matching regex: %v", err) + } + if !matched { + t.Errorf("Result format is incorrect: %s", result) + } + }) + + t.Run("Empty input produces consistent result", func(t *testing.T) { + result1 := manager.GenerateSetName([]netip.Prefix{}) + result2 := manager.GenerateSetName([]netip.Prefix{}) + + if result1 != result2 { + t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2) + } + }) + + t.Run("IPv4 and IPv6 mixing", func(t *testing.T) { + prefixes1 := []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("2001:db8::/32"), + } + prefixes2 := []netip.Prefix{ + netip.MustParsePrefix("2001:db8::/32"), + netip.MustParsePrefix("192.168.1.0/24"), + } + + result1 := manager.GenerateSetName(prefixes1) + result2 := manager.GenerateSetName(prefixes2) + + if result1 != result2 { + t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2) + } + }) +} + +func TestMergeIPRanges(t *testing.T) { + tests := []struct { + name string + input []netip.Prefix + expected []netip.Prefix + }{ + { + name: "Empty input", + input: []netip.Prefix{}, + expected: []netip.Prefix{}, + }, + { + name: "Single range", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + }, + }, + { + name: "Two non-overlapping ranges", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("10.0.0.0/8"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("10.0.0.0/8"), + }, + }, + { + name: "One range containing another", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + netip.MustParsePrefix("192.168.1.0/24"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + }, + }, + { + name: "One range containing another (different order)", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.0.0/16"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + }, + }, + { + name: "Overlapping ranges", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.1.128/25"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + }, + }, + { + name: "Overlapping ranges (different order)", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.128/25"), + netip.MustParsePrefix("192.168.1.0/24"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + }, + }, + { + name: "Multiple overlapping ranges", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.2.0/24"), + netip.MustParsePrefix("192.168.1.128/25"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + }, + }, + { + name: "Partially overlapping ranges", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/23"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.2.0/25"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/23"), + netip.MustParsePrefix("192.168.2.0/25"), + }, + }, + { + name: "IPv6 ranges", + input: []netip.Prefix{ + netip.MustParsePrefix("2001:db8::/32"), + netip.MustParsePrefix("2001:db8:1::/48"), + netip.MustParsePrefix("2001:db8:2::/48"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("2001:db8::/32"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := manager.MergeIPRanges(tt.input) + if !reflect.DeepEqual(result, tt.expected) { + t.Errorf("MergeIPRanges() = %v, want %v", result, tt.expected) + } + }) + } +} diff --git a/client/firewall/manager/routerpair.go b/client/firewall/manager/routerpair.go index b63a9f104..8c94b7dd4 100644 --- a/client/firewall/manager/routerpair.go +++ b/client/firewall/manager/routerpair.go @@ -1,18 +1,26 @@ package manager +import ( + "net/netip" + + "github.com/netbirdio/netbird/route" +) + type RouterPair struct { - ID string - Source string - Destination string + ID route.ID + Source netip.Prefix + Destination netip.Prefix Masquerade bool + Inverse bool } -func GetInPair(pair RouterPair) RouterPair { +func GetInversePair(pair RouterPair) RouterPair { return RouterPair{ ID: pair.ID, // invert Source/Destination Source: pair.Destination, Destination: pair.Source, Masquerade: pair.Masquerade, + Inverse: true, } } diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index 1fa41b63a..85cba9e1c 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -33,9 +33,10 @@ const ( allowNetbirdInputRuleID = "allow Netbird incoming traffic" ) +const flushError = "flush: %w" + var ( - anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} - postroutingMark = []byte{0xe4, 0x7, 0x0, 0x00} + anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} ) type AclManager struct { @@ -48,7 +49,6 @@ type AclManager struct { chainInputRules *nftables.Chain chainOutputRules *nftables.Chain chainFwFilter *nftables.Chain - chainPrerouting *nftables.Chain ipsetStore *ipsetStore rules map[string]*Rule @@ -64,7 +64,7 @@ type iFaceMapper interface { func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainName string) (*AclManager, error) { // sConn is used for creating sets and adding/removing elements from them // it's differ then rConn (which does create new conn for each flush operation) - // and is permanent. Using same connection for booth type of operations + // and is permanent. Using same connection for both type of operations // overloads netlink with high amount of rules ( > 10000) sConn, err := nftables.New(nftables.AsLasting()) if err != nil { @@ -90,11 +90,11 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainNa return m, nil } -// AddFiltering rule to the firewall +// AddPeerFiltering rule to the firewall // // If comment argument is empty firewall manager should set // rule ID as comment for the rule -func (m *AclManager) AddFiltering( +func (m *AclManager) AddPeerFiltering( ip net.IP, proto firewall.Protocol, sPort *firewall.Port, @@ -120,20 +120,11 @@ func (m *AclManager) AddFiltering( } newRules = append(newRules, ioRule) - if !shouldAddToPrerouting(proto, dPort, direction) { - return newRules, nil - } - - preroutingRule, err := m.addPreroutingFiltering(ipset, proto, dPort, ip) - if err != nil { - return newRules, err - } - newRules = append(newRules, preroutingRule) return newRules, nil } -// DeleteRule from the firewall by rule definition -func (m *AclManager) DeleteRule(rule firewall.Rule) error { +// DeletePeerRule from the firewall by rule definition +func (m *AclManager) DeletePeerRule(rule firewall.Rule) error { r, ok := rule.(*Rule) if !ok { return fmt.Errorf("invalid rule type") @@ -199,8 +190,7 @@ func (m *AclManager) DeleteRule(rule firewall.Rule) error { return nil } -// createDefaultAllowRules In case if the USP firewall manager can use the native firewall manager we must to create allow rules for -// input and output chains +// createDefaultAllowRules creates default allow rules for the input and output chains func (m *AclManager) createDefaultAllowRules() error { expIn := []expr.Any{ &expr.Payload{ @@ -214,13 +204,13 @@ func (m *AclManager) createDefaultAllowRules() error { SourceRegister: 1, DestRegister: 1, Len: 4, - Mask: []byte{0x00, 0x00, 0x00, 0x00}, - Xor: zeroXor, + Mask: []byte{0, 0, 0, 0}, + Xor: []byte{0, 0, 0, 0}, }, // net address &expr.Cmp{ Register: 1, - Data: []byte{0x00, 0x00, 0x00, 0x00}, + Data: []byte{0, 0, 0, 0}, }, &expr.Verdict{ Kind: expr.VerdictAccept, @@ -246,13 +236,13 @@ func (m *AclManager) createDefaultAllowRules() error { SourceRegister: 1, DestRegister: 1, Len: 4, - Mask: []byte{0x00, 0x00, 0x00, 0x00}, - Xor: zeroXor, + Mask: []byte{0, 0, 0, 0}, + Xor: []byte{0, 0, 0, 0}, }, // net address &expr.Cmp{ Register: 1, - Data: []byte{0x00, 0x00, 0x00, 0x00}, + Data: []byte{0, 0, 0, 0}, }, &expr.Verdict{ Kind: expr.VerdictAccept, @@ -266,10 +256,8 @@ func (m *AclManager) createDefaultAllowRules() error { Exprs: expOut, }) - err := m.rConn.Flush() - if err != nil { - log.Debugf("failed to create default allow rules: %s", err) - return err + if err := m.rConn.Flush(); err != nil { + return fmt.Errorf(flushError, err) } return nil } @@ -290,15 +278,11 @@ func (m *AclManager) Flush() error { log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err) } - if err := m.refreshRuleHandles(m.chainPrerouting); err != nil { - log.Errorf("failed to refresh rule handles IPv4 prerouting chain: %v", err) - } - return nil } func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) { - ruleId := generateRuleId(ip, sPort, dPort, direction, action, ipset) + ruleId := generatePeerRuleId(ip, sPort, dPort, direction, action, ipset) if r, ok := m.rules[ruleId]; ok { return &Rule{ r.nftRule, @@ -308,18 +292,7 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f }, nil } - ifaceKey := expr.MetaKeyIIFNAME - if direction == firewall.RuleDirectionOUT { - ifaceKey = expr.MetaKeyOIFNAME - } - expressions := []expr.Any{ - &expr.Meta{Key: ifaceKey, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - } + var expressions []expr.Any if proto != firewall.ProtocolALL { expressions = append(expressions, &expr.Payload{ @@ -329,21 +302,15 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f Len: uint32(1), }) - var protoData []byte - switch proto { - case firewall.ProtocolTCP: - protoData = []byte{unix.IPPROTO_TCP} - case firewall.ProtocolUDP: - protoData = []byte{unix.IPPROTO_UDP} - case firewall.ProtocolICMP: - protoData = []byte{unix.IPPROTO_ICMP} - default: - return nil, fmt.Errorf("unsupported protocol: %s", proto) + protoData, err := protoToInt(proto) + if err != nil { + return nil, fmt.Errorf("convert protocol to number: %v", err) } + expressions = append(expressions, &expr.Cmp{ Register: 1, Op: expr.CmpOpEq, - Data: protoData, + Data: []byte{protoData}, }) } @@ -432,10 +399,9 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f } else { chain = m.chainOutputRules } - nftRule := m.rConn.InsertRule(&nftables.Rule{ + nftRule := m.rConn.AddRule(&nftables.Rule{ Table: m.workTable, Chain: chain, - Position: 0, Exprs: expressions, UserData: userData, }) @@ -453,139 +419,13 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f return rule, nil } -func (m *AclManager) addPreroutingFiltering(ipset *nftables.Set, proto firewall.Protocol, port *firewall.Port, ip net.IP) (*Rule, error) { - var protoData []byte - switch proto { - case firewall.ProtocolTCP: - protoData = []byte{unix.IPPROTO_TCP} - case firewall.ProtocolUDP: - protoData = []byte{unix.IPPROTO_UDP} - case firewall.ProtocolICMP: - protoData = []byte{unix.IPPROTO_ICMP} - default: - return nil, fmt.Errorf("unsupported protocol: %s", proto) - } - - ruleId := generateRuleIdForMangle(ipset, ip, proto, port) - if r, ok := m.rules[ruleId]; ok { - return &Rule{ - r.nftRule, - r.nftSet, - r.ruleID, - ip, - }, nil - } - - var ipExpression expr.Any - // add individual IP for match if no ipset defined - rawIP := ip.To4() - if ipset == nil { - ipExpression = &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: rawIP, - } - } else { - ipExpression = &expr.Lookup{ - SourceRegister: 1, - SetName: ipset.Name, - SetID: ipset.ID, - } - } - - expressions := []expr.Any{ - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: 12, - Len: 4, - }, - ipExpression, - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: 16, - Len: 4, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: m.wgIface.Address().IP.To4(), - }, - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: uint32(9), - Len: uint32(1), - }, - &expr.Cmp{ - Register: 1, - Op: expr.CmpOpEq, - Data: protoData, - }, - } - - if port != nil { - expressions = append(expressions, - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseTransportHeader, - Offset: 2, - Len: 2, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: encodePort(*port), - }, - ) - } - - expressions = append(expressions, - &expr.Immediate{ - Register: 1, - Data: postroutingMark, - }, - &expr.Meta{ - Key: expr.MetaKeyMARK, - SourceRegister: true, - Register: 1, - }, - ) - - nftRule := m.rConn.InsertRule(&nftables.Rule{ - Table: m.workTable, - Chain: m.chainPrerouting, - Position: 0, - Exprs: expressions, - UserData: []byte(ruleId), - }) - - if err := m.rConn.Flush(); err != nil { - return nil, fmt.Errorf("flush insert rule: %v", err) - } - - rule := &Rule{ - nftRule: nftRule, - nftSet: ipset, - ruleID: ruleId, - ip: ip, - } - - m.rules[ruleId] = rule - if ipset != nil { - m.ipsetStore.AddReferenceToIpset(ipset.Name) - } - return rule, nil -} - func (m *AclManager) createDefaultChains() (err error) { // chainNameInputRules chain := m.createChain(chainNameInputRules) err = m.rConn.Flush() if err != nil { log.Debugf("failed to create chain (%s): %s", chain.Name, err) - return err + return fmt.Errorf(flushError, err) } m.chainInputRules = chain @@ -601,9 +441,6 @@ func (m *AclManager) createDefaultChains() (err error) { // netbird-acl-input-filter // type filter hook input priority filter; policy accept; chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput) - //netbird-acl-input-filter iifname "wt0" ip saddr 100.72.0.0/16 ip daddr != 100.72.0.0/16 accept - m.addRouteAllowRule(chain, expr.MetaKeyIIFNAME) - m.addFwdAllow(chain, expr.MetaKeyIIFNAME) m.addJumpRule(chain, m.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules m.addDropExpressions(chain, expr.MetaKeyIIFNAME) err = m.rConn.Flush() @@ -615,7 +452,6 @@ func (m *AclManager) createDefaultChains() (err error) { // netbird-acl-output-filter // type filter hook output priority filter; policy accept; chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput) - m.addRouteAllowRule(chain, expr.MetaKeyOIFNAME) m.addFwdAllow(chain, expr.MetaKeyOIFNAME) m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules m.addDropExpressions(chain, expr.MetaKeyOIFNAME) @@ -627,24 +463,15 @@ func (m *AclManager) createDefaultChains() (err error) { // netbird-acl-forward-filter m.chainFwFilter = m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward) - m.addJumpRulesToRtForward() // to - m.addMarkAccept() - m.addJumpRuleToInputChain() // to netbird-acl-input-rules + m.addJumpRulesToRtForward() // to netbird-rt-fwd m.addDropExpressions(m.chainFwFilter, expr.MetaKeyIIFNAME) + err = m.rConn.Flush() if err != nil { log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err) - return err + return fmt.Errorf(flushError, err) } - // netbird-acl-output-filter - // type filter hook output priority filter; policy accept; - m.chainPrerouting = m.createPreroutingMangle() - err = m.rConn.Flush() - if err != nil { - log.Debugf("failed to create chain (%s): %s", m.chainPrerouting.Name, err) - return err - } return nil } @@ -667,59 +494,6 @@ func (m *AclManager) addJumpRulesToRtForward() { Chain: m.chainFwFilter, Exprs: expressions, }) - - expressions = []expr.Any{ - &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Verdict{ - Kind: expr.VerdictJump, - Chain: m.routeingFwChainName, - }, - } - - _ = m.rConn.AddRule(&nftables.Rule{ - Table: m.workTable, - Chain: m.chainFwFilter, - Exprs: expressions, - }) -} - -func (m *AclManager) addMarkAccept() { - // oifname "wt0" meta mark 0x000007e4 accept - // iifname "wt0" meta mark 0x000007e4 accept - ifaces := []expr.MetaKey{expr.MetaKeyIIFNAME, expr.MetaKeyOIFNAME} - for _, iface := range ifaces { - expressions := []expr.Any{ - &expr.Meta{Key: iface, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Meta{ - Key: expr.MetaKeyMARK, - Register: 1, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: postroutingMark, - }, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - } - - _ = m.rConn.AddRule(&nftables.Rule{ - Table: m.workTable, - Chain: m.chainFwFilter, - Exprs: expressions, - }) - } } func (m *AclManager) createChain(name string) *nftables.Chain { @@ -729,6 +503,9 @@ func (m *AclManager) createChain(name string) *nftables.Chain { } chain = m.rConn.AddChain(chain) + + insertReturnTrafficRule(m.rConn, m.workTable, chain) + return chain } @@ -746,74 +523,6 @@ func (m *AclManager) createFilterChainWithHook(name string, hookNum nftables.Cha return m.rConn.AddChain(chain) } -func (m *AclManager) createPreroutingMangle() *nftables.Chain { - polAccept := nftables.ChainPolicyAccept - chain := &nftables.Chain{ - Name: "netbird-acl-prerouting-filter", - Table: m.workTable, - Hooknum: nftables.ChainHookPrerouting, - Priority: nftables.ChainPriorityMangle, - Type: nftables.ChainTypeFilter, - Policy: &polAccept, - } - - chain = m.rConn.AddChain(chain) - - ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4()) - expressions := []expr.Any{ - &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: 12, - Len: 4, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: m.wgIface.Address().Network.Mask, - }, - &expr.Cmp{ - Op: expr.CmpOpNeq, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: 16, - Len: 4, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: m.wgIface.Address().IP.To4(), - }, - &expr.Immediate{ - Register: 1, - Data: postroutingMark, - }, - &expr.Meta{ - Key: expr.MetaKeyMARK, - SourceRegister: true, - Register: 1, - }, - } - _ = m.rConn.AddRule(&nftables.Rule{ - Table: m.workTable, - Chain: chain, - Exprs: expressions, - }) - return chain -} - func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any { expressions := []expr.Any{ &expr.Meta{Key: ifaceKey, Register: 1}, @@ -832,101 +541,9 @@ func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.Met return nil } -func (m *AclManager) addJumpRuleToInputChain() { - expressions := []expr.Any{ - &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Verdict{ - Kind: expr.VerdictJump, - Chain: m.chainInputRules.Name, - }, - } - - _ = m.rConn.AddRule(&nftables.Rule{ - Table: m.workTable, - Chain: m.chainFwFilter, - Exprs: expressions, - }) -} - -func (m *AclManager) addRouteAllowRule(chain *nftables.Chain, netIfName expr.MetaKey) { - ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4()) - var srcOp, dstOp expr.CmpOp - if netIfName == expr.MetaKeyIIFNAME { - srcOp = expr.CmpOpEq - dstOp = expr.CmpOpNeq - } else { - srcOp = expr.CmpOpNeq - dstOp = expr.CmpOpEq - } - expressions := []expr.Any{ - &expr.Meta{Key: netIfName, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: 12, - Len: 4, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: m.wgIface.Address().Network.Mask, - }, - &expr.Cmp{ - Op: srcOp, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: 16, - Len: 4, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: m.wgIface.Address().Network.Mask, - }, - &expr.Cmp{ - Op: dstOp, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - } - _ = m.rConn.AddRule(&nftables.Rule{ - Table: chain.Table, - Chain: chain, - Exprs: expressions, - }) -} - func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) { ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4()) - var srcOp, dstOp expr.CmpOp - if iifname == expr.MetaKeyIIFNAME { - srcOp = expr.CmpOpNeq - dstOp = expr.CmpOpEq - } else { - srcOp = expr.CmpOpEq - dstOp = expr.CmpOpNeq - } + dstOp := expr.CmpOpNeq expressions := []expr.Any{ &expr.Meta{Key: iifname, Register: 1}, &expr.Cmp{ @@ -934,24 +551,6 @@ func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) { Register: 1, Data: ifname(m.wgIface.Name()), }, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: 12, - Len: 4, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: m.wgIface.Address().Network.Mask, - }, - &expr.Cmp{ - Op: srcOp, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, &expr.Payload{ DestRegister: 2, Base: expr.PayloadBaseNetworkHeader, @@ -982,7 +581,6 @@ func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) { } func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) { - ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4()) expressions := []expr.Any{ &expr.Meta{Key: ifaceKey, Register: 1}, &expr.Cmp{ @@ -990,47 +588,12 @@ func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr Register: 1, Data: ifname(m.wgIface.Name()), }, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: 12, - Len: 4, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: m.wgIface.Address().Network.Mask, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: 16, - Len: 4, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: m.wgIface.Address().Network.Mask, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, &expr.Verdict{ Kind: expr.VerdictJump, Chain: to, }, } + _ = m.rConn.AddRule(&nftables.Rule{ Table: chain.Table, Chain: chain, @@ -1132,7 +695,7 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error { return nil } -func generateRuleId( +func generatePeerRuleId( ip net.IP, sPort *firewall.Port, dPort *firewall.Port, @@ -1155,33 +718,6 @@ func generateRuleId( } return "set:" + ipset.Name + rulesetID } -func generateRuleIdForMangle(ipset *nftables.Set, ip net.IP, proto firewall.Protocol, port *firewall.Port) string { - // case of icmp port is empty - var p string - if port != nil { - p = port.String() - } - if ipset != nil { - return fmt.Sprintf("p:set:%s:%s:%v", ipset.Name, proto, p) - } else { - return fmt.Sprintf("p:ip:%s:%s:%v", ip.String(), proto, p) - } -} - -func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool { - if proto == "all" { - return false - } - - if direction != firewall.RuleDirectionIN { - return false - } - - if dPort == nil && proto != firewall.ProtocolICMP { - return false - } - return true -} func encodePort(port firewall.Port) []byte { bs := make([]byte, 2) @@ -1191,6 +727,19 @@ func encodePort(port firewall.Port) []byte { func ifname(n string) []byte { b := make([]byte, 16) - copy(b, []byte(n+"\x00")) + copy(b, n+"\x00") return b } + +func protoToInt(protocol firewall.Protocol) (uint8, error) { + switch protocol { + case firewall.ProtocolTCP: + return unix.IPPROTO_TCP, nil + case firewall.ProtocolUDP: + return unix.IPPROTO_UDP, nil + case firewall.ProtocolICMP: + return unix.IPPROTO_ICMP, nil + } + + return 0, fmt.Errorf("unsupported protocol: %s", protocol) +} diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index a376c98c3..d2258ae08 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -5,9 +5,11 @@ import ( "context" "fmt" "net" + "net/netip" "sync" "github.com/google/nftables" + "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" log "github.com/sirupsen/logrus" @@ -15,8 +17,11 @@ import ( ) const ( - // tableName is the name of the table that is used for filtering by the Netbird client - tableName = "netbird" + // tableNameNetbird is the name of the table that is used for filtering by the Netbird client + tableNameNetbird = "netbird" + + tableNameFilter = "filter" + chainNameInput = "INPUT" ) // Manager of iptables firewall @@ -41,12 +46,12 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { return nil, err } - m.router, err = newRouter(context, workTable) + m.router, err = newRouter(context, workTable, wgIface) if err != nil { return nil, err } - m.aclManager, err = newAclManager(workTable, wgIface, m.router.RouteingFwChainName()) + m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw) if err != nil { return nil, err } @@ -54,11 +59,11 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { return m, nil } -// AddFiltering rule to the firewall +// AddPeerFiltering rule to the firewall // // If comment argument is empty firewall manager should set // rule ID as comment for the rule -func (m *Manager) AddFiltering( +func (m *Manager) AddPeerFiltering( ip net.IP, proto firewall.Protocol, sPort *firewall.Port, @@ -76,33 +81,52 @@ func (m *Manager) AddFiltering( return nil, fmt.Errorf("unsupported IP version: %s", ip.String()) } - return m.aclManager.AddFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment) + return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment) } -// DeleteRule from the firewall by rule definition -func (m *Manager) DeleteRule(rule firewall.Rule) error { +func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) { m.mutex.Lock() defer m.mutex.Unlock() - return m.aclManager.DeleteRule(rule) + if !destination.Addr().Is4() { + return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String()) + } + + return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action) +} + +// DeletePeerRule from the firewall by rule definition +func (m *Manager) DeletePeerRule(rule firewall.Rule) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.aclManager.DeletePeerRule(rule) +} + +// DeleteRouteRule deletes a routing rule +func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.DeleteRouteRule(rule) } func (m *Manager) IsServerRouteSupported() bool { return true } -func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error { +func (m *Manager) AddNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.AddRoutingRules(pair) + return m.router.AddNatRule(pair) } -func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error { +func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.RemoveRoutingRules(pair) + return m.router.RemoveNatRule(pair) } // AllowNetbird allows netbird interface traffic @@ -126,7 +150,7 @@ func (m *Manager) AllowNetbird() error { var chain *nftables.Chain for _, c := range chains { - if c.Table.Name == "filter" && c.Name == "INPUT" { + if c.Table.Name == tableNameFilter && c.Name == chainNameForward { chain = c break } @@ -157,6 +181,27 @@ func (m *Manager) AllowNetbird() error { return nil } +// SetLegacyManagement sets the route manager to use legacy management +func (m *Manager) SetLegacyManagement(isLegacy bool) error { + oldLegacy := m.router.legacyManagement + + if oldLegacy != isLegacy { + m.router.legacyManagement = isLegacy + log.Debugf("Set legacy management to %v", isLegacy) + } + + // client reconnected to a newer mgmt, we need to cleanup the legacy rules + if !isLegacy && oldLegacy { + if err := m.router.RemoveAllLegacyRouteRules(); err != nil { + return fmt.Errorf("remove legacy routing rules: %v", err) + } + + log.Debugf("Legacy routing rules removed") + } + + return nil +} + // Reset firewall to the default state func (m *Manager) Reset() error { m.mutex.Lock() @@ -185,14 +230,16 @@ func (m *Manager) Reset() error { } } - m.router.ResetForwardRules() + if err := m.router.Reset(); err != nil { + return fmt.Errorf("reset forward rules: %v", err) + } tables, err := m.rConn.ListTables() if err != nil { return fmt.Errorf("list of tables: %w", err) } for _, t := range tables { - if t.Name == tableName { + if t.Name == tableNameNetbird { m.rConn.DelTable(t) } } @@ -218,12 +265,12 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) { } for _, t := range tables { - if t.Name == tableName { + if t.Name == tableNameNetbird { m.rConn.DelTable(t) } } - table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}) + table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}) err = m.rConn.Flush() return table, err } @@ -239,9 +286,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) { Register: 1, Data: ifname(m.wgIface.Name()), }, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, + &expr.Verdict{}, }, UserData: []byte(allowNetbirdInputRuleID), } @@ -251,7 +296,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) { func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule { ifName := ifname(m.wgIface.Name()) for _, rule := range existedRules { - if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" { + if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput { if len(rule.Exprs) < 4 { if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME { continue @@ -265,3 +310,33 @@ func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftable } return nil } + +func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) { + rule := &nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Ct{ + Key: expr.CtKeySTATE, + Register: 1, + }, + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED), + Xor: binaryutil.NativeEndian.PutUint32(0), + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: []byte{0, 0, 0, 0}, + }, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + }, + } + + conn.InsertRule(rule) +} diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 1f226e315..7f78a9a2e 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/google/nftables" + "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" "github.com/stretchr/testify/require" "golang.org/x/sys/unix" @@ -17,6 +18,21 @@ import ( "github.com/netbirdio/netbird/iface" ) +var ifaceMock = &iFaceMock{ + NameFunc: func() string { + return "lo" + }, + AddressFunc: func() iface.WGAddress { + return iface.WGAddress{ + IP: net.ParseIP("100.96.0.1"), + Network: &net.IPNet{ + IP: net.ParseIP("100.96.0.0"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + } + }, +} + // iFaceMapper defines subset methods of interface required for manager type iFaceMock struct { NameFunc func() string @@ -40,23 +56,9 @@ func (i *iFaceMock) Address() iface.WGAddress { func (i *iFaceMock) IsUserspaceBind() bool { return false } func TestNftablesManager(t *testing.T) { - mock := &iFaceMock{ - NameFunc: func() string { - return "lo" - }, - AddressFunc: func() iface.WGAddress { - return iface.WGAddress{ - IP: net.ParseIP("100.96.0.1"), - Network: &net.IPNet{ - IP: net.ParseIP("100.96.0.0"), - Mask: net.IPv4Mask(255, 255, 255, 0), - }, - } - }, - } // just check on the local interface - manager, err := Create(context.Background(), mock) + manager, err := Create(context.Background(), ifaceMock) require.NoError(t, err) time.Sleep(time.Second * 3) @@ -70,7 +72,7 @@ func TestNftablesManager(t *testing.T) { testClient := &nftables.Conn{} - rule, err := manager.AddFiltering( + rule, err := manager.AddPeerFiltering( ip, fw.ProtocolTCP, nil, @@ -88,17 +90,34 @@ func TestNftablesManager(t *testing.T) { rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules) require.NoError(t, err, "failed to get rules") - require.Len(t, rules, 1, "expected 1 rules") + require.Len(t, rules, 2, "expected 2 rules") + + expectedExprs1 := []expr.Any{ + &expr.Ct{ + Key: expr.CtKeySTATE, + Register: 1, + }, + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED), + Xor: binaryutil.NativeEndian.PutUint32(0), + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: []byte{0, 0, 0, 0}, + }, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + } + require.ElementsMatch(t, rules[0].Exprs, expectedExprs1, "expected the same expressions") ipToAdd, _ := netip.AddrFromSlice(ip) add := ipToAdd.Unmap() - expectedExprs := []expr.Any{ - &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname("lo"), - }, + expectedExprs2 := []expr.Any{ &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, @@ -134,10 +153,10 @@ func TestNftablesManager(t *testing.T) { }, &expr.Verdict{Kind: expr.VerdictDrop}, } - require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions") + require.ElementsMatch(t, rules[1].Exprs, expectedExprs2, "expected the same expressions") for _, r := range rule { - err = manager.DeleteRule(r) + err = manager.DeletePeerRule(r) require.NoError(t, err, "failed to delete rule") } @@ -146,7 +165,8 @@ func TestNftablesManager(t *testing.T) { rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules) require.NoError(t, err, "failed to get rules") - require.Len(t, rules, 0, "expected 0 rules after deletion") + // established rule remains + require.Len(t, rules, 1, "expected 1 rules after deletion") err = manager.Reset() require.NoError(t, err, "failed to reset") @@ -187,9 +207,9 @@ func TestNFtablesCreatePerformance(t *testing.T) { for i := 0; i < testMax; i++ { port := &fw.Port{Values: []int{1000 + i}} if i%2 == 0 { - _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") + _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") } else { - _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") + _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") } require.NoError(t, err, "failed to add rule") diff --git a/client/firewall/nftables/route_linux.go b/client/firewall/nftables/route_linux.go deleted file mode 100644 index 71d5ac88e..000000000 --- a/client/firewall/nftables/route_linux.go +++ /dev/null @@ -1,431 +0,0 @@ -package nftables - -import ( - "bytes" - "context" - "errors" - "fmt" - "net" - "net/netip" - - "github.com/google/nftables" - "github.com/google/nftables/binaryutil" - "github.com/google/nftables/expr" - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/firewall/manager" -) - -const ( - chainNameRouteingFw = "netbird-rt-fwd" - chainNameRoutingNat = "netbird-rt-nat" - - userDataAcceptForwardRuleSrc = "frwacceptsrc" - userDataAcceptForwardRuleDst = "frwacceptdst" - - loopbackInterface = "lo\x00" -) - -// some presets for building nftable rules -var ( - zeroXor = binaryutil.NativeEndian.PutUint32(0) - - exprCounterAccept = []expr.Any{ - &expr.Counter{}, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - } - - errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found") -) - -type router struct { - ctx context.Context - stop context.CancelFunc - conn *nftables.Conn - workTable *nftables.Table - filterTable *nftables.Table - chains map[string]*nftables.Chain - // rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules - rules map[string]*nftables.Rule - isDefaultFwdRulesEnabled bool -} - -func newRouter(parentCtx context.Context, workTable *nftables.Table) (*router, error) { - ctx, cancel := context.WithCancel(parentCtx) - - r := &router{ - ctx: ctx, - stop: cancel, - conn: &nftables.Conn{}, - workTable: workTable, - chains: make(map[string]*nftables.Chain), - rules: make(map[string]*nftables.Rule), - } - - var err error - r.filterTable, err = r.loadFilterTable() - if err != nil { - if errors.Is(err, errFilterTableNotFound) { - log.Warnf("table 'filter' not found for forward rules") - } else { - return nil, err - } - } - - err = r.cleanUpDefaultForwardRules() - if err != nil { - log.Errorf("failed to clean up rules from FORWARD chain: %s", err) - } - - err = r.createContainers() - if err != nil { - log.Errorf("failed to create containers for route: %s", err) - } - return r, err -} - -func (r *router) RouteingFwChainName() string { - return chainNameRouteingFw -} - -// ResetForwardRules cleans existing nftables default forward rules from the system -func (r *router) ResetForwardRules() { - err := r.cleanUpDefaultForwardRules() - if err != nil { - log.Errorf("failed to reset forward rules: %s", err) - } -} - -func (r *router) loadFilterTable() (*nftables.Table, error) { - tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4) - if err != nil { - return nil, fmt.Errorf("nftables: unable to list tables: %v", err) - } - - for _, table := range tables { - if table.Name == "filter" { - return table, nil - } - } - - return nil, errFilterTableNotFound -} - -func (r *router) createContainers() error { - - r.chains[chainNameRouteingFw] = r.conn.AddChain(&nftables.Chain{ - Name: chainNameRouteingFw, - Table: r.workTable, - }) - - r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{ - Name: chainNameRoutingNat, - Table: r.workTable, - Hooknum: nftables.ChainHookPostrouting, - Priority: nftables.ChainPriorityNATSource - 1, - Type: nftables.ChainTypeNAT, - }) - - // Add RETURN rule for loopback interface - loRule := &nftables.Rule{ - Table: r.workTable, - Chain: r.chains[chainNameRoutingNat], - Exprs: []expr.Any{ - &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: []byte(loopbackInterface), - }, - &expr.Verdict{Kind: expr.VerdictReturn}, - }, - } - r.conn.InsertRule(loRule) - - err := r.refreshRulesMap() - if err != nil { - log.Errorf("failed to clean up rules from FORWARD chain: %s", err) - } - - err = r.conn.Flush() - if err != nil { - return fmt.Errorf("nftables: unable to initialize table: %v", err) - } - return nil -} - -// AddRoutingRules appends a nftable rule pair to the forwarding chain and if enabled, to the nat chain -func (r *router) AddRoutingRules(pair manager.RouterPair) error { - err := r.refreshRulesMap() - if err != nil { - return err - } - - err = r.addRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false) - if err != nil { - return err - } - err = r.addRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false) - if err != nil { - return err - } - - if pair.Masquerade { - err = r.addRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true) - if err != nil { - return err - } - err = r.addRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true) - if err != nil { - return err - } - } - - if r.filterTable != nil && !r.isDefaultFwdRulesEnabled { - log.Debugf("add default accept forward rule") - r.acceptForwardRule(pair.Source) - } - - err = r.conn.Flush() - if err != nil { - return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.Destination, err) - } - return nil -} - -// addRoutingRule inserts a nftable rule to the conn client flush queue -func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error { - sourceExp := generateCIDRMatcherExpressions(true, pair.Source) - destExp := generateCIDRMatcherExpressions(false, pair.Destination) - - var expression []expr.Any - if isNat { - expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) // nolint:gocritic - } else { - expression = append(sourceExp, append(destExp, exprCounterAccept...)...) // nolint:gocritic - } - - ruleKey := manager.GenKey(format, pair.ID) - - _, exists := r.rules[ruleKey] - if exists { - err := r.removeRoutingRule(format, pair) - if err != nil { - return err - } - } - - r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ - Table: r.workTable, - Chain: r.chains[chainName], - Exprs: expression, - UserData: []byte(ruleKey), - }) - return nil -} - -func (r *router) acceptForwardRule(sourceNetwork string) { - src := generateCIDRMatcherExpressions(true, sourceNetwork) - dst := generateCIDRMatcherExpressions(false, "0.0.0.0/0") - - var exprs []expr.Any - exprs = append(src, append(dst, &expr.Verdict{ // nolint:gocritic - Kind: expr.VerdictAccept, - })...) - - rule := &nftables.Rule{ - Table: r.filterTable, - Chain: &nftables.Chain{ - Name: "FORWARD", - Table: r.filterTable, - Type: nftables.ChainTypeFilter, - Hooknum: nftables.ChainHookForward, - Priority: nftables.ChainPriorityFilter, - }, - Exprs: exprs, - UserData: []byte(userDataAcceptForwardRuleSrc), - } - - r.conn.AddRule(rule) - - src = generateCIDRMatcherExpressions(true, "0.0.0.0/0") - dst = generateCIDRMatcherExpressions(false, sourceNetwork) - - exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic - Kind: expr.VerdictAccept, - })...) - - rule = &nftables.Rule{ - Table: r.filterTable, - Chain: &nftables.Chain{ - Name: "FORWARD", - Table: r.filterTable, - Type: nftables.ChainTypeFilter, - Hooknum: nftables.ChainHookForward, - Priority: nftables.ChainPriorityFilter, - }, - Exprs: exprs, - UserData: []byte(userDataAcceptForwardRuleDst), - } - r.conn.AddRule(rule) - r.isDefaultFwdRulesEnabled = true -} - -// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains -func (r *router) RemoveRoutingRules(pair manager.RouterPair) error { - err := r.refreshRulesMap() - if err != nil { - return err - } - - err = r.removeRoutingRule(manager.ForwardingFormat, pair) - if err != nil { - return err - } - - err = r.removeRoutingRule(manager.InForwardingFormat, manager.GetInPair(pair)) - if err != nil { - return err - } - - err = r.removeRoutingRule(manager.NatFormat, pair) - if err != nil { - return err - } - - err = r.removeRoutingRule(manager.InNatFormat, manager.GetInPair(pair)) - if err != nil { - return err - } - - if len(r.rules) == 0 { - err := r.cleanUpDefaultForwardRules() - if err != nil { - log.Errorf("failed to clean up rules from FORWARD chain: %s", err) - } - } - - err = r.conn.Flush() - if err != nil { - return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err) - } - log.Debugf("nftables: removed rules for %s", pair.Destination) - return nil -} - -// removeRoutingRule add a nftable rule to the removal queue and delete from rules map -func (r *router) removeRoutingRule(format string, pair manager.RouterPair) error { - ruleKey := manager.GenKey(format, pair.ID) - - rule, found := r.rules[ruleKey] - if found { - ruleType := "forwarding" - if rule.Chain.Type == nftables.ChainTypeNAT { - ruleType = "nat" - } - - err := r.conn.DelRule(rule) - if err != nil { - return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.Destination, err) - } - - log.Debugf("nftables: removing %s rule for %s", ruleType, pair.Destination) - - delete(r.rules, ruleKey) - } - return nil -} - -// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid -// duplicates and to get missing attributes that we don't have when adding new rules -func (r *router) refreshRulesMap() error { - for _, chain := range r.chains { - rules, err := r.conn.GetRules(chain.Table, chain) - if err != nil { - return fmt.Errorf("nftables: unable to list rules: %v", err) - } - for _, rule := range rules { - if len(rule.UserData) > 0 { - r.rules[string(rule.UserData)] = rule - } - } - } - return nil -} - -func (r *router) cleanUpDefaultForwardRules() error { - if r.filterTable == nil { - r.isDefaultFwdRulesEnabled = false - return nil - } - - chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) - if err != nil { - return err - } - - var rules []*nftables.Rule - for _, chain := range chains { - if chain.Table.Name != r.filterTable.Name { - continue - } - if chain.Name != "FORWARD" { - continue - } - - rules, err = r.conn.GetRules(r.filterTable, chain) - if err != nil { - return err - } - } - - for _, rule := range rules { - if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleSrc)) || bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleDst)) { - err := r.conn.DelRule(rule) - if err != nil { - return err - } - } - } - r.isDefaultFwdRulesEnabled = false - return r.conn.Flush() -} - -// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR -func generateCIDRMatcherExpressions(source bool, cidr string) []expr.Any { - ip, network, _ := net.ParseCIDR(cidr) - ipToAdd, _ := netip.AddrFromSlice(ip) - add := ipToAdd.Unmap() - - var offSet uint32 - if source { - offSet = 12 // src offset - } else { - offSet = 16 // dst offset - } - - return []expr.Any{ - // fetch src add - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: offSet, - Len: 4, - }, - // net mask - &expr.Bitwise{ - DestRegister: 1, - SourceRegister: 1, - Len: 4, - Mask: network.Mask, - Xor: zeroXor, - }, - // net address - &expr.Cmp{ - Register: 1, - Data: add.AsSlice(), - }, - } -} diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go new file mode 100644 index 000000000..aa61e1858 --- /dev/null +++ b/client/firewall/nftables/router_linux.go @@ -0,0 +1,798 @@ +package nftables + +import ( + "bytes" + "context" + "encoding/binary" + "errors" + "fmt" + "net" + "net/netip" + "strings" + + "github.com/google/nftables" + "github.com/google/nftables/binaryutil" + "github.com/google/nftables/expr" + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nberrors "github.com/netbirdio/netbird/client/errors" + firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/internal/acl/id" + "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" +) + +const ( + chainNameRoutingFw = "netbird-rt-fwd" + chainNameRoutingNat = "netbird-rt-nat" + chainNameForward = "FORWARD" + + userDataAcceptForwardRuleIif = "frwacceptiif" + userDataAcceptForwardRuleOif = "frwacceptoif" +) + +const refreshRulesMapError = "refresh rules map: %w" + +var ( + errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found") +) + +type router struct { + ctx context.Context + stop context.CancelFunc + conn *nftables.Conn + workTable *nftables.Table + filterTable *nftables.Table + chains map[string]*nftables.Chain + // rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules + rules map[string]*nftables.Rule + ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set] + + wgIface iFaceMapper + legacyManagement bool +} + +func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFaceMapper) (*router, error) { + ctx, cancel := context.WithCancel(parentCtx) + + r := &router{ + ctx: ctx, + stop: cancel, + conn: &nftables.Conn{}, + workTable: workTable, + chains: make(map[string]*nftables.Chain), + rules: make(map[string]*nftables.Rule), + wgIface: wgIface, + } + + r.ipsetCounter = refcounter.New( + r.createIpSet, + r.deleteIpSet, + ) + + var err error + r.filterTable, err = r.loadFilterTable() + if err != nil { + if errors.Is(err, errFilterTableNotFound) { + log.Warnf("table 'filter' not found for forward rules") + } else { + return nil, err + } + } + + err = r.cleanUpDefaultForwardRules() + if err != nil { + log.Errorf("failed to clean up rules from FORWARD chain: %s", err) + } + + err = r.createContainers() + if err != nil { + log.Errorf("failed to create containers for route: %s", err) + } + return r, err +} + +// Reset cleans existing nftables default forward rules from the system +func (r *router) Reset() error { + // clear without deleting the ipsets, the nf table will be deleted by the caller + r.ipsetCounter.Clear() + + return r.cleanUpDefaultForwardRules() +} + +func (r *router) cleanUpDefaultForwardRules() error { + if r.filterTable == nil { + return nil + } + + chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) + if err != nil { + return fmt.Errorf("list chains: %v", err) + } + + for _, chain := range chains { + if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward { + continue + } + + rules, err := r.conn.GetRules(r.filterTable, chain) + if err != nil { + return fmt.Errorf("get rules: %v", err) + } + + for _, rule := range rules { + if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) || + bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) { + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("delete rule: %v", err) + } + } + } + } + + return r.conn.Flush() +} + +func (r *router) loadFilterTable() (*nftables.Table, error) { + tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4) + if err != nil { + return nil, fmt.Errorf("nftables: unable to list tables: %v", err) + } + + for _, table := range tables { + if table.Name == "filter" { + return table, nil + } + } + + return nil, errFilterTableNotFound +} + +func (r *router) createContainers() error { + + r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{ + Name: chainNameRoutingFw, + Table: r.workTable, + }) + + insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw]) + + r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{ + Name: chainNameRoutingNat, + Table: r.workTable, + Hooknum: nftables.ChainHookPostrouting, + Priority: nftables.ChainPriorityNATSource - 1, + Type: nftables.ChainTypeNAT, + }) + + r.acceptForwardRules() + + err := r.refreshRulesMap() + if err != nil { + log.Errorf("failed to clean up rules from FORWARD chain: %s", err) + } + + err = r.conn.Flush() + if err != nil { + return fmt.Errorf("nftables: unable to initialize table: %v", err) + } + return nil +} + +// AddRouteFiltering appends a nftables rule to the routing chain +func (r *router) AddRouteFiltering( + sources []netip.Prefix, + destination netip.Prefix, + proto firewall.Protocol, + sPort *firewall.Port, + dPort *firewall.Port, + action firewall.Action, +) (firewall.Rule, error) { + ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action) + if _, ok := r.rules[string(ruleKey)]; ok { + return ruleKey, nil + } + + chain := r.chains[chainNameRoutingFw] + var exprs []expr.Any + + switch { + case len(sources) == 1 && sources[0].Bits() == 0: + // If it's 0.0.0.0/0, we don't need to add any source matching + case len(sources) == 1: + // If there's only one source, we can use it directly + exprs = append(exprs, generateCIDRMatcherExpressions(true, sources[0])...) + default: + // If there are multiple sources, create or get an ipset + var err error + exprs, err = r.getIpSetExprs(sources, exprs) + if err != nil { + return nil, fmt.Errorf("get ipset expressions: %w", err) + } + } + + // Handle destination + exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...) + + // Handle protocol + if proto != firewall.ProtocolALL { + protoNum, err := protoToInt(proto) + if err != nil { + return nil, fmt.Errorf("convert protocol to number: %w", err) + } + exprs = append(exprs, &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1}) + exprs = append(exprs, &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{protoNum}, + }) + + exprs = append(exprs, applyPort(sPort, true)...) + exprs = append(exprs, applyPort(dPort, false)...) + } + + exprs = append(exprs, &expr.Counter{}) + + var verdict expr.VerdictKind + if action == firewall.ActionAccept { + verdict = expr.VerdictAccept + } else { + verdict = expr.VerdictDrop + } + exprs = append(exprs, &expr.Verdict{Kind: verdict}) + + rule := &nftables.Rule{ + Table: r.workTable, + Chain: chain, + Exprs: exprs, + UserData: []byte(ruleKey), + } + + r.rules[string(ruleKey)] = r.conn.AddRule(rule) + + return ruleKey, r.conn.Flush() +} + +func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) { + setName := firewall.GenerateSetName(sources) + ref, err := r.ipsetCounter.Increment(setName, sources) + if err != nil { + return nil, fmt.Errorf("create or get ipset for sources: %w", err) + } + + exprs = append(exprs, + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: 12, + Len: 4, + }, + &expr.Lookup{ + SourceRegister: 1, + SetName: ref.Out.Name, + SetID: ref.Out.ID, + }, + ) + return exprs, nil +} + +func (r *router) DeleteRouteRule(rule firewall.Rule) error { + if err := r.refreshRulesMap(); err != nil { + return fmt.Errorf(refreshRulesMapError, err) + } + + ruleKey := rule.GetRuleID() + nftRule, exists := r.rules[ruleKey] + if !exists { + log.Debugf("route rule %s not found", ruleKey) + return nil + } + + setName := r.findSetNameInRule(nftRule) + + if err := r.deleteNftRule(nftRule, ruleKey); err != nil { + return fmt.Errorf("delete: %w", err) + } + + if setName != "" { + if _, err := r.ipsetCounter.Decrement(setName); err != nil { + return fmt.Errorf("decrement ipset reference: %w", err) + } + } + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf(flushError, err) + } + + return nil +} + +func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) { + // overlapping prefixes will result in an error, so we need to merge them + sources = firewall.MergeIPRanges(sources) + + set := &nftables.Set{ + Name: setName, + Table: r.workTable, + // required for prefixes + Interval: true, + KeyType: nftables.TypeIPAddr, + } + + var elements []nftables.SetElement + for _, prefix := range sources { + // TODO: Implement IPv6 support + if prefix.Addr().Is6() { + log.Printf("Skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) + continue + } + + // nftables needs half-open intervals [firstIP, lastIP) for prefixes + // e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc + firstIP := prefix.Addr() + lastIP := calculateLastIP(prefix).Next() + + elements = append(elements, + // the nft tool also adds a line like this, see https://github.com/google/nftables/issues/247 + // nftables.SetElement{Key: []byte{0, 0, 0, 0}, IntervalEnd: true}, + nftables.SetElement{Key: firstIP.AsSlice()}, + nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true}, + ) + } + + if err := r.conn.AddSet(set, elements); err != nil { + return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err) + } + + if err := r.conn.Flush(); err != nil { + return nil, fmt.Errorf("flush error: %w", err) + } + + log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2) + + return set, nil +} + +// calculateLastIP determines the last IP in a given prefix. +func calculateLastIP(prefix netip.Prefix) netip.Addr { + hostMask := ^uint32(0) >> prefix.Masked().Bits() + lastIP := uint32FromNetipAddr(prefix.Addr()) | hostMask + + return netip.AddrFrom4(uint32ToBytes(lastIP)) +} + +// Utility function to convert netip.Addr to uint32. +func uint32FromNetipAddr(addr netip.Addr) uint32 { + b := addr.As4() + return binary.BigEndian.Uint32(b[:]) +} + +// Utility function to convert uint32 to a netip-compatible byte slice. +func uint32ToBytes(ip uint32) [4]byte { + var b [4]byte + binary.BigEndian.PutUint32(b[:], ip) + return b +} + +func (r *router) deleteIpSet(setName string, set *nftables.Set) error { + r.conn.DelSet(set) + if err := r.conn.Flush(); err != nil { + return fmt.Errorf(flushError, err) + } + + log.Debugf("Deleted unused ipset %s", setName) + return nil +} + +func (r *router) findSetNameInRule(rule *nftables.Rule) string { + for _, e := range rule.Exprs { + if lookup, ok := e.(*expr.Lookup); ok { + return lookup.SetName + } + } + return "" +} + +func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error { + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("delete rule %s: %w", ruleKey, err) + } + delete(r.rules, ruleKey) + + log.Debugf("removed route rule %s", ruleKey) + + return nil +} + +// AddNatRule appends a nftables rule pair to the nat chain +func (r *router) AddNatRule(pair firewall.RouterPair) error { + if err := r.refreshRulesMap(); err != nil { + return fmt.Errorf(refreshRulesMapError, err) + } + + if r.legacyManagement { + log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination) + if err := r.addLegacyRouteRule(pair); err != nil { + return fmt.Errorf("add legacy routing rule: %w", err) + } + } + + if pair.Masquerade { + if err := r.addNatRule(pair); err != nil { + return fmt.Errorf("add nat rule: %w", err) + } + + if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil { + return fmt.Errorf("add inverse nat rule: %w", err) + } + } + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("nftables: insert rules for %s: %v", pair.Destination, err) + } + + return nil +} + +// addNatRule inserts a nftables rule to the conn client flush queue +func (r *router) addNatRule(pair firewall.RouterPair) error { + sourceExp := generateCIDRMatcherExpressions(true, pair.Source) + destExp := generateCIDRMatcherExpressions(false, pair.Destination) + + dir := expr.MetaKeyIIFNAME + if pair.Inverse { + dir = expr.MetaKeyOIFNAME + } + + intf := ifname(r.wgIface.Name()) + exprs := []expr.Any{ + &expr.Meta{ + Key: dir, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: intf, + }, + } + + exprs = append(exprs, sourceExp...) + exprs = append(exprs, destExp...) + exprs = append(exprs, + &expr.Counter{}, &expr.Masq{}, + ) + + ruleKey := firewall.GenKey(firewall.NatFormat, pair) + + if _, exists := r.rules[ruleKey]; exists { + if err := r.removeNatRule(pair); err != nil { + return fmt.Errorf("remove routing rule: %w", err) + } + } + + r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameRoutingNat], + Exprs: exprs, + UserData: []byte(ruleKey), + }) + return nil +} + +// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls +func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { + sourceExp := generateCIDRMatcherExpressions(true, pair.Source) + destExp := generateCIDRMatcherExpressions(false, pair.Destination) + + exprs := []expr.Any{ + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + } + + expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic + + ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) + + if _, exists := r.rules[ruleKey]; exists { + if err := r.removeLegacyRouteRule(pair); err != nil { + return fmt.Errorf("remove legacy routing rule: %w", err) + } + } + + r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameRoutingFw], + Exprs: expression, + UserData: []byte(ruleKey), + }) + return nil +} + +// removeLegacyRouteRule removes a legacy routing rule for mgmt servers pre route acls +func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error { + ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) + + if rule, exists := r.rules[ruleKey]; exists { + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) + } + + log.Debugf("nftables: removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination) + + delete(r.rules, ruleKey) + } else { + log.Debugf("nftables: legacy forwarding rule %s not found", ruleKey) + } + + return nil +} + +// GetLegacyManagement returns the route manager's legacy management mode +func (r *router) GetLegacyManagement() bool { + return r.legacyManagement +} + +// SetLegacyManagement sets the route manager to use legacy management mode +func (r *router) SetLegacyManagement(isLegacy bool) { + r.legacyManagement = isLegacy +} + +// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls +func (r *router) RemoveAllLegacyRouteRules() error { + if err := r.refreshRulesMap(); err != nil { + return fmt.Errorf(refreshRulesMapError, err) + } + + var merr *multierror.Error + for k, rule := range r.rules { + if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) { + continue + } + if err := r.conn.DelRule(rule); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err)) + } + } + return nberrors.FormatErrorOrNil(merr) +} + +// acceptForwardRules adds iif/oif rules in the filter table/forward chain to make sure +// that our traffic is not dropped by existing rules there. +// The existing FORWARD rules/policies decide outbound traffic towards our interface. +// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. +func (r *router) acceptForwardRules() { + if r.filterTable == nil { + log.Debugf("table 'filter' not found for forward rules, skipping accept rules") + return + } + + intf := ifname(r.wgIface.Name()) + + // Rule for incoming interface (iif) with counter + iifRule := &nftables.Rule{ + Table: r.filterTable, + Chain: &nftables.Chain{ + Name: "FORWARD", + Table: r.filterTable, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityFilter, + }, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: intf, + }, + &expr.Counter{}, + &expr.Verdict{Kind: expr.VerdictAccept}, + }, + UserData: []byte(userDataAcceptForwardRuleIif), + } + r.conn.InsertRule(iifRule) + + // Rule for outgoing interface (oif) with counter + oifRule := &nftables.Rule{ + Table: r.filterTable, + Chain: &nftables.Chain{ + Name: "FORWARD", + Table: r.filterTable, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityFilter, + }, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: intf, + }, + &expr.Ct{ + Key: expr.CtKeySTATE, + Register: 2, + }, + &expr.Bitwise{ + SourceRegister: 2, + DestRegister: 2, + Len: 4, + Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED), + Xor: binaryutil.NativeEndian.PutUint32(0), + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 2, + Data: []byte{0, 0, 0, 0}, + }, + &expr.Counter{}, + &expr.Verdict{Kind: expr.VerdictAccept}, + }, + UserData: []byte(userDataAcceptForwardRuleOif), + } + + r.conn.InsertRule(oifRule) +} + +// RemoveNatRule removes a nftables rule pair from nat chains +func (r *router) RemoveNatRule(pair firewall.RouterPair) error { + if err := r.refreshRulesMap(); err != nil { + return fmt.Errorf(refreshRulesMapError, err) + } + + if err := r.removeNatRule(pair); err != nil { + return fmt.Errorf("remove nat rule: %w", err) + } + + if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { + return fmt.Errorf("remove inverse nat rule: %w", err) + } + + if err := r.removeLegacyRouteRule(pair); err != nil { + return fmt.Errorf("remove legacy routing rule: %w", err) + } + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err) + } + + log.Debugf("nftables: removed rules for %s", pair.Destination) + return nil +} + +// removeNatRule adds a nftables rule to the removal queue and deletes it from the rules map +func (r *router) removeNatRule(pair firewall.RouterPair) error { + ruleKey := firewall.GenKey(firewall.NatFormat, pair) + + if rule, exists := r.rules[ruleKey]; exists { + err := r.conn.DelRule(rule) + if err != nil { + return fmt.Errorf("remove nat rule %s -> %s: %v", pair.Source, pair.Destination, err) + } + + log.Debugf("nftables: removed nat rule %s -> %s", pair.Source, pair.Destination) + + delete(r.rules, ruleKey) + } else { + log.Debugf("nftables: nat rule %s not found", ruleKey) + } + + return nil +} + +// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid +// duplicates and to get missing attributes that we don't have when adding new rules +func (r *router) refreshRulesMap() error { + for _, chain := range r.chains { + rules, err := r.conn.GetRules(chain.Table, chain) + if err != nil { + return fmt.Errorf("nftables: unable to list rules: %v", err) + } + for _, rule := range rules { + if len(rule.UserData) > 0 { + r.rules[string(rule.UserData)] = rule + } + } + } + return nil +} + +// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR +func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any { + var offset uint32 + if source { + offset = 12 // src offset + } else { + offset = 16 // dst offset + } + + ones := prefix.Bits() + // 0.0.0.0/0 doesn't need extra expressions + if ones == 0 { + return nil + } + + mask := net.CIDRMask(ones, 32) + + return []expr.Any{ + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: offset, + Len: 4, + }, + // netmask + &expr.Bitwise{ + DestRegister: 1, + SourceRegister: 1, + Len: 4, + Mask: mask, + Xor: []byte{0, 0, 0, 0}, + }, + // net address + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: prefix.Masked().Addr().AsSlice(), + }, + } +} + +func applyPort(port *firewall.Port, isSource bool) []expr.Any { + if port == nil { + return nil + } + + var exprs []expr.Any + + offset := uint32(2) // Default offset for destination port + if isSource { + offset = 0 // Offset for source port + } + + exprs = append(exprs, &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseTransportHeader, + Offset: offset, + Len: 2, + }) + + if port.IsRange && len(port.Values) == 2 { + // Handle port range + exprs = append(exprs, + &expr.Cmp{ + Op: expr.CmpOpGte, + Register: 1, + Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[0])), + }, + &expr.Cmp{ + Op: expr.CmpOpLte, + Register: 1, + Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[1])), + }, + ) + } else { + // Handle single port or multiple ports + for i, p := range port.Values { + if i > 0 { + // Add a bitwise OR operation between port checks + exprs = append(exprs, &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Mask: []byte{0x00, 0x00, 0xff, 0xff}, + Xor: []byte{0x00, 0x00, 0x00, 0x00}, + }) + } + exprs = append(exprs, &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: binaryutil.BigEndian.PutUint16(uint16(p)), + }) + } + } + + return exprs +} diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index 913fbd5d2..bbf92f3be 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -4,11 +4,15 @@ package nftables import ( "context" + "encoding/binary" + "net/netip" + "os/exec" "testing" "github.com/coreos/go-iptables/iptables" "github.com/google/nftables" "github.com/google/nftables/expr" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" firewall "github.com/netbirdio/netbird/client/firewall/manager" @@ -24,56 +28,50 @@ const ( NFTABLES ) -func TestNftablesManager_InsertRoutingRules(t *testing.T) { +func TestNftablesManager_AddNatRule(t *testing.T) { if check() != NFTABLES { t.Skip("nftables not supported on this OS") } table, err := createWorkTable() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Failed to create work table") defer deleteWorkTable() for _, testCase := range test.InsertRuleTestCases { t.Run(testCase.Name, func(t *testing.T) { - manager, err := newRouter(context.TODO(), table) + manager, err := newRouter(context.TODO(), table, ifaceMock) require.NoError(t, err, "failed to create router") nftablesTestingClient := &nftables.Conn{} - defer manager.ResetForwardRules() + defer func(manager *router) { + require.NoError(t, manager.Reset(), "failed to reset rules") + }(manager) require.NoError(t, err, "shouldn't return error") - err = manager.AddRoutingRules(testCase.InputPair) - defer func() { - _ = manager.RemoveRoutingRules(testCase.InputPair) - }() - require.NoError(t, err, "forwarding pair should be inserted") + err = manager.AddNatRule(testCase.InputPair) + require.NoError(t, err, "pair should be inserted") - sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) - destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) - testingExpression := append(sourceExp, destExp...) //nolint:gocritic - fwdRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID) - - found := 0 - for _, chain := range manager.chains { - rules, err := nftablesTestingClient.GetRules(chain.Table, chain) - require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) - for _, rule := range rules { - if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey { - require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match") - found = 1 - } - } - } - - require.Equal(t, 1, found, "should find at least 1 rule to test") + defer func(manager *router, pair firewall.RouterPair) { + require.NoError(t, manager.RemoveNatRule(pair), "failed to remove rule") + }(manager, testCase.InputPair) if testCase.InputPair.Masquerade { - natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID) + sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) + destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) + testingExpression := append(sourceExp, destExp...) //nolint:gocritic + testingExpression = append(testingExpression, + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(ifaceMock.Name()), + }, + ) + + natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) found := 0 for _, chain := range manager.chains { rules, err := nftablesTestingClient.GetRules(chain.Table, chain) @@ -88,27 +86,20 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) { require.Equal(t, 1, found, "should find at least 1 rule to test") } - sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source) - destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination) - testingExpression = append(sourceExp, destExp...) //nolint:gocritic - inFwdRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID) - - found = 0 - for _, chain := range manager.chains { - rules, err := nftablesTestingClient.GetRules(chain.Table, chain) - require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) - for _, rule := range rules { - if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey { - require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match") - found = 1 - } - } - } - - require.Equal(t, 1, found, "should find at least 1 rule to test") - if testCase.InputPair.Masquerade { - inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID) + sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) + destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) + testingExpression := append(sourceExp, destExp...) //nolint:gocritic + testingExpression = append(testingExpression, + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(ifaceMock.Name()), + }, + ) + + inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) found := 0 for _, chain := range manager.chains { rules, err := nftablesTestingClient.GetRules(chain.Table, chain) @@ -122,45 +113,37 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) { } require.Equal(t, 1, found, "should find at least 1 rule to test") } + }) } } -func TestNftablesManager_RemoveRoutingRules(t *testing.T) { +func TestNftablesManager_RemoveNatRule(t *testing.T) { if check() != NFTABLES { t.Skip("nftables not supported on this OS") } table, err := createWorkTable() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Failed to create work table") defer deleteWorkTable() for _, testCase := range test.RemoveRuleTestCases { t.Run(testCase.Name, func(t *testing.T) { - manager, err := newRouter(context.TODO(), table) + manager, err := newRouter(context.TODO(), table, ifaceMock) require.NoError(t, err, "failed to create router") nftablesTestingClient := &nftables.Conn{} - defer manager.ResetForwardRules() + defer func(manager *router) { + require.NoError(t, manager.Reset(), "failed to reset rules") + }(manager) sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) - forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic - forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID) - insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{ - Table: manager.workTable, - Chain: manager.chains[chainNameRouteingFw], - Exprs: forwardExp, - UserData: []byte(forwardRuleKey), - }) - natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic - natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID) + natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{ Table: manager.workTable, @@ -169,20 +152,11 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) { UserData: []byte(natRuleKey), }) - sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source) - destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination) - - forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic - inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID) - insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{ - Table: manager.workTable, - Chain: manager.chains[chainNameRouteingFw], - Exprs: forwardExp, - UserData: []byte(inForwardRuleKey), - }) + sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInversePair(testCase.InputPair).Source) + destExp = generateCIDRMatcherExpressions(false, firewall.GetInversePair(testCase.InputPair).Destination) natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic - inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID) + inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{ Table: manager.workTable, @@ -194,9 +168,10 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) { err = nftablesTestingClient.Flush() require.NoError(t, err, "shouldn't return error") - manager.ResetForwardRules() + err = manager.Reset() + require.NoError(t, err, "shouldn't return error") - err = manager.RemoveRoutingRules(testCase.InputPair) + err = manager.RemoveNatRule(testCase.InputPair) require.NoError(t, err, "shouldn't return error") for _, chain := range manager.chains { @@ -204,9 +179,7 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) { require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) for _, rule := range rules { if len(rule.UserData) > 0 { - require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist") require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist") - require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist") require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist") } } @@ -215,6 +188,468 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) { } } +func TestRouter_AddRouteFiltering(t *testing.T) { + if check() != NFTABLES { + t.Skip("nftables not supported on this system") + } + + workTable, err := createWorkTable() + require.NoError(t, err, "Failed to create work table") + + defer deleteWorkTable() + + r, err := newRouter(context.Background(), workTable, ifaceMock) + require.NoError(t, err, "Failed to create router") + + defer func(r *router) { + require.NoError(t, r.Reset(), "Failed to reset rules") + }(r) + + tests := []struct { + name string + sources []netip.Prefix + destination netip.Prefix + proto firewall.Protocol + sPort *firewall.Port + dPort *firewall.Port + direction firewall.RuleDirection + action firewall.Action + expectSet bool + }{ + { + name: "Basic TCP rule with single source", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")}, + destination: netip.MustParsePrefix("10.0.0.0/24"), + proto: firewall.ProtocolTCP, + sPort: nil, + dPort: &firewall.Port{Values: []int{80}}, + direction: firewall.RuleDirectionIN, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "UDP rule with multiple sources", + sources: []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/16"), + netip.MustParsePrefix("192.168.0.0/16"), + }, + destination: netip.MustParsePrefix("10.0.0.0/8"), + proto: firewall.ProtocolUDP, + sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true}, + dPort: nil, + direction: firewall.RuleDirectionOUT, + action: firewall.ActionDrop, + expectSet: true, + }, + { + name: "All protocols rule", + sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, + destination: netip.MustParsePrefix("0.0.0.0/0"), + proto: firewall.ProtocolALL, + sPort: nil, + dPort: nil, + direction: firewall.RuleDirectionIN, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "ICMP rule", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + destination: netip.MustParsePrefix("10.0.0.0/8"), + proto: firewall.ProtocolICMP, + sPort: nil, + dPort: nil, + direction: firewall.RuleDirectionIN, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "TCP rule with multiple source ports", + sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")}, + destination: netip.MustParsePrefix("192.168.0.0/16"), + proto: firewall.ProtocolTCP, + sPort: &firewall.Port{Values: []int{80, 443, 8080}}, + dPort: nil, + direction: firewall.RuleDirectionOUT, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "UDP rule with single IP and port range", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")}, + destination: netip.MustParsePrefix("10.0.0.0/24"), + proto: firewall.ProtocolUDP, + sPort: nil, + dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true}, + direction: firewall.RuleDirectionIN, + action: firewall.ActionDrop, + expectSet: false, + }, + { + name: "TCP rule with source and destination ports", + sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}, + destination: netip.MustParsePrefix("172.16.0.0/16"), + proto: firewall.ProtocolTCP, + sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true}, + dPort: &firewall.Port{Values: []int{22}}, + direction: firewall.RuleDirectionOUT, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "Drop all incoming traffic", + sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, + destination: netip.MustParsePrefix("192.168.0.0/24"), + proto: firewall.ProtocolALL, + sPort: nil, + dPort: nil, + direction: firewall.RuleDirectionIN, + action: firewall.ActionDrop, + expectSet: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) + require.NoError(t, err, "AddRouteFiltering failed") + + // Check if the rule is in the internal map + rule, ok := r.rules[ruleKey.GetRuleID()] + assert.True(t, ok, "Rule not found in internal map") + + t.Log("Internal rule expressions:") + for i, expr := range rule.Exprs { + t.Logf(" [%d] %T: %+v", i, expr, expr) + } + + // Verify internal rule content + verifyRule(t, rule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet) + + // Check if the rule exists in nftables and verify its content + rules, err := r.conn.GetRules(r.workTable, r.chains[chainNameRoutingFw]) + require.NoError(t, err, "Failed to get rules from nftables") + + var nftRule *nftables.Rule + for _, rule := range rules { + if string(rule.UserData) == ruleKey.GetRuleID() { + nftRule = rule + break + } + } + + require.NotNil(t, nftRule, "Rule not found in nftables") + t.Log("Actual nftables rule expressions:") + for i, expr := range nftRule.Exprs { + t.Logf(" [%d] %T: %+v", i, expr, expr) + } + + // Verify actual nftables rule content + verifyRule(t, nftRule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet) + + // Clean up + err = r.DeleteRouteRule(ruleKey) + require.NoError(t, err, "Failed to delete rule") + }) + } +} + +func TestNftablesCreateIpSet(t *testing.T) { + if check() != NFTABLES { + t.Skip("nftables not supported on this system") + } + + workTable, err := createWorkTable() + require.NoError(t, err, "Failed to create work table") + + defer deleteWorkTable() + + r, err := newRouter(context.Background(), workTable, ifaceMock) + require.NoError(t, err, "Failed to create router") + + defer func() { + require.NoError(t, r.Reset(), "Failed to reset router") + }() + + tests := []struct { + name string + sources []netip.Prefix + expected []netip.Prefix + }{ + { + name: "Single IP", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")}, + }, + { + name: "Multiple IPs", + sources: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.1/32"), + netip.MustParsePrefix("10.0.0.1/32"), + netip.MustParsePrefix("172.16.0.1/32"), + }, + }, + { + name: "Single Subnet", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}, + }, + { + name: "Multiple Subnets with Various Prefix Lengths", + sources: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("172.16.0.0/16"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("203.0.113.0/26"), + }, + }, + { + name: "Mix of Single IPs and Subnets in Different Positions", + sources: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.1/32"), + netip.MustParsePrefix("10.0.0.0/16"), + netip.MustParsePrefix("172.16.0.1/32"), + netip.MustParsePrefix("203.0.113.0/24"), + }, + }, + { + name: "Overlapping IPs/Subnets", + sources: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("10.0.0.0/16"), + netip.MustParsePrefix("10.0.0.1/32"), + netip.MustParsePrefix("192.168.0.0/16"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.1.1/32"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("192.168.0.0/16"), + }, + }, + } + + // Add this helper function inside TestNftablesCreateIpSet + printNftSets := func() { + cmd := exec.Command("nft", "list", "sets") + output, err := cmd.CombinedOutput() + if err != nil { + t.Logf("Failed to run 'nft list sets': %v", err) + } else { + t.Logf("Current nft sets:\n%s", output) + } + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setName := firewall.GenerateSetName(tt.sources) + set, err := r.createIpSet(setName, tt.sources) + if err != nil { + t.Logf("Failed to create IP set: %v", err) + printNftSets() + require.NoError(t, err, "Failed to create IP set") + } + require.NotNil(t, set, "Created set is nil") + + // Verify set properties + assert.Equal(t, setName, set.Name, "Set name mismatch") + assert.Equal(t, r.workTable, set.Table, "Set table mismatch") + assert.True(t, set.Interval, "Set interval property should be true") + assert.Equal(t, nftables.TypeIPAddr, set.KeyType, "Set key type mismatch") + + // Fetch the created set from nftables + fetchedSet, err := r.conn.GetSetByName(r.workTable, setName) + require.NoError(t, err, "Failed to fetch created set") + require.NotNil(t, fetchedSet, "Fetched set is nil") + + // Verify set elements + elements, err := r.conn.GetSetElements(fetchedSet) + require.NoError(t, err, "Failed to get set elements") + + // Count the number of unique prefixes (excluding interval end markers) + uniquePrefixes := make(map[string]bool) + for _, elem := range elements { + if !elem.IntervalEnd { + ip := netip.AddrFrom4(*(*[4]byte)(elem.Key)) + uniquePrefixes[ip.String()] = true + } + } + + // Check against expected merged prefixes + expectedCount := len(tt.expected) + if expectedCount == 0 { + expectedCount = len(tt.sources) + } + assert.Equal(t, expectedCount, len(uniquePrefixes), "Number of unique prefixes in set doesn't match expected") + + // Verify each expected prefix is in the set + for _, expected := range tt.expected { + found := false + for _, elem := range elements { + if !elem.IntervalEnd { + ip := netip.AddrFrom4(*(*[4]byte)(elem.Key)) + if expected.Contains(ip) { + found = true + break + } + } + } + assert.True(t, found, "Expected prefix %s not found in set", expected) + } + + r.conn.DelSet(set) + if err := r.conn.Flush(); err != nil { + t.Logf("Failed to delete set: %v", err) + printNftSets() + } + require.NoError(t, err, "Failed to delete set") + }) + } +} + +func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) { + t.Helper() + + assert.NotNil(t, rule, "Rule should not be nil") + + // Verify sources and destination + if expectSet { + assert.True(t, containsSetLookup(rule.Exprs), "Rule should contain set lookup for multiple sources") + } else if len(sources) == 1 && sources[0].Bits() != 0 { + if direction == firewall.RuleDirectionIN { + assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], true), "Rule should contain source CIDR matcher for %s", sources[0]) + } else { + assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], false), "Rule should contain destination CIDR matcher for %s", sources[0]) + } + } + + if direction == firewall.RuleDirectionIN { + assert.True(t, containsCIDRMatcher(rule.Exprs, destination, false), "Rule should contain destination CIDR matcher for %s", destination) + } else { + assert.True(t, containsCIDRMatcher(rule.Exprs, destination, true), "Rule should contain source CIDR matcher for %s", destination) + } + + // Verify protocol + if proto != firewall.ProtocolALL { + assert.True(t, containsProtocol(rule.Exprs, proto), "Rule should contain protocol matcher for %s", proto) + } + + // Verify ports + if sPort != nil { + assert.True(t, containsPort(rule.Exprs, sPort, true), "Rule should contain source port matcher for %v", sPort) + } + if dPort != nil { + assert.True(t, containsPort(rule.Exprs, dPort, false), "Rule should contain destination port matcher for %v", dPort) + } + + // Verify action + assert.True(t, containsAction(rule.Exprs, action), "Rule should contain correct action: %s", action) +} + +func containsSetLookup(exprs []expr.Any) bool { + for _, e := range exprs { + if _, ok := e.(*expr.Lookup); ok { + return true + } + } + return false +} + +func containsCIDRMatcher(exprs []expr.Any, prefix netip.Prefix, isSource bool) bool { + var offset uint32 + if isSource { + offset = 12 // src offset + } else { + offset = 16 // dst offset + } + + var payloadFound, bitwiseFound, cmpFound bool + for _, e := range exprs { + switch ex := e.(type) { + case *expr.Payload: + if ex.Base == expr.PayloadBaseNetworkHeader && ex.Offset == offset && ex.Len == 4 { + payloadFound = true + } + case *expr.Bitwise: + if ex.Len == 4 && len(ex.Mask) == 4 && len(ex.Xor) == 4 { + bitwiseFound = true + } + case *expr.Cmp: + if ex.Op == expr.CmpOpEq && len(ex.Data) == 4 { + cmpFound = true + } + } + } + return (payloadFound && bitwiseFound && cmpFound) || prefix.Bits() == 0 +} + +func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool { + var offset uint32 = 2 // Default offset for destination port + if isSource { + offset = 0 // Offset for source port + } + + var payloadFound, portMatchFound bool + for _, e := range exprs { + switch ex := e.(type) { + case *expr.Payload: + if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 { + payloadFound = true + } + case *expr.Cmp: + if port.IsRange { + if ex.Op == expr.CmpOpGte || ex.Op == expr.CmpOpLte { + portMatchFound = true + } + } else { + if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 { + portValue := binary.BigEndian.Uint16(ex.Data) + for _, p := range port.Values { + if uint16(p) == portValue { + portMatchFound = true + break + } + } + } + } + } + if payloadFound && portMatchFound { + return true + } + } + return false +} + +func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool { + var metaFound, cmpFound bool + expectedProto, _ := protoToInt(proto) + for _, e := range exprs { + switch ex := e.(type) { + case *expr.Meta: + if ex.Key == expr.MetaKeyL4PROTO { + metaFound = true + } + case *expr.Cmp: + if ex.Op == expr.CmpOpEq && len(ex.Data) == 1 && ex.Data[0] == expectedProto { + cmpFound = true + } + } + } + return metaFound && cmpFound +} + +func containsAction(exprs []expr.Any, action firewall.Action) bool { + for _, e := range exprs { + if verdict, ok := e.(*expr.Verdict); ok { + switch action { + case firewall.ActionAccept: + return verdict.Kind == expr.VerdictAccept + case firewall.ActionDrop: + return verdict.Kind == expr.VerdictDrop + } + } + } + return false +} + // check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found. func check() int { nf := nftables.Conn{} @@ -250,12 +685,12 @@ func createWorkTable() (*nftables.Table, error) { } for _, t := range tables { - if t.Name == tableName { + if t.Name == tableNameNetbird { sConn.DelTable(t) } } - table := sConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}) + table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}) err = sConn.Flush() return table, err @@ -273,7 +708,7 @@ func deleteWorkTable() { } for _, t := range tables { - if t.Name == tableName { + if t.Name == tableNameNetbird { sConn.DelTable(t) } } diff --git a/client/firewall/test/cases_linux.go b/client/firewall/test/cases_linux.go index 432d113dd..267e93efd 100644 --- a/client/firewall/test/cases_linux.go +++ b/client/firewall/test/cases_linux.go @@ -1,8 +1,10 @@ -//go:build !android - package test -import firewall "github.com/netbirdio/netbird/client/firewall/manager" +import ( + "net/netip" + + firewall "github.com/netbirdio/netbird/client/firewall/manager" +) var ( InsertRuleTestCases = []struct { @@ -13,8 +15,8 @@ var ( Name: "Insert Forwarding IPV4 Rule", InputPair: firewall.RouterPair{ ID: "zxa", - Source: "100.100.100.1/32", - Destination: "100.100.200.0/24", + Source: netip.MustParsePrefix("100.100.100.1/32"), + Destination: netip.MustParsePrefix("100.100.200.0/24"), Masquerade: false, }, }, @@ -22,8 +24,8 @@ var ( Name: "Insert Forwarding And Nat IPV4 Rules", InputPair: firewall.RouterPair{ ID: "zxa", - Source: "100.100.100.1/32", - Destination: "100.100.200.0/24", + Source: netip.MustParsePrefix("100.100.100.1/32"), + Destination: netip.MustParsePrefix("100.100.200.0/24"), Masquerade: true, }, }, @@ -38,8 +40,8 @@ var ( Name: "Remove Forwarding And Nat IPV4 Rules", InputPair: firewall.RouterPair{ ID: "zxa", - Source: "100.100.100.1/32", - Destination: "100.100.200.0/24", + Source: netip.MustParsePrefix("100.100.100.1/32"), + Destination: netip.MustParsePrefix("100.100.200.0/24"), Masquerade: true, }, }, diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 75792e9c0..681058ea9 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -3,6 +3,7 @@ package uspfilter import ( "fmt" "net" + "net/netip" "sync" "github.com/google/gopacket" @@ -103,26 +104,26 @@ func (m *Manager) IsServerRouteSupported() bool { } } -func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error { +func (m *Manager) AddNatRule(pair firewall.RouterPair) error { if m.nativeFirewall == nil { return errRouteNotSupported } - return m.nativeFirewall.InsertRoutingRules(pair) + return m.nativeFirewall.AddNatRule(pair) } -// RemoveRoutingRules removes a routing firewall rule -func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error { +// RemoveNatRule removes a routing firewall rule +func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { if m.nativeFirewall == nil { return errRouteNotSupported } - return m.nativeFirewall.RemoveRoutingRules(pair) + return m.nativeFirewall.RemoveNatRule(pair) } -// AddFiltering rule to the firewall +// AddPeerFiltering rule to the firewall // // If comment argument is empty firewall manager should set // rule ID as comment for the rule -func (m *Manager) AddFiltering( +func (m *Manager) AddPeerFiltering( ip net.IP, proto firewall.Protocol, sPort *firewall.Port, @@ -188,8 +189,22 @@ func (m *Manager) AddFiltering( return []firewall.Rule{&r}, nil } -// DeleteRule from the firewall by rule definition -func (m *Manager) DeleteRule(rule firewall.Rule) error { +func (m *Manager) AddRouteFiltering(sources [] netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action ) (firewall.Rule, error) { + if m.nativeFirewall == nil { + return nil, errRouteNotSupported + } + return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action) +} + +func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { + if m.nativeFirewall == nil { + return errRouteNotSupported + } + return m.nativeFirewall.DeleteRouteRule(rule) +} + +// DeletePeerRule from the firewall by rule definition +func (m *Manager) DeletePeerRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -215,6 +230,11 @@ func (m *Manager) DeleteRule(rule firewall.Rule) error { return nil } +// SetLegacyManagement doesn't need to be implemented for this manager +func (m *Manager) SetLegacyManagement(_ bool) error { + return nil +} + // Flush doesn't need to be implemented for this manager func (m *Manager) Flush() error { return nil } @@ -395,7 +415,7 @@ func (m *Manager) RemovePacketHook(hookID string) error { for _, r := range arr { if r.id == hookID { rule := r - return m.DeleteRule(&rule) + return m.DeletePeerRule(&rule) } } } @@ -403,7 +423,7 @@ func (m *Manager) RemovePacketHook(hookID string) error { for _, r := range arr { if r.id == hookID { rule := r - return m.DeleteRule(&rule) + return m.DeletePeerRule(&rule) } } } diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index 514a90539..dd7366fe9 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -49,7 +49,7 @@ func TestManagerCreate(t *testing.T) { } } -func TestManagerAddFiltering(t *testing.T) { +func TestManagerAddPeerFiltering(t *testing.T) { isSetFilterCalled := false ifaceMock := &IFaceMock{ SetFilterFunc: func(iface.PacketFilter) error { @@ -71,7 +71,7 @@ func TestManagerAddFiltering(t *testing.T) { action := fw.ActionDrop comment := "Test rule" - rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment) + rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) if err != nil { t.Errorf("failed to add filtering: %v", err) return @@ -106,7 +106,7 @@ func TestManagerDeleteRule(t *testing.T) { action := fw.ActionDrop comment := "Test rule" - rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment) + rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) if err != nil { t.Errorf("failed to add filtering: %v", err) return @@ -119,14 +119,14 @@ func TestManagerDeleteRule(t *testing.T) { action = fw.ActionDrop comment = "Test rule 2" - rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment) + rule2, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) if err != nil { t.Errorf("failed to add filtering: %v", err) return } for _, r := range rule { - err = m.DeleteRule(r) + err = m.DeletePeerRule(r) if err != nil { t.Errorf("failed to delete rule: %v", err) return @@ -140,7 +140,7 @@ func TestManagerDeleteRule(t *testing.T) { } for _, r := range rule2 { - err = m.DeleteRule(r) + err = m.DeletePeerRule(r) if err != nil { t.Errorf("failed to delete rule: %v", err) return @@ -252,7 +252,7 @@ func TestManagerReset(t *testing.T) { action := fw.ActionDrop comment := "Test rule" - _, err = m.AddFiltering(ip, proto, nil, port, direction, action, "", comment) + _, err = m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) if err != nil { t.Errorf("failed to add filtering: %v", err) return @@ -290,7 +290,7 @@ func TestNotMatchByIP(t *testing.T) { action := fw.ActionAccept comment := "Test rule" - _, err = m.AddFiltering(ip, proto, nil, nil, direction, action, "", comment) + _, err = m.AddPeerFiltering(ip, proto, nil, nil, direction, action, "", comment) if err != nil { t.Errorf("failed to add filtering: %v", err) return @@ -406,9 +406,9 @@ func TestUSPFilterCreatePerformance(t *testing.T) { for i := 0; i < testMax; i++ { port := &fw.Port{Values: []int{1000 + i}} if i%2 == 0 { - _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") + _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") } else { - _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") + _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") } require.NoError(t, err, "failed to add rule") diff --git a/client/internal/acl/id/id.go b/client/internal/acl/id/id.go new file mode 100644 index 000000000..e27fce439 --- /dev/null +++ b/client/internal/acl/id/id.go @@ -0,0 +1,25 @@ +package id + +import ( + "fmt" + "net/netip" + + "github.com/netbirdio/netbird/client/firewall/manager" +) + +type RuleID string + +func (r RuleID) GetRuleID() string { + return string(r) +} + +func GenerateRouteRuleKey( + sources []netip.Prefix, + destination netip.Prefix, + proto manager.Protocol, + sPort *manager.Port, + dPort *manager.Port, + action manager.Action, +) RuleID { + return RuleID(fmt.Sprintf("%s-%s-%s-%s-%s-%d", sources, destination, proto, sPort, dPort, action)) +} diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index fd2c2c875..ce2a12af1 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "fmt" "net" + "net/netip" "strconv" "sync" "time" @@ -12,6 +13,7 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/ssh" mgmProto "github.com/netbirdio/netbird/management/proto" ) @@ -23,16 +25,18 @@ type Manager interface { // DefaultManager uses firewall manager to handle type DefaultManager struct { - firewall firewall.Manager - ipsetCounter int - rulesPairs map[string][]firewall.Rule - mutex sync.Mutex + firewall firewall.Manager + ipsetCounter int + peerRulesPairs map[id.RuleID][]firewall.Rule + routeRules map[id.RuleID]struct{} + mutex sync.Mutex } func NewDefaultManager(fm firewall.Manager) *DefaultManager { return &DefaultManager{ - firewall: fm, - rulesPairs: make(map[string][]firewall.Rule), + firewall: fm, + peerRulesPairs: make(map[id.RuleID][]firewall.Rule), + routeRules: make(map[id.RuleID]struct{}), } } @@ -46,7 +50,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { start := time.Now() defer func() { total := 0 - for _, pairs := range d.rulesPairs { + for _, pairs := range d.peerRulesPairs { total += len(pairs) } log.Infof( @@ -59,21 +63,34 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { return } - defer func() { - if err := d.firewall.Flush(); err != nil { - log.Error("failed to flush firewall rules: ", err) - } - }() + d.applyPeerACLs(networkMap) + // If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag, + // then the mgmt server is older than the client, and we need to allow all traffic for routes + isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty + if err := d.firewall.SetLegacyManagement(isLegacy); err != nil { + log.Errorf("failed to set legacy management flag: %v", err) + } + + if err := d.applyRouteACLs(networkMap.RoutesFirewallRules); err != nil { + log.Errorf("Failed to apply route ACLs: %v", err) + } + + if err := d.firewall.Flush(); err != nil { + log.Error("failed to flush firewall rules: ", err) + } +} + +func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { rules, squashedProtocols := d.squashAcceptRules(networkMap) - enableSSH := (networkMap.PeerConfig != nil && + enableSSH := networkMap.PeerConfig != nil && networkMap.PeerConfig.SshConfig != nil && - networkMap.PeerConfig.SshConfig.SshEnabled) - if _, ok := squashedProtocols[mgmProto.FirewallRule_ALL]; ok { + networkMap.PeerConfig.SshConfig.SshEnabled + if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok { enableSSH = enableSSH && !ok } - if _, ok := squashedProtocols[mgmProto.FirewallRule_TCP]; ok { + if _, ok := squashedProtocols[mgmProto.RuleProtocol_TCP]; ok { enableSSH = enableSSH && !ok } @@ -83,9 +100,9 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { if enableSSH { rules = append(rules, &mgmProto.FirewallRule{ PeerIP: "0.0.0.0", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_TCP, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, Port: strconv.Itoa(ssh.DefaultSSHPort), }) } @@ -97,20 +114,20 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { rules = append(rules, &mgmProto.FirewallRule{ PeerIP: "0.0.0.0", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, &mgmProto.FirewallRule{ PeerIP: "0.0.0.0", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, ) } - newRulePairs := make(map[string][]firewall.Rule) + newRulePairs := make(map[id.RuleID][]firewall.Rule) ipsetByRuleSelectors := make(map[string]string) for _, r := range rules { @@ -130,29 +147,97 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { break } if len(rules) > 0 { - d.rulesPairs[pairID] = rulePair + d.peerRulesPairs[pairID] = rulePair newRulePairs[pairID] = rulePair } } - for pairID, rules := range d.rulesPairs { + for pairID, rules := range d.peerRulesPairs { if _, ok := newRulePairs[pairID]; !ok { for _, rule := range rules { - if err := d.firewall.DeleteRule(rule); err != nil { - log.Errorf("failed to delete firewall rule: %v", err) + if err := d.firewall.DeletePeerRule(rule); err != nil { + log.Errorf("failed to delete peer firewall rule: %v", err) continue } } - delete(d.rulesPairs, pairID) + delete(d.peerRulesPairs, pairID) } } - d.rulesPairs = newRulePairs + d.peerRulesPairs = newRulePairs +} + +func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error { + var newRouteRules = make(map[id.RuleID]struct{}) + for _, rule := range rules { + id, err := d.applyRouteACL(rule) + if err != nil { + return fmt.Errorf("apply route ACL: %w", err) + } + newRouteRules[id] = struct{}{} + } + + for id := range d.routeRules { + if _, ok := newRouteRules[id]; !ok { + if err := d.firewall.DeleteRouteRule(id); err != nil { + log.Errorf("failed to delete route firewall rule: %v", err) + continue + } + delete(d.routeRules, id) + } + } + d.routeRules = newRouteRules + return nil +} + +func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) { + if len(rule.SourceRanges) == 0 { + return "", fmt.Errorf("source ranges is empty") + } + + var sources []netip.Prefix + for _, sourceRange := range rule.SourceRanges { + source, err := netip.ParsePrefix(sourceRange) + if err != nil { + return "", fmt.Errorf("parse source range: %w", err) + } + sources = append(sources, source) + } + + var destination netip.Prefix + if rule.IsDynamic { + destination = getDefault(sources[0]) + } else { + var err error + destination, err = netip.ParsePrefix(rule.Destination) + if err != nil { + return "", fmt.Errorf("parse destination: %w", err) + } + } + + protocol, err := convertToFirewallProtocol(rule.Protocol) + if err != nil { + return "", fmt.Errorf("invalid protocol: %w", err) + } + + action, err := convertFirewallAction(rule.Action) + if err != nil { + return "", fmt.Errorf("invalid action: %w", err) + } + + dPorts := convertPortInfo(rule.PortInfo) + + addedRule, err := d.firewall.AddRouteFiltering(sources, destination, protocol, nil, dPorts, action) + if err != nil { + return "", fmt.Errorf("add route rule: %w", err) + } + + return id.RuleID(addedRule.GetRuleID()), nil } func (d *DefaultManager) protoRuleToFirewallRule( r *mgmProto.FirewallRule, ipsetName string, -) (string, []firewall.Rule, error) { +) (id.RuleID, []firewall.Rule, error) { ip := net.ParseIP(r.PeerIP) if ip == nil { return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule") @@ -179,16 +264,16 @@ func (d *DefaultManager) protoRuleToFirewallRule( } } - ruleID := d.getRuleID(ip, protocol, int(r.Direction), port, action, "") - if rulesPair, ok := d.rulesPairs[ruleID]; ok { + ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action, "") + if rulesPair, ok := d.peerRulesPairs[ruleID]; ok { return ruleID, rulesPair, nil } var rules []firewall.Rule switch r.Direction { - case mgmProto.FirewallRule_IN: + case mgmProto.RuleDirection_IN: rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "") - case mgmProto.FirewallRule_OUT: + case mgmProto.RuleDirection_OUT: rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "") default: return "", nil, fmt.Errorf("invalid direction, skipping firewall rule") @@ -210,7 +295,7 @@ func (d *DefaultManager) addInRules( comment string, ) ([]firewall.Rule, error) { var rules []firewall.Rule - rule, err := d.firewall.AddFiltering( + rule, err := d.firewall.AddPeerFiltering( ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment) if err != nil { return nil, fmt.Errorf("failed to add firewall rule: %v", err) @@ -221,7 +306,7 @@ func (d *DefaultManager) addInRules( return rules, nil } - rule, err = d.firewall.AddFiltering( + rule, err = d.firewall.AddPeerFiltering( ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment) if err != nil { return nil, fmt.Errorf("failed to add firewall rule: %v", err) @@ -239,7 +324,7 @@ func (d *DefaultManager) addOutRules( comment string, ) ([]firewall.Rule, error) { var rules []firewall.Rule - rule, err := d.firewall.AddFiltering( + rule, err := d.firewall.AddPeerFiltering( ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment) if err != nil { return nil, fmt.Errorf("failed to add firewall rule: %v", err) @@ -250,7 +335,7 @@ func (d *DefaultManager) addOutRules( return rules, nil } - rule, err = d.firewall.AddFiltering( + rule, err = d.firewall.AddPeerFiltering( ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment) if err != nil { return nil, fmt.Errorf("failed to add firewall rule: %v", err) @@ -259,21 +344,21 @@ func (d *DefaultManager) addOutRules( return append(rules, rule...), nil } -// getRuleID() returns unique ID for the rule based on its parameters. -func (d *DefaultManager) getRuleID( +// getPeerRuleID() returns unique ID for the rule based on its parameters. +func (d *DefaultManager) getPeerRuleID( ip net.IP, proto firewall.Protocol, direction int, port *firewall.Port, action firewall.Action, comment string, -) string { +) id.RuleID { idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment if port != nil { idStr += port.String() } - return hex.EncodeToString(md5.New().Sum([]byte(idStr))) + return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr)))) } // squashAcceptRules does complex logic to convert many rules which allows connection by traffic type @@ -283,7 +368,7 @@ func (d *DefaultManager) getRuleID( // but other has port definitions or has drop policy. func (d *DefaultManager) squashAcceptRules( networkMap *mgmProto.NetworkMap, -) ([]*mgmProto.FirewallRule, map[mgmProto.FirewallRuleProtocol]struct{}) { +) ([]*mgmProto.FirewallRule, map[mgmProto.RuleProtocol]struct{}) { totalIPs := 0 for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) { for range p.AllowedIps { @@ -291,14 +376,14 @@ func (d *DefaultManager) squashAcceptRules( } } - type protoMatch map[mgmProto.FirewallRuleProtocol]map[string]int + type protoMatch map[mgmProto.RuleProtocol]map[string]int in := protoMatch{} out := protoMatch{} // trace which type of protocols was squashed squashedRules := []*mgmProto.FirewallRule{} - squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{} + squashedProtocols := map[mgmProto.RuleProtocol]struct{}{} // this function we use to do calculation, can we squash the rules by protocol or not. // We summ amount of Peers IP for given protocol we found in original rules list. @@ -308,7 +393,7 @@ func (d *DefaultManager) squashAcceptRules( // // We zeroed this to notify squash function that this protocol can't be squashed. addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols protoMatch) { - drop := r.Action == mgmProto.FirewallRule_DROP || r.Port != "" + drop := r.Action == mgmProto.RuleAction_DROP || r.Port != "" if drop { protocols[r.Protocol] = map[string]int{} return @@ -336,7 +421,7 @@ func (d *DefaultManager) squashAcceptRules( for i, r := range networkMap.FirewallRules { // calculate squash for different directions - if r.Direction == mgmProto.FirewallRule_IN { + if r.Direction == mgmProto.RuleDirection_IN { addRuleToCalculationMap(i, r, in) } else { addRuleToCalculationMap(i, r, out) @@ -345,14 +430,14 @@ func (d *DefaultManager) squashAcceptRules( // order of squashing by protocol is important // only for their first element ALL, it must be done first - protocolOrders := []mgmProto.FirewallRuleProtocol{ - mgmProto.FirewallRule_ALL, - mgmProto.FirewallRule_ICMP, - mgmProto.FirewallRule_TCP, - mgmProto.FirewallRule_UDP, + protocolOrders := []mgmProto.RuleProtocol{ + mgmProto.RuleProtocol_ALL, + mgmProto.RuleProtocol_ICMP, + mgmProto.RuleProtocol_TCP, + mgmProto.RuleProtocol_UDP, } - squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) { + squash := func(matches protoMatch, direction mgmProto.RuleDirection) { for _, protocol := range protocolOrders { if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 { // don't squash if : @@ -365,12 +450,12 @@ func (d *DefaultManager) squashAcceptRules( squashedRules = append(squashedRules, &mgmProto.FirewallRule{ PeerIP: "0.0.0.0", Direction: direction, - Action: mgmProto.FirewallRule_ACCEPT, + Action: mgmProto.RuleAction_ACCEPT, Protocol: protocol, }) squashedProtocols[protocol] = struct{}{} - if protocol == mgmProto.FirewallRule_ALL { + if protocol == mgmProto.RuleProtocol_ALL { // if we have ALL traffic type squashed rule // it allows all other type of traffic, so we can stop processing break @@ -378,11 +463,11 @@ func (d *DefaultManager) squashAcceptRules( } } - squash(in, mgmProto.FirewallRule_IN) - squash(out, mgmProto.FirewallRule_OUT) + squash(in, mgmProto.RuleDirection_IN) + squash(out, mgmProto.RuleDirection_OUT) // if all protocol was squashed everything is allow and we can ignore all other rules - if _, ok := squashedProtocols[mgmProto.FirewallRule_ALL]; ok { + if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok { return squashedRules, squashedProtocols } @@ -412,26 +497,26 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port) } -func (d *DefaultManager) rollBack(newRulePairs map[string][]firewall.Rule) { +func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) { log.Debugf("rollback ACL to previous state") for _, rules := range newRulePairs { for _, rule := range rules { - if err := d.firewall.DeleteRule(rule); err != nil { + if err := d.firewall.DeletePeerRule(rule); err != nil { log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err) } } } } -func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) (firewall.Protocol, error) { +func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) { switch protocol { - case mgmProto.FirewallRule_TCP: + case mgmProto.RuleProtocol_TCP: return firewall.ProtocolTCP, nil - case mgmProto.FirewallRule_UDP: + case mgmProto.RuleProtocol_UDP: return firewall.ProtocolUDP, nil - case mgmProto.FirewallRule_ICMP: + case mgmProto.RuleProtocol_ICMP: return firewall.ProtocolICMP, nil - case mgmProto.FirewallRule_ALL: + case mgmProto.RuleProtocol_ALL: return firewall.ProtocolALL, nil default: return firewall.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String()) @@ -442,13 +527,41 @@ func shouldSkipInvertedRule(protocol firewall.Protocol, port *firewall.Port) boo return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil } -func convertFirewallAction(action mgmProto.FirewallRuleAction) (firewall.Action, error) { +func convertFirewallAction(action mgmProto.RuleAction) (firewall.Action, error) { switch action { - case mgmProto.FirewallRule_ACCEPT: + case mgmProto.RuleAction_ACCEPT: return firewall.ActionAccept, nil - case mgmProto.FirewallRule_DROP: + case mgmProto.RuleAction_DROP: return firewall.ActionDrop, nil default: return firewall.ActionDrop, fmt.Errorf("invalid action type: %d", action) } } + +func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port { + if portInfo == nil { + return nil + } + + if portInfo.GetPort() != 0 { + return &firewall.Port{ + Values: []int{int(portInfo.GetPort())}, + } + } + + if portInfo.GetRange() != nil { + return &firewall.Port{ + IsRange: true, + Values: []int{int(portInfo.GetRange().Start), int(portInfo.GetRange().End)}, + } + } + + return nil +} + +func getDefault(prefix netip.Prefix) netip.Prefix { + if prefix.Addr().Is6() { + return netip.PrefixFrom(netip.IPv6Unspecified(), 0) + } + return netip.PrefixFrom(netip.IPv4Unspecified(), 0) +} diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 494d54bf2..eec3d3b8c 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -19,16 +19,16 @@ func TestDefaultManager(t *testing.T) { FirewallRules: []*mgmProto.FirewallRule{ { PeerIP: "10.93.0.1", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_TCP, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, Port: "80", }, { PeerIP: "10.93.0.2", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_DROP, - Protocol: mgmProto.FirewallRule_UDP, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_DROP, + Protocol: mgmProto.RuleProtocol_UDP, Port: "53", }, }, @@ -65,16 +65,16 @@ func TestDefaultManager(t *testing.T) { t.Run("apply firewall rules", func(t *testing.T) { acl.ApplyFiltering(networkMap) - if len(acl.rulesPairs) != 2 { - t.Errorf("firewall rules not applied: %v", acl.rulesPairs) + if len(acl.peerRulesPairs) != 2 { + t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs) return } }) t.Run("add extra rules", func(t *testing.T) { existedPairs := map[string]struct{}{} - for id := range acl.rulesPairs { - existedPairs[id] = struct{}{} + for id := range acl.peerRulesPairs { + existedPairs[id.GetRuleID()] = struct{}{} } // remove first rule @@ -83,24 +83,24 @@ func TestDefaultManager(t *testing.T) { networkMap.FirewallRules, &mgmProto.FirewallRule{ PeerIP: "10.93.0.3", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_DROP, - Protocol: mgmProto.FirewallRule_ICMP, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_DROP, + Protocol: mgmProto.RuleProtocol_ICMP, }, ) acl.ApplyFiltering(networkMap) // we should have one old and one new rule in the existed rules - if len(acl.rulesPairs) != 2 { + if len(acl.peerRulesPairs) != 2 { t.Errorf("firewall rules not applied") return } // check that old rule was removed previousCount := 0 - for id := range acl.rulesPairs { - if _, ok := existedPairs[id]; ok { + for id := range acl.peerRulesPairs { + if _, ok := existedPairs[id.GetRuleID()]; ok { previousCount++ } } @@ -113,15 +113,15 @@ func TestDefaultManager(t *testing.T) { networkMap.FirewallRules = networkMap.FirewallRules[:0] networkMap.FirewallRulesIsEmpty = true - if acl.ApplyFiltering(networkMap); len(acl.rulesPairs) != 0 { - t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.rulesPairs)) + if acl.ApplyFiltering(networkMap); len(acl.peerRulesPairs) != 0 { + t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs)) return } networkMap.FirewallRulesIsEmpty = false acl.ApplyFiltering(networkMap) - if len(acl.rulesPairs) != 2 { - t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.rulesPairs)) + if len(acl.peerRulesPairs) != 2 { + t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs)) return } }) @@ -138,51 +138,51 @@ func TestDefaultManagerSquashRules(t *testing.T) { FirewallRules: []*mgmProto.FirewallRule{ { PeerIP: "10.93.0.1", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.2", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.3", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.4", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.1", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.2", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.3", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.4", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, }, } @@ -199,13 +199,13 @@ func TestDefaultManagerSquashRules(t *testing.T) { case r.PeerIP != "0.0.0.0": t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP) return - case r.Direction != mgmProto.FirewallRule_IN: + case r.Direction != mgmProto.RuleDirection_IN: t.Errorf("direction should be IN, got: %v", r.Direction) return - case r.Protocol != mgmProto.FirewallRule_ALL: + case r.Protocol != mgmProto.RuleProtocol_ALL: t.Errorf("protocol should be ALL, got: %v", r.Protocol) return - case r.Action != mgmProto.FirewallRule_ACCEPT: + case r.Action != mgmProto.RuleAction_ACCEPT: t.Errorf("action should be ACCEPT, got: %v", r.Action) return } @@ -215,13 +215,13 @@ func TestDefaultManagerSquashRules(t *testing.T) { case r.PeerIP != "0.0.0.0": t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP) return - case r.Direction != mgmProto.FirewallRule_OUT: + case r.Direction != mgmProto.RuleDirection_OUT: t.Errorf("direction should be OUT, got: %v", r.Direction) return - case r.Protocol != mgmProto.FirewallRule_ALL: + case r.Protocol != mgmProto.RuleProtocol_ALL: t.Errorf("protocol should be ALL, got: %v", r.Protocol) return - case r.Action != mgmProto.FirewallRule_ACCEPT: + case r.Action != mgmProto.RuleAction_ACCEPT: t.Errorf("action should be ACCEPT, got: %v", r.Action) return } @@ -238,51 +238,51 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) { FirewallRules: []*mgmProto.FirewallRule{ { PeerIP: "10.93.0.1", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.2", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.3", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.4", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_TCP, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, }, { PeerIP: "10.93.0.1", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.2", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.3", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.4", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_UDP, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_UDP, }, }, } @@ -308,21 +308,21 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { FirewallRules: []*mgmProto.FirewallRule{ { PeerIP: "10.93.0.1", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_TCP, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, }, { PeerIP: "10.93.0.2", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_TCP, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, }, { PeerIP: "10.93.0.3", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_UDP, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_UDP, }, }, } @@ -357,8 +357,8 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { acl.ApplyFiltering(networkMap) - if len(acl.rulesPairs) != 4 { - t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.rulesPairs)) + if len(acl.peerRulesPairs) != 4 { + t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.peerRulesPairs)) return } } diff --git a/client/internal/engine.go b/client/internal/engine.go index 463507ad8..998cbce2d 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -704,6 +704,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { return nil } + // Apply ACLs in the beginning to avoid security leaks + if e.acl != nil { + e.acl.ApplyFiltering(networkMap) + } + protoRoutes := networkMap.GetRoutes() if protoRoutes == nil { protoRoutes = []*mgmProto.Route{} @@ -770,10 +775,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { log.Errorf("failed to update dns server, err: %v", err) } - if e.acl != nil { - e.acl.ApplyFiltering(networkMap) - } - e.networkSerial = serial // Test received (upstream) servers for availability right away instead of upon usage. diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go index 5897031e7..e86a52810 100644 --- a/client/internal/routemanager/dynamic/route.go +++ b/client/internal/routemanager/dynamic/route.go @@ -303,7 +303,7 @@ func (r *Route) addRoutes(domain domain.Domain, prefixes []netip.Prefix) ([]neti var merr *multierror.Error for _, prefix := range prefixes { - if _, err := r.routeRefCounter.Increment(prefix, nil); err != nil { + if _, err := r.routeRefCounter.Increment(prefix, struct{}{}); err != nil { merr = multierror.Append(merr, fmt.Errorf("add dynamic route for IP %s: %w", prefix, err)) continue } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index cdfd322bd..d97fe631f 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -87,10 +87,10 @@ func NewManager( } dm.routeRefCounter = refcounter.New( - func(prefix netip.Prefix, _ any) (any, error) { - return nil, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface()) + func(prefix netip.Prefix, _ struct{}) (struct{}, error) { + return struct{}{}, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface()) }, - func(prefix netip.Prefix, _ any) error { + func(prefix netip.Prefix, _ struct{}) error { return sysOps.RemoveVPNRoute(prefix, wgInterface.ToInterface()) }, ) diff --git a/client/internal/routemanager/refcounter/refcounter.go b/client/internal/routemanager/refcounter/refcounter.go index f1d696ad9..65ea0f708 100644 --- a/client/internal/routemanager/refcounter/refcounter.go +++ b/client/internal/routemanager/refcounter/refcounter.go @@ -3,7 +3,8 @@ package refcounter import ( "errors" "fmt" - "net/netip" + "runtime" + "strings" "sync" "github.com/hashicorp/go-multierror" @@ -12,118 +13,153 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" ) -// ErrIgnore can be returned by AddFunc to indicate that the counter not be incremented for the given prefix. +const logLevel = log.TraceLevel + +// ErrIgnore can be returned by AddFunc to indicate that the counter should not be incremented for the given key. var ErrIgnore = errors.New("ignore") +// Ref holds the reference count and associated data for a key. type Ref[O any] struct { Count int Out O } -type AddFunc[I, O any] func(prefix netip.Prefix, in I) (out O, err error) -type RemoveFunc[I, O any] func(prefix netip.Prefix, out O) error +// AddFunc is the function type for adding a new key. +// Key is the type of the key (e.g., netip.Prefix). +type AddFunc[Key, I, O any] func(key Key, in I) (out O, err error) -type Counter[I, O any] struct { - // refCountMap keeps track of the reference Ref for prefixes - refCountMap map[netip.Prefix]Ref[O] +// RemoveFunc is the function type for removing a key. +type RemoveFunc[Key, O any] func(key Key, out O) error + +// Counter is a generic reference counter for managing keys and their associated data. +// Key: The type of the key (e.g., netip.Prefix, string). +// +// I: The input type for the AddFunc. It is the input type for additional data needed +// when adding a key, it is passed as the second argument to AddFunc. +// +// O: The output type for the AddFunc and RemoveFunc. This is the output returned by AddFunc. +// It is stored and passed to RemoveFunc when the reference count reaches 0. +// +// The types can be aliased to a specific type using the following syntax: +// +// type RouteRefCounter = Counter[netip.Prefix, any, any] +type Counter[Key comparable, I, O any] struct { + // refCountMap keeps track of the reference Ref for keys + refCountMap map[Key]Ref[O] refCountMu sync.Mutex - // idMap keeps track of the prefixes associated with an ID for removal - idMap map[string][]netip.Prefix + // idMap keeps track of the keys associated with an ID for removal + idMap map[string][]Key idMu sync.Mutex - add AddFunc[I, O] - remove RemoveFunc[I, O] + add AddFunc[Key, I, O] + remove RemoveFunc[Key, O] } -// New creates a new Counter instance -func New[I, O any](add AddFunc[I, O], remove RemoveFunc[I, O]) *Counter[I, O] { - return &Counter[I, O]{ - refCountMap: map[netip.Prefix]Ref[O]{}, - idMap: map[string][]netip.Prefix{}, +// New creates a new Counter instance. +// Usage example: +// +// counter := New[netip.Prefix, string, string]( +// func(key netip.Prefix, in string) (out string, err error) { ... }, +// func(key netip.Prefix, out string) error { ... },` +// ) +func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key, O]) *Counter[Key, I, O] { + return &Counter[Key, I, O]{ + refCountMap: map[Key]Ref[O]{}, + idMap: map[string][]Key{}, add: add, remove: remove, } } -// Increment increments the reference count for the given prefix. -// If this is the first reference to the prefix, the AddFunc is called. -func (rm *Counter[I, O]) Increment(prefix netip.Prefix, in I) (Ref[O], error) { +// Get retrieves the current reference count and associated data for a key. +// If the key doesn't exist, it returns a zero value Ref and false. +func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) { rm.refCountMu.Lock() defer rm.refCountMu.Unlock() - ref := rm.refCountMap[prefix] - log.Tracef("Increasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out) + ref, ok := rm.refCountMap[key] + return ref, ok +} - // Call AddFunc only if it's a new prefix +// Increment increments the reference count for the given key. +// If this is the first reference to the key, the AddFunc is called. +func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) { + rm.refCountMu.Lock() + defer rm.refCountMu.Unlock() + + ref := rm.refCountMap[key] + logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out) + + // Call AddFunc only if it's a new key if ref.Count == 0 { - log.Tracef("Adding for prefix %s with [%v]", prefix, ref.Out) - out, err := rm.add(prefix, in) + logCallerF("Calling add for key %v", key) + out, err := rm.add(key, in) if errors.Is(err, ErrIgnore) { return ref, nil } if err != nil { - return ref, fmt.Errorf("failed to add for prefix %s: %w", prefix, err) + return ref, fmt.Errorf("failed to add for key %v: %w", key, err) } ref.Out = out } ref.Count++ - rm.refCountMap[prefix] = ref + rm.refCountMap[key] = ref return ref, nil } -// IncrementWithID increments the reference count for the given prefix and groups it under the given ID. -// If this is the first reference to the prefix, the AddFunc is called. -func (rm *Counter[I, O]) IncrementWithID(id string, prefix netip.Prefix, in I) (Ref[O], error) { +// IncrementWithID increments the reference count for the given key and groups it under the given ID. +// If this is the first reference to the key, the AddFunc is called. +func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) { rm.idMu.Lock() defer rm.idMu.Unlock() - ref, err := rm.Increment(prefix, in) + ref, err := rm.Increment(key, in) if err != nil { return ref, fmt.Errorf("with ID: %w", err) } - rm.idMap[id] = append(rm.idMap[id], prefix) + rm.idMap[id] = append(rm.idMap[id], key) return ref, nil } -// Decrement decrements the reference count for the given prefix. +// Decrement decrements the reference count for the given key. // If the reference count reaches 0, the RemoveFunc is called. -func (rm *Counter[I, O]) Decrement(prefix netip.Prefix) (Ref[O], error) { +func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) { rm.refCountMu.Lock() defer rm.refCountMu.Unlock() - ref, ok := rm.refCountMap[prefix] + ref, ok := rm.refCountMap[key] if !ok { - log.Tracef("No reference found for prefix %s", prefix) + logCallerF("No reference found for key %v", key) return ref, nil } - log.Tracef("Decreasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out) + logCallerF("Decreasing ref count [%d -> %d] for key %v with Out [%v]", ref.Count, ref.Count-1, key, ref.Out) if ref.Count == 1 { - log.Tracef("Removing for prefix %s with [%v]", prefix, ref.Out) - if err := rm.remove(prefix, ref.Out); err != nil { - return ref, fmt.Errorf("remove for prefix %s: %w", prefix, err) + logCallerF("Calling remove for key %v", key) + if err := rm.remove(key, ref.Out); err != nil { + return ref, fmt.Errorf("remove for key %v: %w", key, err) } - delete(rm.refCountMap, prefix) + delete(rm.refCountMap, key) } else { ref.Count-- - rm.refCountMap[prefix] = ref + rm.refCountMap[key] = ref } return ref, nil } -// DecrementWithID decrements the reference count for all prefixes associated with the given ID. +// DecrementWithID decrements the reference count for all keys associated with the given ID. // If the reference count reaches 0, the RemoveFunc is called. -func (rm *Counter[I, O]) DecrementWithID(id string) error { +func (rm *Counter[Key, I, O]) DecrementWithID(id string) error { rm.idMu.Lock() defer rm.idMu.Unlock() var merr *multierror.Error - for _, prefix := range rm.idMap[id] { - if _, err := rm.Decrement(prefix); err != nil { + for _, key := range rm.idMap[id] { + if _, err := rm.Decrement(key); err != nil { merr = multierror.Append(merr, err) } } @@ -132,24 +168,77 @@ func (rm *Counter[I, O]) DecrementWithID(id string) error { return nberrors.FormatErrorOrNil(merr) } -// Flush removes all references and calls RemoveFunc for each prefix. -func (rm *Counter[I, O]) Flush() error { +// Flush removes all references and calls RemoveFunc for each key. +func (rm *Counter[Key, I, O]) Flush() error { rm.refCountMu.Lock() defer rm.refCountMu.Unlock() rm.idMu.Lock() defer rm.idMu.Unlock() var merr *multierror.Error - for prefix := range rm.refCountMap { - log.Tracef("Removing for prefix %s", prefix) - ref := rm.refCountMap[prefix] - if err := rm.remove(prefix, ref.Out); err != nil { - merr = multierror.Append(merr, fmt.Errorf("remove for prefix %s: %w", prefix, err)) + for key := range rm.refCountMap { + logCallerF("Calling remove for key %v", key) + ref := rm.refCountMap[key] + if err := rm.remove(key, ref.Out); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove for key %v: %w", key, err)) } } - rm.refCountMap = map[netip.Prefix]Ref[O]{} - rm.idMap = map[string][]netip.Prefix{} + clear(rm.refCountMap) + clear(rm.idMap) return nberrors.FormatErrorOrNil(merr) } + +// Clear removes all references without calling RemoveFunc. +func (rm *Counter[Key, I, O]) Clear() { + rm.refCountMu.Lock() + defer rm.refCountMu.Unlock() + rm.idMu.Lock() + defer rm.idMu.Unlock() + + clear(rm.refCountMap) + clear(rm.idMap) +} + +func getCallerInfo(depth int, maxDepth int) (string, bool) { + if depth >= maxDepth { + return "", false + } + + pc, _, _, ok := runtime.Caller(depth) + if !ok { + return "", false + } + + if details := runtime.FuncForPC(pc); details != nil { + name := details.Name() + + lastDotIndex := strings.LastIndex(name, "/") + if lastDotIndex != -1 { + name = name[lastDotIndex+1:] + } + + if strings.HasPrefix(name, "refcounter.") { + // +2 to account for recursion + return getCallerInfo(depth+2, maxDepth) + } + + return name, true + } + + return "", false +} + +// logCaller logs a message with the package name and method of the function that called the current function. +func logCallerF(format string, args ...interface{}) { + if log.GetLevel() < logLevel { + return + } + + if callerName, ok := getCallerInfo(3, 18); ok { + format = fmt.Sprintf("[%s] %s", callerName, format) + } + + log.StandardLogger().Logf(logLevel, format, args...) +} diff --git a/client/internal/routemanager/refcounter/types.go b/client/internal/routemanager/refcounter/types.go index 6753b64ef..aadac3e25 100644 --- a/client/internal/routemanager/refcounter/types.go +++ b/client/internal/routemanager/refcounter/types.go @@ -1,7 +1,9 @@ package refcounter +import "net/netip" + // RouteRefCounter is a Counter for Route, it doesn't take any input on Increment and doesn't use any output on Decrement -type RouteRefCounter = Counter[any, any] +type RouteRefCounter = Counter[netip.Prefix, struct{}, struct{}] // AllowedIPsRefCounter is a Counter for AllowedIPs, it takes a peer key on Increment and passes it back to Decrement -type AllowedIPsRefCounter = Counter[string, string] +type AllowedIPsRefCounter = Counter[netip.Prefix, string, string] diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index 43a266cd2..1d1a4b063 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -94,7 +94,7 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error return fmt.Errorf("parse prefix: %w", err) } - err = m.firewall.RemoveRoutingRules(routerPair) + err = m.firewall.RemoveNatRule(routerPair) if err != nil { return fmt.Errorf("remove routing rules: %w", err) } @@ -123,7 +123,7 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error { return fmt.Errorf("parse prefix: %w", err) } - err = m.firewall.InsertRoutingRules(routerPair) + err = m.firewall.AddNatRule(routerPair) if err != nil { return fmt.Errorf("insert routing rules: %w", err) } @@ -157,7 +157,7 @@ func (m *defaultServerRouter) cleanUp() { continue } - err = m.firewall.RemoveRoutingRules(routerPair) + err = m.firewall.RemoveNatRule(routerPair) if err != nil { log.Errorf("Failed to remove cleanup route: %v", err) } @@ -173,15 +173,15 @@ func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) { // TODO: add ipv6 source := getDefaultPrefix(route.Network) - destination := route.Network.Masked().String() + destination := route.Network.Masked() if route.IsDynamic() { - // TODO: add ipv6 - destination = "0.0.0.0/0" + // TODO: add ipv6 additionally + destination = getDefaultPrefix(destination) } return firewall.RouterPair{ - ID: string(route.ID), - Source: source.String(), + ID: route.ID, + Source: source, Destination: destination, Masquerade: route.Masquerade, }, nil diff --git a/client/internal/routemanager/static/route.go b/client/internal/routemanager/static/route.go index 88cca522a..98c34dbee 100644 --- a/client/internal/routemanager/static/route.go +++ b/client/internal/routemanager/static/route.go @@ -30,7 +30,7 @@ func (r *Route) String() string { } func (r *Route) AddRoute(context.Context) error { - _, err := r.routeRefCounter.Increment(r.route.Network, nil) + _, err := r.routeRefCounter.Increment(r.route.Network, struct{}{}) return err } diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go index ae27b0123..10944c1e2 100644 --- a/client/internal/routemanager/systemops/systemops.go +++ b/client/internal/routemanager/systemops/systemops.go @@ -15,7 +15,7 @@ type Nexthop struct { Intf *net.Interface } -type ExclusionCounter = refcounter.Counter[any, Nexthop] +type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop] type SysOps struct { refCounter *ExclusionCounter diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index d76824c10..90f06ba78 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -41,7 +41,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn } refCounter := refcounter.New( - func(prefix netip.Prefix, _ any) (Nexthop, error) { + func(prefix netip.Prefix, _ struct{}) (Nexthop, error) { initialNexthop := initialNextHopV4 if prefix.Addr().Is6() { initialNexthop = initialNextHopV6 @@ -317,7 +317,7 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re return fmt.Errorf("convert ip to prefix: %w", err) } - if _, err := r.refCounter.IncrementWithID(string(connID), prefix, nil); err != nil { + if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil { return fmt.Errorf("adding route reference: %v", err) } diff --git a/management/proto/management.pb.go b/management/proto/management.pb.go index 48f048c4c..672b2a102 100644 --- a/management/proto/management.pb.go +++ b/management/proto/management.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v3.21.12 +// protoc v4.23.4 // source: management.proto package proto @@ -21,6 +21,153 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) +type RuleProtocol int32 + +const ( + RuleProtocol_UNKNOWN RuleProtocol = 0 + RuleProtocol_ALL RuleProtocol = 1 + RuleProtocol_TCP RuleProtocol = 2 + RuleProtocol_UDP RuleProtocol = 3 + RuleProtocol_ICMP RuleProtocol = 4 +) + +// Enum value maps for RuleProtocol. +var ( + RuleProtocol_name = map[int32]string{ + 0: "UNKNOWN", + 1: "ALL", + 2: "TCP", + 3: "UDP", + 4: "ICMP", + } + RuleProtocol_value = map[string]int32{ + "UNKNOWN": 0, + "ALL": 1, + "TCP": 2, + "UDP": 3, + "ICMP": 4, + } +) + +func (x RuleProtocol) Enum() *RuleProtocol { + p := new(RuleProtocol) + *p = x + return p +} + +func (x RuleProtocol) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (RuleProtocol) Descriptor() protoreflect.EnumDescriptor { + return file_management_proto_enumTypes[0].Descriptor() +} + +func (RuleProtocol) Type() protoreflect.EnumType { + return &file_management_proto_enumTypes[0] +} + +func (x RuleProtocol) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use RuleProtocol.Descriptor instead. +func (RuleProtocol) EnumDescriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{0} +} + +type RuleDirection int32 + +const ( + RuleDirection_IN RuleDirection = 0 + RuleDirection_OUT RuleDirection = 1 +) + +// Enum value maps for RuleDirection. +var ( + RuleDirection_name = map[int32]string{ + 0: "IN", + 1: "OUT", + } + RuleDirection_value = map[string]int32{ + "IN": 0, + "OUT": 1, + } +) + +func (x RuleDirection) Enum() *RuleDirection { + p := new(RuleDirection) + *p = x + return p +} + +func (x RuleDirection) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (RuleDirection) Descriptor() protoreflect.EnumDescriptor { + return file_management_proto_enumTypes[1].Descriptor() +} + +func (RuleDirection) Type() protoreflect.EnumType { + return &file_management_proto_enumTypes[1] +} + +func (x RuleDirection) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use RuleDirection.Descriptor instead. +func (RuleDirection) EnumDescriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{1} +} + +type RuleAction int32 + +const ( + RuleAction_ACCEPT RuleAction = 0 + RuleAction_DROP RuleAction = 1 +) + +// Enum value maps for RuleAction. +var ( + RuleAction_name = map[int32]string{ + 0: "ACCEPT", + 1: "DROP", + } + RuleAction_value = map[string]int32{ + "ACCEPT": 0, + "DROP": 1, + } +) + +func (x RuleAction) Enum() *RuleAction { + p := new(RuleAction) + *p = x + return p +} + +func (x RuleAction) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (RuleAction) Descriptor() protoreflect.EnumDescriptor { + return file_management_proto_enumTypes[2].Descriptor() +} + +func (RuleAction) Type() protoreflect.EnumType { + return &file_management_proto_enumTypes[2] +} + +func (x RuleAction) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use RuleAction.Descriptor instead. +func (RuleAction) EnumDescriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{2} +} + type HostConfig_Protocol int32 const ( @@ -60,11 +207,11 @@ func (x HostConfig_Protocol) String() string { } func (HostConfig_Protocol) Descriptor() protoreflect.EnumDescriptor { - return file_management_proto_enumTypes[0].Descriptor() + return file_management_proto_enumTypes[3].Descriptor() } func (HostConfig_Protocol) Type() protoreflect.EnumType { - return &file_management_proto_enumTypes[0] + return &file_management_proto_enumTypes[3] } func (x HostConfig_Protocol) Number() protoreflect.EnumNumber { @@ -103,11 +250,11 @@ func (x DeviceAuthorizationFlowProvider) String() string { } func (DeviceAuthorizationFlowProvider) Descriptor() protoreflect.EnumDescriptor { - return file_management_proto_enumTypes[1].Descriptor() + return file_management_proto_enumTypes[4].Descriptor() } func (DeviceAuthorizationFlowProvider) Type() protoreflect.EnumType { - return &file_management_proto_enumTypes[1] + return &file_management_proto_enumTypes[4] } func (x DeviceAuthorizationFlowProvider) Number() protoreflect.EnumNumber { @@ -119,153 +266,6 @@ func (DeviceAuthorizationFlowProvider) EnumDescriptor() ([]byte, []int) { return file_management_proto_rawDescGZIP(), []int{21, 0} } -type FirewallRuleDirection int32 - -const ( - FirewallRule_IN FirewallRuleDirection = 0 - FirewallRule_OUT FirewallRuleDirection = 1 -) - -// Enum value maps for FirewallRuleDirection. -var ( - FirewallRuleDirection_name = map[int32]string{ - 0: "IN", - 1: "OUT", - } - FirewallRuleDirection_value = map[string]int32{ - "IN": 0, - "OUT": 1, - } -) - -func (x FirewallRuleDirection) Enum() *FirewallRuleDirection { - p := new(FirewallRuleDirection) - *p = x - return p -} - -func (x FirewallRuleDirection) String() string { - return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) -} - -func (FirewallRuleDirection) Descriptor() protoreflect.EnumDescriptor { - return file_management_proto_enumTypes[2].Descriptor() -} - -func (FirewallRuleDirection) Type() protoreflect.EnumType { - return &file_management_proto_enumTypes[2] -} - -func (x FirewallRuleDirection) Number() protoreflect.EnumNumber { - return protoreflect.EnumNumber(x) -} - -// Deprecated: Use FirewallRuleDirection.Descriptor instead. -func (FirewallRuleDirection) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{31, 0} -} - -type FirewallRuleAction int32 - -const ( - FirewallRule_ACCEPT FirewallRuleAction = 0 - FirewallRule_DROP FirewallRuleAction = 1 -) - -// Enum value maps for FirewallRuleAction. -var ( - FirewallRuleAction_name = map[int32]string{ - 0: "ACCEPT", - 1: "DROP", - } - FirewallRuleAction_value = map[string]int32{ - "ACCEPT": 0, - "DROP": 1, - } -) - -func (x FirewallRuleAction) Enum() *FirewallRuleAction { - p := new(FirewallRuleAction) - *p = x - return p -} - -func (x FirewallRuleAction) String() string { - return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) -} - -func (FirewallRuleAction) Descriptor() protoreflect.EnumDescriptor { - return file_management_proto_enumTypes[3].Descriptor() -} - -func (FirewallRuleAction) Type() protoreflect.EnumType { - return &file_management_proto_enumTypes[3] -} - -func (x FirewallRuleAction) Number() protoreflect.EnumNumber { - return protoreflect.EnumNumber(x) -} - -// Deprecated: Use FirewallRuleAction.Descriptor instead. -func (FirewallRuleAction) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{31, 1} -} - -type FirewallRuleProtocol int32 - -const ( - FirewallRule_UNKNOWN FirewallRuleProtocol = 0 - FirewallRule_ALL FirewallRuleProtocol = 1 - FirewallRule_TCP FirewallRuleProtocol = 2 - FirewallRule_UDP FirewallRuleProtocol = 3 - FirewallRule_ICMP FirewallRuleProtocol = 4 -) - -// Enum value maps for FirewallRuleProtocol. -var ( - FirewallRuleProtocol_name = map[int32]string{ - 0: "UNKNOWN", - 1: "ALL", - 2: "TCP", - 3: "UDP", - 4: "ICMP", - } - FirewallRuleProtocol_value = map[string]int32{ - "UNKNOWN": 0, - "ALL": 1, - "TCP": 2, - "UDP": 3, - "ICMP": 4, - } -) - -func (x FirewallRuleProtocol) Enum() *FirewallRuleProtocol { - p := new(FirewallRuleProtocol) - *p = x - return p -} - -func (x FirewallRuleProtocol) String() string { - return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) -} - -func (FirewallRuleProtocol) Descriptor() protoreflect.EnumDescriptor { - return file_management_proto_enumTypes[4].Descriptor() -} - -func (FirewallRuleProtocol) Type() protoreflect.EnumType { - return &file_management_proto_enumTypes[4] -} - -func (x FirewallRuleProtocol) Number() protoreflect.EnumNumber { - return protoreflect.EnumNumber(x) -} - -// Deprecated: Use FirewallRuleProtocol.Descriptor instead. -func (FirewallRuleProtocol) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{31, 2} -} - type EncryptedMessage struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -1482,6 +1482,10 @@ type NetworkMap struct { FirewallRules []*FirewallRule `protobuf:"bytes,8,rep,name=FirewallRules,proto3" json:"FirewallRules,omitempty"` // firewallRulesIsEmpty indicates whether FirewallRule array is empty or not to bypass protobuf null and empty array equality. FirewallRulesIsEmpty bool `protobuf:"varint,9,opt,name=firewallRulesIsEmpty,proto3" json:"firewallRulesIsEmpty,omitempty"` + // RoutesFirewallRules represents a list of routes firewall rules to be applied to peer + RoutesFirewallRules []*RouteFirewallRule `protobuf:"bytes,10,rep,name=routesFirewallRules,proto3" json:"routesFirewallRules,omitempty"` + // RoutesFirewallRulesIsEmpty indicates whether RouteFirewallRule array is empty or not to bypass protobuf null and empty array equality. + RoutesFirewallRulesIsEmpty bool `protobuf:"varint,11,opt,name=routesFirewallRulesIsEmpty,proto3" json:"routesFirewallRulesIsEmpty,omitempty"` } func (x *NetworkMap) Reset() { @@ -1579,6 +1583,20 @@ func (x *NetworkMap) GetFirewallRulesIsEmpty() bool { return false } +func (x *NetworkMap) GetRoutesFirewallRules() []*RouteFirewallRule { + if x != nil { + return x.RoutesFirewallRules + } + return nil +} + +func (x *NetworkMap) GetRoutesFirewallRulesIsEmpty() bool { + if x != nil { + return x.RoutesFirewallRulesIsEmpty + } + return false +} + // RemotePeerConfig represents a configuration of a remote peer. // The properties are used to configure WireGuard Peers sections type RemotePeerConfig struct { @@ -2487,11 +2505,11 @@ type FirewallRule struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - PeerIP string `protobuf:"bytes,1,opt,name=PeerIP,proto3" json:"PeerIP,omitempty"` - Direction FirewallRuleDirection `protobuf:"varint,2,opt,name=Direction,proto3,enum=management.FirewallRuleDirection" json:"Direction,omitempty"` - Action FirewallRuleAction `protobuf:"varint,3,opt,name=Action,proto3,enum=management.FirewallRuleAction" json:"Action,omitempty"` - Protocol FirewallRuleProtocol `protobuf:"varint,4,opt,name=Protocol,proto3,enum=management.FirewallRuleProtocol" json:"Protocol,omitempty"` - Port string `protobuf:"bytes,5,opt,name=Port,proto3" json:"Port,omitempty"` + PeerIP string `protobuf:"bytes,1,opt,name=PeerIP,proto3" json:"PeerIP,omitempty"` + Direction RuleDirection `protobuf:"varint,2,opt,name=Direction,proto3,enum=management.RuleDirection" json:"Direction,omitempty"` + Action RuleAction `protobuf:"varint,3,opt,name=Action,proto3,enum=management.RuleAction" json:"Action,omitempty"` + Protocol RuleProtocol `protobuf:"varint,4,opt,name=Protocol,proto3,enum=management.RuleProtocol" json:"Protocol,omitempty"` + Port string `protobuf:"bytes,5,opt,name=Port,proto3" json:"Port,omitempty"` } func (x *FirewallRule) Reset() { @@ -2533,25 +2551,25 @@ func (x *FirewallRule) GetPeerIP() string { return "" } -func (x *FirewallRule) GetDirection() FirewallRuleDirection { +func (x *FirewallRule) GetDirection() RuleDirection { if x != nil { return x.Direction } - return FirewallRule_IN + return RuleDirection_IN } -func (x *FirewallRule) GetAction() FirewallRuleAction { +func (x *FirewallRule) GetAction() RuleAction { if x != nil { return x.Action } - return FirewallRule_ACCEPT + return RuleAction_ACCEPT } -func (x *FirewallRule) GetProtocol() FirewallRuleProtocol { +func (x *FirewallRule) GetProtocol() RuleProtocol { if x != nil { return x.Protocol } - return FirewallRule_UNKNOWN + return RuleProtocol_UNKNOWN } func (x *FirewallRule) GetPort() string { @@ -2663,6 +2681,236 @@ func (x *Checks) GetFiles() []string { return nil } +type PortInfo struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Types that are assignable to PortSelection: + // + // *PortInfo_Port + // *PortInfo_Range_ + PortSelection isPortInfo_PortSelection `protobuf_oneof:"portSelection"` +} + +func (x *PortInfo) Reset() { + *x = PortInfo{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[34] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *PortInfo) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PortInfo) ProtoMessage() {} + +func (x *PortInfo) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[34] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PortInfo.ProtoReflect.Descriptor instead. +func (*PortInfo) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{34} +} + +func (m *PortInfo) GetPortSelection() isPortInfo_PortSelection { + if m != nil { + return m.PortSelection + } + return nil +} + +func (x *PortInfo) GetPort() uint32 { + if x, ok := x.GetPortSelection().(*PortInfo_Port); ok { + return x.Port + } + return 0 +} + +func (x *PortInfo) GetRange() *PortInfo_Range { + if x, ok := x.GetPortSelection().(*PortInfo_Range_); ok { + return x.Range + } + return nil +} + +type isPortInfo_PortSelection interface { + isPortInfo_PortSelection() +} + +type PortInfo_Port struct { + Port uint32 `protobuf:"varint,1,opt,name=port,proto3,oneof"` +} + +type PortInfo_Range_ struct { + Range *PortInfo_Range `protobuf:"bytes,2,opt,name=range,proto3,oneof"` +} + +func (*PortInfo_Port) isPortInfo_PortSelection() {} + +func (*PortInfo_Range_) isPortInfo_PortSelection() {} + +// RouteFirewallRule signifies a firewall rule applicable for a routed network. +type RouteFirewallRule struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // sourceRanges IP ranges of the routing peers. + SourceRanges []string `protobuf:"bytes,1,rep,name=sourceRanges,proto3" json:"sourceRanges,omitempty"` + // Action to be taken by the firewall when the rule is applicable. + Action RuleAction `protobuf:"varint,2,opt,name=action,proto3,enum=management.RuleAction" json:"action,omitempty"` + // Network prefix for the routed network. + Destination string `protobuf:"bytes,3,opt,name=destination,proto3" json:"destination,omitempty"` + // Protocol of the routed network. + Protocol RuleProtocol `protobuf:"varint,4,opt,name=protocol,proto3,enum=management.RuleProtocol" json:"protocol,omitempty"` + // Details about the port. + PortInfo *PortInfo `protobuf:"bytes,5,opt,name=portInfo,proto3" json:"portInfo,omitempty"` + // IsDynamic indicates if the route is a DNS route. + IsDynamic bool `protobuf:"varint,6,opt,name=isDynamic,proto3" json:"isDynamic,omitempty"` +} + +func (x *RouteFirewallRule) Reset() { + *x = RouteFirewallRule{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[35] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *RouteFirewallRule) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RouteFirewallRule) ProtoMessage() {} + +func (x *RouteFirewallRule) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[35] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RouteFirewallRule.ProtoReflect.Descriptor instead. +func (*RouteFirewallRule) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{35} +} + +func (x *RouteFirewallRule) GetSourceRanges() []string { + if x != nil { + return x.SourceRanges + } + return nil +} + +func (x *RouteFirewallRule) GetAction() RuleAction { + if x != nil { + return x.Action + } + return RuleAction_ACCEPT +} + +func (x *RouteFirewallRule) GetDestination() string { + if x != nil { + return x.Destination + } + return "" +} + +func (x *RouteFirewallRule) GetProtocol() RuleProtocol { + if x != nil { + return x.Protocol + } + return RuleProtocol_UNKNOWN +} + +func (x *RouteFirewallRule) GetPortInfo() *PortInfo { + if x != nil { + return x.PortInfo + } + return nil +} + +func (x *RouteFirewallRule) GetIsDynamic() bool { + if x != nil { + return x.IsDynamic + } + return false +} + +type PortInfo_Range struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"` + End uint32 `protobuf:"varint,2,opt,name=end,proto3" json:"end,omitempty"` +} + +func (x *PortInfo_Range) Reset() { + *x = PortInfo_Range{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[36] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *PortInfo_Range) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PortInfo_Range) ProtoMessage() {} + +func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[36] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PortInfo_Range.ProtoReflect.Descriptor instead. +func (*PortInfo_Range) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{34, 0} +} + +func (x *PortInfo_Range) GetStart() uint32 { + if x != nil { + return x.Start + } + return 0 +} + +func (x *PortInfo_Range) GetEnd() uint32 { + if x != nil { + return x.End + } + return 0 +} + var File_management_proto protoreflect.FileDescriptor var file_management_proto_rawDesc = []byte{ @@ -2835,7 +3083,7 @@ var file_management_proto_rawDesc = []byte{ 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0xe2, 0x03, 0x0a, 0x0a, + 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0xf3, 0x04, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, @@ -2866,184 +3114,219 @@ var file_management_proto_rawDesc = []byte{ 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, - 0x22, 0x97, 0x01, 0x0a, 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, - 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, - 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x18, - 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, - 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0x49, 0x0a, 0x09, 0x53, 0x53, - 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, - 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, - 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, - 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, - 0x75, 0x62, 0x4b, 0x65, 0x79, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, - 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, 0x69, - 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, - 0x6c, 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, - 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, 0x69, - 0x64, 0x65, 0x72, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, 0x0a, - 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, - 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, 0x50, 0x4b, 0x43, - 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, - 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, 0x50, 0x4b, 0x43, - 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, - 0x6f, 0x77, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0xea, 0x02, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, - 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, - 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, - 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, - 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, 0x69, - 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, - 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, 0x0a, - 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, - 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, - 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, - 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, - 0x69, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, 0x65, - 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x55, - 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, 0x74, - 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, - 0x6e, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, - 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, - 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x18, - 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, - 0x52, 0x4c, 0x73, 0x22, 0xed, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, - 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, - 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, - 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, - 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, - 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, - 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, - 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, - 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, - 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, - 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x44, - 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, - 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, - 0x74, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, - 0x75, 0x74, 0x65, 0x22, 0xb4, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, - 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, - 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, - 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, - 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, - 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x10, - 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, - 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, - 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, 0x43, - 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75, - 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, - 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, - 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, - 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, - 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, - 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, - 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, - 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, - 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, - 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, - 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, - 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, - 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, - 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, - 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, - 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, - 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, - 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, - 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, - 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, - 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, - 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, - 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, - 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xf0, 0x02, 0x0a, 0x0c, 0x46, - 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, - 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, - 0x72, 0x49, 0x50, 0x12, 0x40, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x22, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, - 0x2e, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, 0x65, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x37, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x2e, - 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x3d, - 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x21, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, - 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, - 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, - 0x74, 0x22, 0x1c, 0x0a, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, - 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x22, - 0x1e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, - 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x22, - 0x3c, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, - 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, 0x10, - 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, - 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x22, 0x38, 0x0a, - 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, - 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, - 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, - 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, - 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x32, 0x90, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, - 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, - 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, - 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, - 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, - 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, - 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, - 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, - 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, - 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, - 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, - 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, - 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x12, 0x4f, 0x0a, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, + 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1d, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, + 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x13, 0x72, 0x6f, + 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, + 0x73, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, + 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, + 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, + 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, + 0x79, 0x22, 0x97, 0x01, 0x0a, 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, + 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, + 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, + 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, + 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, + 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0x49, 0x0a, 0x09, 0x53, + 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, + 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, + 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, + 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, + 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, + 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, + 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, + 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, + 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, + 0x69, 0x64, 0x65, 0x72, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, + 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, + 0x0a, 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, 0x50, 0x4b, + 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, + 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, 0x50, 0x4b, + 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, + 0x6c, 0x6f, 0x77, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, + 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, + 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0xea, 0x02, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, + 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, + 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, + 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, + 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, + 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, + 0x0a, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, + 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, + 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, + 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, + 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, + 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, + 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, + 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, + 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, + 0x69, 0x6e, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, + 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, + 0x12, 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, + 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, + 0x55, 0x52, 0x4c, 0x73, 0x22, 0xed, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, + 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, + 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, + 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, + 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, + 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, + 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, + 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, + 0x72, 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, + 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, + 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, + 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, + 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, + 0x75, 0x74, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, + 0x6f, 0x75, 0x74, 0x65, 0x22, 0xb4, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, + 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, + 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, + 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, + 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, + 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, + 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, + 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x22, 0x58, 0x0a, 0x0a, 0x43, + 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, + 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, + 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, + 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, + 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, + 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, + 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, + 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, + 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, + 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, + 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, + 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, + 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, + 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, + 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, + 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, + 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, + 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, + 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, + 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, + 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, + 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xd9, 0x01, 0x0a, 0x0c, + 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, + 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, + 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, + 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, + 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, + 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, + 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, + 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, + 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, + 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, + 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, + 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, + 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, + 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, + 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, + 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, + 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, + 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, + 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, + 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, + 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, + 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x8f, 0x02, 0x0a, 0x11, 0x52, + 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, + 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, + 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, + 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, + 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, + 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, + 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, + 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, + 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x2a, 0x40, 0x0a, 0x0c, + 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, + 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, + 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, + 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x2a, 0x20, + 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, + 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, + 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, + 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, + 0x4f, 0x50, 0x10, 0x01, 0x32, 0x90, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, + 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, - 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, - 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, - 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, - 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, - 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, + 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, + 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, + 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, + 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, + 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, + 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, + 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, + 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, + 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, + 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, + 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, + 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, + 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, + 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -3059,13 +3342,13 @@ func file_management_proto_rawDescGZIP() []byte { } var file_management_proto_enumTypes = make([]protoimpl.EnumInfo, 5) -var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 34) +var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 37) var file_management_proto_goTypes = []interface{}{ - (HostConfig_Protocol)(0), // 0: management.HostConfig.Protocol - (DeviceAuthorizationFlowProvider)(0), // 1: management.DeviceAuthorizationFlow.provider - (FirewallRuleDirection)(0), // 2: management.FirewallRule.direction - (FirewallRuleAction)(0), // 3: management.FirewallRule.action - (FirewallRuleProtocol)(0), // 4: management.FirewallRule.protocol + (RuleProtocol)(0), // 0: management.RuleProtocol + (RuleDirection)(0), // 1: management.RuleDirection + (RuleAction)(0), // 2: management.RuleAction + (HostConfig_Protocol)(0), // 3: management.HostConfig.Protocol + (DeviceAuthorizationFlowProvider)(0), // 4: management.DeviceAuthorizationFlow.provider (*EncryptedMessage)(nil), // 5: management.EncryptedMessage (*SyncRequest)(nil), // 6: management.SyncRequest (*SyncResponse)(nil), // 7: management.SyncResponse @@ -3100,7 +3383,10 @@ var file_management_proto_goTypes = []interface{}{ (*FirewallRule)(nil), // 36: management.FirewallRule (*NetworkAddress)(nil), // 37: management.NetworkAddress (*Checks)(nil), // 38: management.Checks - (*timestamppb.Timestamp)(nil), // 39: google.protobuf.Timestamp + (*PortInfo)(nil), // 39: management.PortInfo + (*RouteFirewallRule)(nil), // 40: management.RouteFirewallRule + (*PortInfo_Range)(nil), // 41: management.PortInfo.Range + (*timestamppb.Timestamp)(nil), // 42: google.protobuf.Timestamp } var file_management_proto_depIdxs = []int32{ 13, // 0: management.SyncRequest.meta:type_name -> management.PeerSystemMeta @@ -3118,12 +3404,12 @@ var file_management_proto_depIdxs = []int32{ 17, // 12: management.LoginResponse.wiretrusteeConfig:type_name -> management.WiretrusteeConfig 21, // 13: management.LoginResponse.peerConfig:type_name -> management.PeerConfig 38, // 14: management.LoginResponse.Checks:type_name -> management.Checks - 39, // 15: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp + 42, // 15: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp 18, // 16: management.WiretrusteeConfig.stuns:type_name -> management.HostConfig 20, // 17: management.WiretrusteeConfig.turns:type_name -> management.ProtectedHostConfig 18, // 18: management.WiretrusteeConfig.signal:type_name -> management.HostConfig 19, // 19: management.WiretrusteeConfig.relay:type_name -> management.RelayConfig - 0, // 20: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol + 3, // 20: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol 18, // 21: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig 24, // 22: management.PeerConfig.sshConfig:type_name -> management.SSHConfig 21, // 23: management.NetworkMap.peerConfig:type_name -> management.PeerConfig @@ -3132,36 +3418,41 @@ var file_management_proto_depIdxs = []int32{ 31, // 26: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig 23, // 27: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig 36, // 28: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule - 24, // 29: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig - 1, // 30: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider - 29, // 31: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 29, // 32: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 34, // 33: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup - 32, // 34: management.DNSConfig.CustomZones:type_name -> management.CustomZone - 33, // 35: management.CustomZone.Records:type_name -> management.SimpleRecord - 35, // 36: management.NameServerGroup.NameServers:type_name -> management.NameServer - 2, // 37: management.FirewallRule.Direction:type_name -> management.FirewallRule.direction - 3, // 38: management.FirewallRule.Action:type_name -> management.FirewallRule.action - 4, // 39: management.FirewallRule.Protocol:type_name -> management.FirewallRule.protocol - 5, // 40: management.ManagementService.Login:input_type -> management.EncryptedMessage - 5, // 41: management.ManagementService.Sync:input_type -> management.EncryptedMessage - 16, // 42: management.ManagementService.GetServerKey:input_type -> management.Empty - 16, // 43: management.ManagementService.isHealthy:input_type -> management.Empty - 5, // 44: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage - 5, // 45: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage - 5, // 46: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage - 5, // 47: management.ManagementService.Login:output_type -> management.EncryptedMessage - 5, // 48: management.ManagementService.Sync:output_type -> management.EncryptedMessage - 15, // 49: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse - 16, // 50: management.ManagementService.isHealthy:output_type -> management.Empty - 5, // 51: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage - 5, // 52: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage - 16, // 53: management.ManagementService.SyncMeta:output_type -> management.Empty - 47, // [47:54] is the sub-list for method output_type - 40, // [40:47] is the sub-list for method input_type - 40, // [40:40] is the sub-list for extension type_name - 40, // [40:40] is the sub-list for extension extendee - 0, // [0:40] is the sub-list for field type_name + 40, // 29: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule + 24, // 30: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig + 4, // 31: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider + 29, // 32: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 29, // 33: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 34, // 34: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup + 32, // 35: management.DNSConfig.CustomZones:type_name -> management.CustomZone + 33, // 36: management.CustomZone.Records:type_name -> management.SimpleRecord + 35, // 37: management.NameServerGroup.NameServers:type_name -> management.NameServer + 1, // 38: management.FirewallRule.Direction:type_name -> management.RuleDirection + 2, // 39: management.FirewallRule.Action:type_name -> management.RuleAction + 0, // 40: management.FirewallRule.Protocol:type_name -> management.RuleProtocol + 41, // 41: management.PortInfo.range:type_name -> management.PortInfo.Range + 2, // 42: management.RouteFirewallRule.action:type_name -> management.RuleAction + 0, // 43: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol + 39, // 44: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo + 5, // 45: management.ManagementService.Login:input_type -> management.EncryptedMessage + 5, // 46: management.ManagementService.Sync:input_type -> management.EncryptedMessage + 16, // 47: management.ManagementService.GetServerKey:input_type -> management.Empty + 16, // 48: management.ManagementService.isHealthy:input_type -> management.Empty + 5, // 49: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage + 5, // 50: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage + 5, // 51: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage + 5, // 52: management.ManagementService.Login:output_type -> management.EncryptedMessage + 5, // 53: management.ManagementService.Sync:output_type -> management.EncryptedMessage + 15, // 54: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse + 16, // 55: management.ManagementService.isHealthy:output_type -> management.Empty + 5, // 56: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage + 5, // 57: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage + 16, // 58: management.ManagementService.SyncMeta:output_type -> management.Empty + 52, // [52:59] is the sub-list for method output_type + 45, // [45:52] is the sub-list for method input_type + 45, // [45:45] is the sub-list for extension type_name + 45, // [45:45] is the sub-list for extension extendee + 0, // [0:45] is the sub-list for field type_name } func init() { file_management_proto_init() } @@ -3578,6 +3869,46 @@ func file_management_proto_init() { return nil } } + file_management_proto_msgTypes[34].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*PortInfo); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[35].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RouteFirewallRule); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[36].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*PortInfo_Range); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + file_management_proto_msgTypes[34].OneofWrappers = []interface{}{ + (*PortInfo_Port)(nil), + (*PortInfo_Range_)(nil), } type x struct{} out := protoimpl.TypeBuilder{ @@ -3585,7 +3916,7 @@ func file_management_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_management_proto_rawDesc, NumEnums: 5, - NumMessages: 34, + NumMessages: 37, NumExtensions: 0, NumServices: 1, }, diff --git a/management/proto/management.proto b/management/proto/management.proto index c5646820f..fe6a828b1 100644 --- a/management/proto/management.proto +++ b/management/proto/management.proto @@ -254,6 +254,12 @@ message NetworkMap { // firewallRulesIsEmpty indicates whether FirewallRule array is empty or not to bypass protobuf null and empty array equality. bool firewallRulesIsEmpty = 9; + + // RoutesFirewallRules represents a list of routes firewall rules to be applied to peer + repeated RouteFirewallRule routesFirewallRules = 10; + + // RoutesFirewallRulesIsEmpty indicates whether RouteFirewallRule array is empty or not to bypass protobuf null and empty array equality. + bool routesFirewallRulesIsEmpty = 11; } // RemotePeerConfig represents a configuration of a remote peer. @@ -384,29 +390,32 @@ message NameServer { int64 Port = 3; } +enum RuleProtocol { + UNKNOWN = 0; + ALL = 1; + TCP = 2; + UDP = 3; + ICMP = 4; +} + +enum RuleDirection { + IN = 0; + OUT = 1; +} + +enum RuleAction { + ACCEPT = 0; + DROP = 1; +} + + // FirewallRule represents a firewall rule message FirewallRule { string PeerIP = 1; - direction Direction = 2; - action Action = 3; - protocol Protocol = 4; + RuleDirection Direction = 2; + RuleAction Action = 3; + RuleProtocol Protocol = 4; string Port = 5; - - enum direction { - IN = 0; - OUT = 1; - } - enum action { - ACCEPT = 0; - DROP = 1; - } - enum protocol { - UNKNOWN = 0; - ALL = 1; - TCP = 2; - UDP = 3; - ICMP = 4; - } } message NetworkAddress { @@ -415,5 +424,40 @@ message NetworkAddress { } message Checks { - repeated string Files= 1; + repeated string Files = 1; } + + +message PortInfo { + oneof portSelection { + uint32 port = 1; + Range range = 2; + } + + message Range { + uint32 start = 1; + uint32 end = 2; + } +} + +// RouteFirewallRule signifies a firewall rule applicable for a routed network. +message RouteFirewallRule { + // sourceRanges IP ranges of the routing peers. + repeated string sourceRanges = 1; + + // Action to be taken by the firewall when the rule is applicable. + RuleAction action = 2; + + // Network prefix for the routed network. + string destination = 3; + + // Protocol of the routed network. + RuleProtocol protocol = 4; + + // Details about the port. + PortInfo portInfo = 5; + + // IsDynamic indicates if the route is a DNS route. + bool isDynamic = 6; +} + diff --git a/management/server/account.go b/management/server/account.go index 710b6f62f..d5e8c8cf8 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -113,7 +113,7 @@ type AccountManager interface { DeletePolicy(ctx context.Context, accountID, policyID, userID string) error ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) - CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) + CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) SaveRoute(ctx context.Context, accountID, userID string, route *route.Route) error DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) @@ -460,6 +460,7 @@ func (a *Account) GetPeerNetworkMap( } routesUpdate := a.getRoutesToSync(ctx, peerID, peersToConnect) + routesFirewallRules := a.getPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap) dnsManagementStatus := a.getPeerDNSManagementStatus(peerID) dnsUpdate := nbdns.Config{ @@ -483,6 +484,7 @@ func (a *Account) GetPeerNetworkMap( DNSConfig: dnsUpdate, OfflinePeers: expiredPeers, FirewallRules: firewallRules, + RoutesFirewallRules: routesFirewallRules, } if metrics != nil { diff --git a/management/server/account_test.go b/management/server/account_test.go index 303261bea..e554ae493 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1599,9 +1599,10 @@ func TestAccount_Copy(t *testing.T) { }, Routes: map[route.ID]*route.Route{ "route1": { - ID: "route1", - PeerGroups: []string{}, - Groups: []string{"group1"}, + ID: "route1", + PeerGroups: []string{}, + Groups: []string{"group1"}, + AccessControlGroups: []string{}, }, }, NameServerGroups: map[string]*nbdns.NameServerGroup{ diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index cda3bc748..4c4ef6c3c 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -596,6 +596,10 @@ func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turn response.NetworkMap.FirewallRules = firewallRules response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0 + routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules) + response.NetworkMap.RoutesFirewallRules = routesFirewallRules + response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0 + return response } diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 2463f830e..fd0343e97 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -727,17 +727,39 @@ components: enum: ["all", "tcp", "udp", "icmp"] example: "tcp" ports: - description: Policy rule affected ports or it ranges list + description: Policy rule affected ports type: array items: type: string example: "80" + port_ranges: + description: Policy rule affected ports ranges list + type: array + items: + $ref: '#/components/schemas/RulePortRange' required: - name - enabled - bidirectional - protocol - action + + RulePortRange: + description: Policy rule affected ports range + type: object + properties: + start: + description: The starting port of the range + type: integer + example: 80 + end: + description: The ending port of the range + type: integer + example: 320 + required: + - start + - end + PolicyRuleUpdate: allOf: - $ref: '#/components/schemas/PolicyRuleMinimum' @@ -1106,6 +1128,12 @@ components: description: Indicate if the route should be kept after a domain doesn't resolve that IP anymore type: boolean example: true + access_control_groups: + description: Access control group identifier associated with route. + type: array + items: + type: string + example: "chacbco6lnnbn6cg5s91" required: - id - description diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index b219d38fd..570ec03c5 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -780,7 +780,10 @@ type PolicyRule struct { // Name Policy rule name identifier Name string `json:"name"` - // Ports Policy rule affected ports or it ranges list + // PortRanges Policy rule affected ports ranges list + PortRanges *[]RulePortRange `json:"port_ranges,omitempty"` + + // Ports Policy rule affected ports Ports *[]string `json:"ports,omitempty"` // Protocol Policy rule type of the traffic @@ -816,7 +819,10 @@ type PolicyRuleMinimum struct { // Name Policy rule name identifier Name string `json:"name"` - // Ports Policy rule affected ports or it ranges list + // PortRanges Policy rule affected ports ranges list + PortRanges *[]RulePortRange `json:"port_ranges,omitempty"` + + // Ports Policy rule affected ports Ports *[]string `json:"ports,omitempty"` // Protocol Policy rule type of the traffic @@ -852,7 +858,10 @@ type PolicyRuleUpdate struct { // Name Policy rule name identifier Name string `json:"name"` - // Ports Policy rule affected ports or it ranges list + // PortRanges Policy rule affected ports ranges list + PortRanges *[]RulePortRange `json:"port_ranges,omitempty"` + + // Ports Policy rule affected ports Ports *[]string `json:"ports,omitempty"` // Protocol Policy rule type of the traffic @@ -935,6 +944,9 @@ type ProcessCheck struct { // Route defines model for Route. type Route struct { + // AccessControlGroups Access control group identifier associated with route. + AccessControlGroups *[]string `json:"access_control_groups,omitempty"` + // Description Route description Description string `json:"description"` @@ -977,6 +989,9 @@ type Route struct { // RouteRequest defines model for RouteRequest. type RouteRequest struct { + // AccessControlGroups Access control group identifier associated with route. + AccessControlGroups *[]string `json:"access_control_groups,omitempty"` + // Description Route description Description string `json:"description"` @@ -1011,6 +1026,15 @@ type RouteRequest struct { PeerGroups *[]string `json:"peer_groups,omitempty"` } +// RulePortRange Policy rule affected ports range +type RulePortRange struct { + // End The ending port of the range + End int `json:"end"` + + // Start The starting port of the range + Start int `json:"start"` +} + // SetupKey defines model for SetupKey. type SetupKey struct { // AutoGroups List of group IDs to auto-assign to peers registered with this key diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go index 225d7e1f3..73f3803b5 100644 --- a/management/server/http/policies_handler.go +++ b/management/server/http/policies_handler.go @@ -172,6 +172,11 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID return } + if (rule.Ports != nil && len(*rule.Ports) != 0) && (rule.PortRanges != nil && len(*rule.PortRanges) != 0) { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "specify either individual ports or port ranges, not both"), w) + return + } + if rule.Ports != nil && len(*rule.Ports) != 0 { for _, v := range *rule.Ports { if port, err := strconv.Atoi(v); err != nil || port < 1 || port > 65535 { @@ -182,10 +187,23 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID } } + if rule.PortRanges != nil && len(*rule.PortRanges) != 0 { + for _, portRange := range *rule.PortRanges { + if portRange.Start < 1 || portRange.End > 65535 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "valid port value is in 1..65535 range"), w) + return + } + pr.PortRanges = append(pr.PortRanges, server.RulePortRange{ + Start: uint16(portRange.Start), + End: uint16(portRange.End), + }) + } + } + // validate policy object switch pr.Protocol { case server.PolicyRuleProtocolALL, server.PolicyRuleProtocolICMP: - if len(pr.Ports) != 0 { + if len(pr.Ports) != 0 || len(pr.PortRanges) != 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w) return } @@ -194,7 +212,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID return } case server.PolicyRuleProtocolTCP, server.PolicyRuleProtocolUDP: - if !pr.Bidirectional && len(pr.Ports) == 0 { + if !pr.Bidirectional && (len(pr.Ports) == 0 || len(pr.PortRanges) != 0) { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w) return } @@ -320,6 +338,17 @@ func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Polic rule.Ports = &portsCopy } + if len(r.PortRanges) != 0 { + portRanges := make([]api.RulePortRange, 0, len(r.PortRanges)) + for _, portRange := range r.PortRanges { + portRanges = append(portRanges, api.RulePortRange{ + End: int(portRange.End), + Start: int(portRange.Start), + }) + } + rule.PortRanges = &portRanges + } + for _, gid := range r.Sources { _, ok := cache[gid] if ok { diff --git a/management/server/http/routes_handler.go b/management/server/http/routes_handler.go index 0932e6445..ce4edee4f 100644 --- a/management/server/http/routes_handler.go +++ b/management/server/http/routes_handler.go @@ -117,9 +117,14 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { peerGroupIds = *req.PeerGroups } + var accessControlGroupIds []string + if req.AccessControlGroups != nil { + accessControlGroupIds = *req.AccessControlGroups + } + newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds, - req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, userID, req.KeepRoute, - ) + req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userID, req.KeepRoute) + if err != nil { util.WriteError(r.Context(), err, w) return @@ -233,6 +238,10 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { newRoute.PeerGroups = *req.PeerGroups } + if req.AccessControlGroups != nil { + newRoute.AccessControlGroups = *req.AccessControlGroups + } + err = h.accountManager.SaveRoute(r.Context(), accountID, userID, newRoute) if err != nil { util.WriteError(r.Context(), err, w) @@ -326,6 +335,9 @@ func toRouteResponse(serverRoute *route.Route) (*api.Route, error) { if len(serverRoute.PeerGroups) > 0 { route.PeerGroups = &serverRoute.PeerGroups } + if len(serverRoute.AccessControlGroups) > 0 { + route.AccessControlGroups = &serverRoute.AccessControlGroups + } return route, nil } diff --git a/management/server/http/routes_handler_test.go b/management/server/http/routes_handler_test.go index 2c367cac3..83bd7004d 100644 --- a/management/server/http/routes_handler_test.go +++ b/management/server/http/routes_handler_test.go @@ -105,7 +105,7 @@ func initRoutesTestData() *RoutesHandler { } return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID) }, - CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) { + CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) { if peerID == notFoundPeerID { return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) } @@ -119,18 +119,19 @@ func initRoutesTestData() *RoutesHandler { } return &route.Route{ - ID: existingRouteID, - NetID: netID, - Peer: peerID, - PeerGroups: peerGroups, - Network: prefix, - Domains: domains, - NetworkType: networkType, - Description: description, - Masquerade: masquerade, - Enabled: enabled, - Groups: groups, - KeepRoute: keepRoute, + ID: existingRouteID, + NetID: netID, + Peer: peerID, + PeerGroups: peerGroups, + Network: prefix, + Domains: domains, + NetworkType: networkType, + Description: description, + Masquerade: masquerade, + Enabled: enabled, + Groups: groups, + KeepRoute: keepRoute, + AccessControlGroups: accessControlGroups, }, nil }, SaveRouteFunc: func(_ context.Context, _, _ string, r *route.Route) error { @@ -268,6 +269,27 @@ func TestRoutesHandlers(t *testing.T) { Groups: []string{existingGroupID}, }, }, + { + name: "POST OK With Access Control Groups", + requestType: http.MethodPost, + requestPath: "/api/routes", + requestBody: bytes.NewBuffer( + []byte(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"],\"access_control_groups\":[\"%s\"]}", existingPeerID, existingGroupID, existingGroupID))), + expectedStatus: http.StatusOK, + expectedBody: true, + expectedRoute: &api.Route{ + Id: existingRouteID, + Description: "Post", + NetworkId: "awesomeNet", + Network: toPtr("192.168.0.0/16"), + Peer: &existingPeerID, + NetworkType: route.IPv4NetworkString, + Masquerade: false, + Enabled: false, + Groups: []string{existingGroupID}, + AccessControlGroups: &[]string{existingGroupID}, + }, + }, { name: "POST Non Linux Peer", requestType: http.MethodPost, diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index df12ec1c4..b399be822 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -58,7 +58,7 @@ type MockAccountManager struct { UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error UpdatePeerSSHKeyFunc func(ctx context.Context, peerID string, sshKey string) error UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) - CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) + CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups,accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error @@ -367,7 +367,7 @@ func (am *MockAccountManager) DeleteRule(ctx context.Context, accountID, ruleID, if am.DeleteRuleFunc != nil { return am.DeleteRuleFunc(ctx, accountID, ruleID, userID) } - return status.Errorf(codes.Unimplemented, "method DeleteRule is not implemented") + return status.Errorf(codes.Unimplemented, "method DeletePeerRule is not implemented") } // GetPolicy mock implementation of GetPolicy from server.AccountManager interface @@ -442,9 +442,9 @@ func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID } // CreateRoute mock implementation of CreateRoute from server.AccountManager interface -func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { +func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { if am.CreateRouteFunc != nil { - return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, enabled, userID, keepRoute) + return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups,accessControlGroupID, enabled, userID, keepRoute) } return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented") } diff --git a/management/server/network.go b/management/server/network.go index 0e7d753a7..a5b188b46 100644 --- a/management/server/network.go +++ b/management/server/network.go @@ -26,12 +26,13 @@ const ( ) type NetworkMap struct { - Peers []*nbpeer.Peer - Network *Network - Routes []*route.Route - DNSConfig nbdns.Config - OfflinePeers []*nbpeer.Peer - FirewallRules []*FirewallRule + Peers []*nbpeer.Peer + Network *Network + Routes []*route.Route + DNSConfig nbdns.Config + OfflinePeers []*nbpeer.Peer + FirewallRules []*FirewallRule + RoutesFirewallRules []*RouteFirewallRule } type Network struct { diff --git a/management/server/peer_test.go b/management/server/peer_test.go index d329e04bc..387adb91d 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -646,7 +646,6 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { }) } - } func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccountManager, string, string, error) { @@ -991,9 +990,9 @@ func TestToSyncResponse(t *testing.T) { // assert network map Firewall assert.Equal(t, 1, len(response.NetworkMap.FirewallRules)) assert.Equal(t, "192.168.1.2", response.NetworkMap.FirewallRules[0].PeerIP) - assert.Equal(t, proto.FirewallRule_IN, response.NetworkMap.FirewallRules[0].Direction) - assert.Equal(t, proto.FirewallRule_ACCEPT, response.NetworkMap.FirewallRules[0].Action) - assert.Equal(t, proto.FirewallRule_TCP, response.NetworkMap.FirewallRules[0].Protocol) + assert.Equal(t, proto.RuleDirection_IN, response.NetworkMap.FirewallRules[0].Direction) + assert.Equal(t, proto.RuleAction_ACCEPT, response.NetworkMap.FirewallRules[0].Action) + assert.Equal(t, proto.RuleProtocol_TCP, response.NetworkMap.FirewallRules[0].Protocol) assert.Equal(t, "80", response.NetworkMap.FirewallRules[0].Port) // assert posture checks assert.Equal(t, 1, len(response.Checks)) diff --git a/management/server/policy.go b/management/server/policy.go index 5d07ba8f8..75647de44 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -76,6 +76,12 @@ type PolicyUpdateOperation struct { Values []string } +// RulePortRange represents a range of ports for a firewall rule. +type RulePortRange struct { + Start uint16 + End uint16 +} + // PolicyRule is the metadata of the policy type PolicyRule struct { // ID of the policy rule @@ -110,6 +116,9 @@ type PolicyRule struct { // Ports or it ranges list Ports []string `gorm:"serializer:json"` + + // PortRanges a list of port ranges. + PortRanges []RulePortRange `gorm:"serializer:json"` } // Copy returns a copy of a policy rule @@ -125,10 +134,12 @@ func (pm *PolicyRule) Copy() *PolicyRule { Bidirectional: pm.Bidirectional, Protocol: pm.Protocol, Ports: make([]string, len(pm.Ports)), + PortRanges: make([]RulePortRange, len(pm.PortRanges)), } copy(rule.Destinations, pm.Destinations) copy(rule.Sources, pm.Sources) copy(rule.Ports, pm.Ports) + copy(rule.PortRanges, pm.PortRanges) return rule } @@ -445,36 +456,17 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli return nil } -func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule { - result := make([]*proto.FirewallRule, len(update)) - for i := range update { - direction := proto.FirewallRule_IN - if update[i].Direction == firewallRuleDirectionOUT { - direction = proto.FirewallRule_OUT - } - action := proto.FirewallRule_ACCEPT - if update[i].Action == string(PolicyTrafficActionDrop) { - action = proto.FirewallRule_DROP - } - - protocol := proto.FirewallRule_UNKNOWN - switch PolicyRuleProtocolType(update[i].Protocol) { - case PolicyRuleProtocolALL: - protocol = proto.FirewallRule_ALL - case PolicyRuleProtocolTCP: - protocol = proto.FirewallRule_TCP - case PolicyRuleProtocolUDP: - protocol = proto.FirewallRule_UDP - case PolicyRuleProtocolICMP: - protocol = proto.FirewallRule_ICMP - } +func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { + result := make([]*proto.FirewallRule, len(rules)) + for i := range rules { + rule := rules[i] result[i] = &proto.FirewallRule{ - PeerIP: update[i].PeerIP, - Direction: direction, - Action: action, - Protocol: protocol, - Port: update[i].Port, + PeerIP: rule.PeerIP, + Direction: getProtoDirection(rule.Direction), + Action: getProtoAction(rule.Action), + Protocol: getProtoProtocol(rule.Protocol), + Port: rule.Port, } } return result diff --git a/management/server/route.go b/management/server/route.go index 6c1c8b1b3..39ee6170c 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -4,9 +4,15 @@ import ( "context" "fmt" "net/netip" + "slices" + "strconv" + "strings" "unicode/utf8" "github.com/rs/xid" + log "github.com/sirupsen/logrus" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" @@ -15,6 +21,30 @@ import ( "github.com/netbirdio/netbird/route" ) +// RouteFirewallRule a firewall rule applicable for a routed network. +type RouteFirewallRule struct { + // SourceRanges IP ranges of the routing peers. + SourceRanges []string + + // Action of the traffic when the rule is applicable + Action string + + // Destination a network prefix for the routed traffic + Destination string + + // Protocol of the traffic + Protocol string + + // Port of the traffic + Port uint16 + + // PortRange represents the range of ports for a firewall rule + PortRange RulePortRange + + // isDynamic indicates whether the rule is for DNS routing + IsDynamic bool +} + // GetRoute gets a route object from account and route IDs func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) @@ -112,7 +142,7 @@ func getRouteDescriptor(prefix netip.Prefix, domains domain.List) string { } // CreateRoute creates and saves a new route -func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { +func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -157,6 +187,13 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri } } + if len(accessControlGroupIDs) > 0 { + err = validateGroups(accessControlGroupIDs, account.Groups) + if err != nil { + return nil, err + } + } + err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains) if err != nil { return nil, err @@ -187,6 +224,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri newRoute.Enabled = enabled newRoute.Groups = groups newRoute.KeepRoute = keepRoute + newRoute.AccessControlGroups = accessControlGroupIDs if account.Routes == nil { account.Routes = make(map[route.ID]*route.Route) @@ -258,6 +296,13 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI } } + if len(routeToSave.AccessControlGroups) > 0 { + err = validateGroups(routeToSave.AccessControlGroups, account.Groups) + if err != nil { + return err + } + } + err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains) if err != nil { return err @@ -351,3 +396,248 @@ func getPlaceholderIP() netip.Prefix { // Using an IP from the documentation range to minimize impact in case older clients try to set a route return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32) } + +// getPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account. +func (a *Account) getPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule { + routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes)) + + enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID) + for _, route := range enabledRoutes { + // If no access control groups are specified, accept all traffic. + if len(route.AccessControlGroups) == 0 { + defaultPermit := getDefaultPermit(route) + routesFirewallRules = append(routesFirewallRules, defaultPermit...) + continue + } + + policies := getAllRoutePoliciesFromGroups(a, route.AccessControlGroups) + for _, policy := range policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + distributionGroupPeers, _ := a.getAllPeersFromGroups(ctx, route.Groups, peerID, nil, validatedPeersMap) + rules := generateRouteFirewallRules(ctx, route, rule, distributionGroupPeers, firewallRuleDirectionIN) + routesFirewallRules = append(routesFirewallRules, rules...) + } + } + } + + return routesFirewallRules +} + +func getDefaultPermit(route *route.Route) []*RouteFirewallRule { + var rules []*RouteFirewallRule + + sources := []string{"0.0.0.0/0"} + if route.Network.Addr().Is6() { + sources = []string{"::/0"} + } + rule := RouteFirewallRule{ + SourceRanges: sources, + Action: string(PolicyTrafficActionAccept), + Destination: route.Network.String(), + Protocol: string(PolicyRuleProtocolALL), + IsDynamic: route.IsDynamic(), + } + + rules = append(rules, &rule) + + // dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally + if route.IsDynamic() { + ruleV6 := rule + ruleV6.SourceRanges = []string{"::/0"} + rules = append(rules, &ruleV6) + } + + return rules +} + +// getAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups +// and returns a list of policies that have rules with destinations matching the specified groups. +func getAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy { + routePolicies := make([]*Policy, 0) + for _, groupID := range accessControlGroups { + group, ok := account.Groups[groupID] + if !ok { + continue + } + + for _, policy := range account.Policies { + for _, rule := range policy.Rules { + exist := slices.ContainsFunc(rule.Destinations, func(groupID string) bool { + return groupID == group.ID + }) + if exist { + routePolicies = append(routePolicies, policy) + continue + } + } + } + } + + return routePolicies +} + +// generateRouteFirewallRules generates a list of firewall rules for a given route. +func generateRouteFirewallRules(ctx context.Context, route *route.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule { + rulesExists := make(map[string]struct{}) + rules := make([]*RouteFirewallRule, 0) + + sourceRanges := make([]string, 0, len(groupPeers)) + for _, peer := range groupPeers { + if peer == nil { + continue + } + sourceRanges = append(sourceRanges, fmt.Sprintf(AllowedIPsFormat, peer.IP)) + } + + baseRule := RouteFirewallRule{ + SourceRanges: sourceRanges, + Action: string(rule.Action), + Destination: route.Network.String(), + Protocol: string(rule.Protocol), + IsDynamic: route.IsDynamic(), + } + + // generate rule for port range + if len(rule.Ports) == 0 { + rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...) + } else { + rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...) + + } + + // TODO: generate IPv6 rules for dynamic routes + + return rules +} + +// generateRuleIDBase generates the base rule ID for checking duplicates. +func generateRuleIDBase(rule *PolicyRule, baseRule RouteFirewallRule) string { + return rule.ID + strings.Join(baseRule.SourceRanges, ",") + strconv.Itoa(firewallRuleDirectionIN) + baseRule.Protocol + baseRule.Action +} + +// generateRulesForPeer generates rules for a given peer based on ports and port ranges. +func generateRulesWithPortRanges(baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule { + rules := make([]*RouteFirewallRule, 0) + + ruleIDBase := generateRuleIDBase(rule, baseRule) + if len(rule.Ports) == 0 { + if len(rule.PortRanges) == 0 { + if _, ok := rulesExists[ruleIDBase]; !ok { + rulesExists[ruleIDBase] = struct{}{} + rules = append(rules, &baseRule) + } + } else { + for _, portRange := range rule.PortRanges { + ruleID := fmt.Sprintf("%s%d-%d", ruleIDBase, portRange.Start, portRange.End) + if _, ok := rulesExists[ruleID]; !ok { + rulesExists[ruleID] = struct{}{} + pr := baseRule + pr.PortRange = portRange + rules = append(rules, &pr) + } + } + } + return rules + } + + return rules +} + +// generateRulesWithPorts generates rules when specific ports are provided. +func generateRulesWithPorts(ctx context.Context, baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule { + rules := make([]*RouteFirewallRule, 0) + ruleIDBase := generateRuleIDBase(rule, baseRule) + + for _, port := range rule.Ports { + ruleID := ruleIDBase + port + if _, ok := rulesExists[ruleID]; ok { + continue + } + rulesExists[ruleID] = struct{}{} + + pr := baseRule + p, err := strconv.ParseUint(port, 10, 16) + if err != nil { + log.WithContext(ctx).Errorf("failed to parse port %s for rule: %s", port, rule.ID) + continue + } + + pr.Port = uint16(p) + rules = append(rules, &pr) + } + + return rules +} + +func toProtocolRoutesFirewallRules(rules []*RouteFirewallRule) []*proto.RouteFirewallRule { + result := make([]*proto.RouteFirewallRule, len(rules)) + for i := range rules { + rule := rules[i] + result[i] = &proto.RouteFirewallRule{ + SourceRanges: rule.SourceRanges, + Action: getProtoAction(rule.Action), + Destination: rule.Destination, + Protocol: getProtoProtocol(rule.Protocol), + PortInfo: getProtoPortInfo(rule), + IsDynamic: rule.IsDynamic, + } + } + + return result +} + +// getProtoDirection converts the direction to proto.RuleDirection. +func getProtoDirection(direction int) proto.RuleDirection { + if direction == firewallRuleDirectionOUT { + return proto.RuleDirection_OUT + } + return proto.RuleDirection_IN +} + +// getProtoAction converts the action to proto.RuleAction. +func getProtoAction(action string) proto.RuleAction { + if action == string(PolicyTrafficActionDrop) { + return proto.RuleAction_DROP + } + return proto.RuleAction_ACCEPT +} + +// getProtoProtocol converts the protocol to proto.RuleProtocol. +func getProtoProtocol(protocol string) proto.RuleProtocol { + switch PolicyRuleProtocolType(protocol) { + case PolicyRuleProtocolALL: + return proto.RuleProtocol_ALL + case PolicyRuleProtocolTCP: + return proto.RuleProtocol_TCP + case PolicyRuleProtocolUDP: + return proto.RuleProtocol_UDP + case PolicyRuleProtocolICMP: + return proto.RuleProtocol_ICMP + default: + return proto.RuleProtocol_UNKNOWN + } +} + +// getProtoPortInfo converts the port info to proto.PortInfo. +func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo { + var portInfo proto.PortInfo + if rule.Port != 0 { + portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)} + } else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 { + portInfo.PortSelection = &proto.PortInfo_Range_{ + Range: &proto.PortInfo_Range{ + Start: uint32(portRange.Start), + End: uint32(portRange.End), + }, + } + } + return &portInfo +} diff --git a/management/server/route_test.go b/management/server/route_test.go index 4533c6b7e..b556816be 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -2,6 +2,8 @@ package server import ( "context" + "fmt" + "net" "net/netip" "testing" @@ -44,18 +46,19 @@ var existingDomains = domain.List{"example.com"} func TestCreateRoute(t *testing.T) { type input struct { - network netip.Prefix - domains domain.List - keepRoute bool - networkType route.NetworkType - netID route.NetID - peerKey string - peerGroupIDs []string - description string - masquerade bool - metric int - enabled bool - groups []string + network netip.Prefix + domains domain.List + keepRoute bool + networkType route.NetworkType + netID route.NetID + peerKey string + peerGroupIDs []string + description string + masquerade bool + metric int + enabled bool + groups []string + accessControlGroups []string } testCases := []struct { @@ -69,100 +72,107 @@ func TestCreateRoute(t *testing.T) { { name: "Happy Path Network", inputArgs: input{ - network: netip.MustParsePrefix("192.168.0.0/16"), - networkType: route.IPv4Network, - netID: "happy", - peerKey: peer1ID, - description: "super", - masquerade: false, - metric: 9999, - enabled: true, - groups: []string{routeGroup1}, + network: netip.MustParsePrefix("192.168.0.0/16"), + networkType: route.IPv4Network, + netID: "happy", + peerKey: peer1ID, + description: "super", + masquerade: false, + metric: 9999, + enabled: true, + groups: []string{routeGroup1}, + accessControlGroups: []string{routeGroup1}, }, errFunc: require.NoError, shouldCreate: true, expectedRoute: &route.Route{ - Network: netip.MustParsePrefix("192.168.0.0/16"), - NetworkType: route.IPv4Network, - NetID: "happy", - Peer: peer1ID, - Description: "super", - Masquerade: false, - Metric: 9999, - Enabled: true, - Groups: []string{routeGroup1}, + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetworkType: route.IPv4Network, + NetID: "happy", + Peer: peer1ID, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + AccessControlGroups: []string{routeGroup1}, }, }, { name: "Happy Path Domains", inputArgs: input{ - domains: domain.List{"domain1", "domain2"}, - keepRoute: true, - networkType: route.DomainNetwork, - netID: "happy", - peerKey: peer1ID, - description: "super", - masquerade: false, - metric: 9999, - enabled: true, - groups: []string{routeGroup1}, + domains: domain.List{"domain1", "domain2"}, + keepRoute: true, + networkType: route.DomainNetwork, + netID: "happy", + peerKey: peer1ID, + description: "super", + masquerade: false, + metric: 9999, + enabled: true, + groups: []string{routeGroup1}, + accessControlGroups: []string{routeGroup1}, }, errFunc: require.NoError, shouldCreate: true, expectedRoute: &route.Route{ - Network: netip.MustParsePrefix("192.0.2.0/32"), - Domains: domain.List{"domain1", "domain2"}, - NetworkType: route.DomainNetwork, - NetID: "happy", - Peer: peer1ID, - Description: "super", - Masquerade: false, - Metric: 9999, - Enabled: true, - Groups: []string{routeGroup1}, - KeepRoute: true, + Network: netip.MustParsePrefix("192.0.2.0/32"), + Domains: domain.List{"domain1", "domain2"}, + NetworkType: route.DomainNetwork, + NetID: "happy", + Peer: peer1ID, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + KeepRoute: true, + AccessControlGroups: []string{routeGroup1}, }, }, { name: "Happy Path Peer Groups", inputArgs: input{ - network: netip.MustParsePrefix("192.168.0.0/16"), - networkType: route.IPv4Network, - netID: "happy", - peerGroupIDs: []string{routeGroupHA1, routeGroupHA2}, - description: "super", - masquerade: false, - metric: 9999, - enabled: true, - groups: []string{routeGroup1, routeGroup2}, + network: netip.MustParsePrefix("192.168.0.0/16"), + networkType: route.IPv4Network, + netID: "happy", + peerGroupIDs: []string{routeGroupHA1, routeGroupHA2}, + description: "super", + masquerade: false, + metric: 9999, + enabled: true, + groups: []string{routeGroup1, routeGroup2}, + accessControlGroups: []string{routeGroup1, routeGroup2}, }, errFunc: require.NoError, shouldCreate: true, expectedRoute: &route.Route{ - Network: netip.MustParsePrefix("192.168.0.0/16"), - NetworkType: route.IPv4Network, - NetID: "happy", - PeerGroups: []string{routeGroupHA1, routeGroupHA2}, - Description: "super", - Masquerade: false, - Metric: 9999, - Enabled: true, - Groups: []string{routeGroup1, routeGroup2}, + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetworkType: route.IPv4Network, + NetID: "happy", + PeerGroups: []string{routeGroupHA1, routeGroupHA2}, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1, routeGroup2}, + AccessControlGroups: []string{routeGroup1, routeGroup2}, }, }, { name: "Both network and domains provided should fail", inputArgs: input{ - network: netip.MustParsePrefix("192.168.0.0/16"), - domains: domain.List{"domain1", "domain2"}, - netID: "happy", - peerKey: peer1ID, - peerGroupIDs: []string{routeGroupHA1}, - description: "super", - masquerade: false, - metric: 9999, - enabled: true, - groups: []string{routeGroup1}, + network: netip.MustParsePrefix("192.168.0.0/16"), + domains: domain.List{"domain1", "domain2"}, + netID: "happy", + peerKey: peer1ID, + peerGroupIDs: []string{routeGroupHA1}, + description: "super", + masquerade: false, + metric: 9999, + enabled: true, + groups: []string{routeGroup1}, + accessControlGroups: []string{routeGroup2}, }, errFunc: require.Error, shouldCreate: false, @@ -170,16 +180,17 @@ func TestCreateRoute(t *testing.T) { { name: "Both peer and peer_groups Provided Should Fail", inputArgs: input{ - network: netip.MustParsePrefix("192.168.0.0/16"), - networkType: route.IPv4Network, - netID: "happy", - peerKey: peer1ID, - peerGroupIDs: []string{routeGroupHA1}, - description: "super", - masquerade: false, - metric: 9999, - enabled: true, - groups: []string{routeGroup1}, + network: netip.MustParsePrefix("192.168.0.0/16"), + networkType: route.IPv4Network, + netID: "happy", + peerKey: peer1ID, + peerGroupIDs: []string{routeGroupHA1}, + description: "super", + masquerade: false, + metric: 9999, + enabled: true, + groups: []string{routeGroup1}, + accessControlGroups: []string{routeGroup2}, }, errFunc: require.Error, shouldCreate: false, @@ -423,13 +434,13 @@ func TestCreateRoute(t *testing.T) { if testCase.createInitRoute { groupAll, errInit := account.GetGroupAll() require.NoError(t, errInit) - _, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID, false) + _, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false) require.NoError(t, errInit) - _, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID, false) + _, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false) require.NoError(t, errInit) } - outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute) + outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute) testCase.errFunc(t, err) @@ -1037,15 +1048,16 @@ func TestDeleteRoute(t *testing.T) { func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { baseRoute := &route.Route{ - Network: netip.MustParsePrefix("192.168.0.0/16"), - NetID: "superNet", - NetworkType: route.IPv4Network, - PeerGroups: []string{routeGroupHA1, routeGroupHA2}, - Description: "ha route", - Masquerade: false, - Metric: 9999, - Enabled: true, - Groups: []string{routeGroup1, routeGroup2}, + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetID: "superNet", + NetworkType: route.IPv4Network, + PeerGroups: []string{routeGroupHA1, routeGroupHA2}, + Description: "ha route", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1, routeGroup2}, + AccessControlGroups: []string{routeGroup1}, } am, err := createRouterManager(t) @@ -1062,7 +1074,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { require.NoError(t, err) require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") - newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.Enabled, userID, baseRoute.KeepRoute) + newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute) require.NoError(t, err) require.Equal(t, newRoute.Enabled, true) @@ -1127,16 +1139,17 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { // no routes for peer in different groups // no routes when route is deleted baseRoute := &route.Route{ - ID: "testingRoute", - Network: netip.MustParsePrefix("192.168.0.0/16"), - NetID: "superNet", - NetworkType: route.IPv4Network, - Peer: peer1ID, - Description: "super", - Masquerade: false, - Metric: 9999, - Enabled: true, - Groups: []string{routeGroup1}, + ID: "testingRoute", + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetID: "superNet", + NetworkType: route.IPv4Network, + Peer: peer1ID, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + AccessControlGroups: []string{routeGroup1}, } am, err := createRouterManager(t) @@ -1153,7 +1166,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { require.NoError(t, err) require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") - createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, false, userID, baseRoute.KeepRoute) + createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute) require.NoError(t, err) noDisabledRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) @@ -1467,3 +1480,300 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er return am.Store.GetAccount(context.Background(), account.Id) } + +func TestAccount_getPeersRoutesFirewall(t *testing.T) { + var ( + peerBIp = "100.65.80.39" + peerCIp = "100.65.254.139" + peerHIp = "100.65.29.55" + ) + + account := &Account{ + Peers: map[string]*nbpeer.Peer{ + "peerA": { + ID: "peerA", + IP: net.ParseIP("100.65.14.88"), + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + }, + }, + "peerB": { + ID: "peerB", + IP: net.ParseIP(peerBIp), + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{}, + }, + "peerC": { + ID: "peerC", + IP: net.ParseIP(peerCIp), + Status: &nbpeer.PeerStatus{}, + }, + "peerD": { + ID: "peerD", + IP: net.ParseIP("100.65.62.5"), + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + }, + }, + "peerE": { + ID: "peerE", + IP: net.ParseIP("100.65.32.206"), + Key: peer1Key, + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + }, + }, + "peerF": { + ID: "peerF", + IP: net.ParseIP("100.65.250.202"), + Status: &nbpeer.PeerStatus{}, + }, + "peerG": { + ID: "peerG", + IP: net.ParseIP("100.65.13.186"), + Status: &nbpeer.PeerStatus{}, + }, + "peerH": { + ID: "peerH", + IP: net.ParseIP(peerHIp), + Status: &nbpeer.PeerStatus{}, + }, + }, + Groups: map[string]*nbgroup.Group{ + "routingPeer1": { + ID: "routingPeer1", + Name: "RoutingPeer1", + Peers: []string{ + "peerA", + }, + }, + "routingPeer2": { + ID: "routingPeer2", + Name: "RoutingPeer2", + Peers: []string{ + "peerD", + }, + }, + "route1": { + ID: "route1", + Name: "Route1", + Peers: []string{}, + }, + "route2": { + ID: "route2", + Name: "Route2", + Peers: []string{}, + }, + "finance": { + ID: "finance", + Name: "Finance", + Peers: []string{ + "peerF", + "peerG", + }, + }, + "dev": { + ID: "dev", + Name: "Dev", + Peers: []string{ + "peerC", + "peerH", + "peerB", + }, + }, + "contractors": { + ID: "contractors", + Name: "Contractors", + Peers: []string{}, + }, + }, + Routes: map[route.ID]*route.Route{ + "route1": { + ID: "route1", + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetID: "route1", + NetworkType: route.IPv4Network, + PeerGroups: []string{"routingPeer1", "routingPeer2"}, + Description: "Route1 ha route", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{"dev"}, + AccessControlGroups: []string{"route1"}, + }, + "route2": { + ID: "route2", + Network: existingNetwork, + NetID: "route2", + NetworkType: route.IPv4Network, + Peer: "peerE", + Description: "Allow", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{"finance"}, + AccessControlGroups: []string{"route2"}, + }, + "route3": { + ID: "route3", + Network: netip.MustParsePrefix("192.0.2.0/32"), + Domains: domain.List{"example.com"}, + NetID: "route3", + NetworkType: route.DomainNetwork, + Peer: "peerE", + Description: "Allow all traffic to routed DNS network", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{"contractors"}, + AccessControlGroups: []string{}, + }, + }, + Policies: []*Policy{ + { + ID: "RuleRoute1", + Name: "Route1", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "RuleRoute1", + Name: "ruleRoute1", + Bidirectional: true, + Enabled: true, + Protocol: PolicyRuleProtocolALL, + Action: PolicyTrafficActionAccept, + Ports: []string{"80", "320"}, + Sources: []string{ + "dev", + }, + Destinations: []string{ + "route1", + }, + }, + }, + }, + { + ID: "RuleRoute2", + Name: "Route2", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "RuleRoute2", + Name: "ruleRoute2", + Bidirectional: true, + Enabled: true, + Protocol: PolicyRuleProtocolTCP, + Action: PolicyTrafficActionAccept, + PortRanges: []RulePortRange{ + { + Start: 80, + End: 350, + }, { + Start: 80, + End: 350, + }, + }, + Sources: []string{ + "finance", + }, + Destinations: []string{ + "route2", + }, + }, + }, + }, + }, + } + + validatedPeers := make(map[string]struct{}) + for p := range account.Peers { + validatedPeers[p] = struct{}{} + } + + t.Run("check applied policies for the route", func(t *testing.T) { + route1 := account.Routes["route1"] + policies := getAllRoutePoliciesFromGroups(account, route1.AccessControlGroups) + assert.Len(t, policies, 1) + + route2 := account.Routes["route2"] + policies = getAllRoutePoliciesFromGroups(account, route2.AccessControlGroups) + assert.Len(t, policies, 1) + + route3 := account.Routes["route3"] + policies = getAllRoutePoliciesFromGroups(account, route3.AccessControlGroups) + assert.Len(t, policies, 0) + }) + + t.Run("check peer routes firewall rules", func(t *testing.T) { + routesFirewallRules := account.getPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers) + assert.Len(t, routesFirewallRules, 2) + + expectedRoutesFirewallRules := []*RouteFirewallRule{ + { + SourceRanges: []string{ + fmt.Sprintf(AllowedIPsFormat, peerCIp), + fmt.Sprintf(AllowedIPsFormat, peerHIp), + fmt.Sprintf(AllowedIPsFormat, peerBIp), + }, + Action: "accept", + Destination: "192.168.0.0/16", + Protocol: "all", + Port: 80, + }, + { + SourceRanges: []string{ + fmt.Sprintf(AllowedIPsFormat, peerCIp), + fmt.Sprintf(AllowedIPsFormat, peerHIp), + fmt.Sprintf(AllowedIPsFormat, peerBIp), + }, + Action: "accept", + Destination: "192.168.0.0/16", + Protocol: "all", + Port: 320, + }, + } + assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) + + //peerD is also the routing peer for route1, should contain same routes firewall rules as peerA + routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers) + assert.Len(t, routesFirewallRules, 2) + assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) + + // peerE is a single routing peer for route 2 and route 3 + routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers) + assert.Len(t, routesFirewallRules, 3) + + expectedRoutesFirewallRules = []*RouteFirewallRule{ + { + SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"}, + Action: "accept", + Destination: existingNetwork.String(), + Protocol: "tcp", + PortRange: RulePortRange{Start: 80, End: 350}, + }, + { + SourceRanges: []string{"0.0.0.0/0"}, + Action: "accept", + Destination: "192.0.2.0/32", + Protocol: "all", + IsDynamic: true, + }, + { + SourceRanges: []string{"::/0"}, + Action: "accept", + Destination: "192.0.2.0/32", + Protocol: "all", + IsDynamic: true, + }, + } + assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) + + // peerC is part of route1 distribution groups but should not receive the routes firewall rules + routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers) + assert.Len(t, routesFirewallRules, 0) + }) + +} diff --git a/route/route.go b/route/route.go index eb6c36bd8..e23801e6e 100644 --- a/route/route.go +++ b/route/route.go @@ -100,6 +100,7 @@ type Route struct { Metric int Enabled bool Groups []string `gorm:"serializer:json"` + AccessControlGroups []string `gorm:"serializer:json"` } // EventMeta returns activity event meta related to the route @@ -123,6 +124,7 @@ func (r *Route) Copy() *Route { Masquerade: r.Masquerade, Enabled: r.Enabled, Groups: slices.Clone(r.Groups), + AccessControlGroups: slices.Clone(r.AccessControlGroups), } return route } @@ -147,7 +149,8 @@ func (r *Route) IsEqual(other *Route) bool { other.Masquerade == r.Masquerade && other.Enabled == r.Enabled && slices.Equal(r.Groups, other.Groups) && - slices.Equal(r.PeerGroups, other.PeerGroups) + slices.Equal(r.PeerGroups, other.PeerGroups)&& + slices.Equal(r.AccessControlGroups, other.AccessControlGroups) } // IsDynamic returns if the route is dynamic, i.e. has domains From b7b08281336676f356c8e1032b1907c617b6439c Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 2 Oct 2024 15:14:09 +0200 Subject: [PATCH 13/37] [client] Adjust relay worker log level and message (#2683) --- client/internal/peer/worker_relay.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/internal/peer/worker_relay.go b/client/internal/peer/worker_relay.go index 6bb385d3e..c02fccebc 100644 --- a/client/internal/peer/worker_relay.go +++ b/client/internal/peer/worker_relay.go @@ -74,7 +74,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key) if err != nil { if errors.Is(err, relayClient.ErrConnAlreadyExists) { - w.log.Infof("do not need to reopen relay connection") + w.log.Debugf("handled offer by reusing existing relay connection") return } w.log.Errorf("failed to open connection via Relay: %s", err) From 7e5d3bdfe2306f69ef5daab3c742c4d206c69406 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Wed, 2 Oct 2024 15:33:38 +0200 Subject: [PATCH 14/37] [signal] Move dummy signal message handling into dispatcher (#2686) --- go.mod | 2 +- go.sum | 4 ++-- signal/server/signal.go | 5 ----- 3 files changed, 3 insertions(+), 8 deletions(-) diff --git a/go.mod b/go.mod index c29ba0763..e7137ce5b 100644 --- a/go.mod +++ b/go.mod @@ -60,7 +60,7 @@ require ( github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd - github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240929132730-cbef5d331757 + github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 github.com/patrickmn/go-cache v2.1.0+incompatible diff --git a/go.sum b/go.sum index 1f6cbb785..4563dc933 100644 --- a/go.sum +++ b/go.sum @@ -525,8 +525,8 @@ github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811- github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= -github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240929132730-cbef5d331757 h1:6XniCzDt+1jvXWMUY4EDH0Hi5RXbUOYB0A8XEQqSlZk= -github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240929132730-cbef5d331757/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= +github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f h1:Rl23OSc2xKFyxiuBXtWDMzhZBV4gOM7lhFxvYoCmBZg= +github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= diff --git a/signal/server/signal.go b/signal/server/signal.go index 386ce7238..63cc43bd7 100644 --- a/signal/server/signal.go +++ b/signal/server/signal.go @@ -71,11 +71,6 @@ func NewServer(ctx context.Context, meter metric.Meter) (*Server, error) { func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { log.Debugf("received a new message to send from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey) - if msg.RemoteKey == "dummy" { - // Test message send during netbird status - return &proto.EncryptedMessage{}, nil - } - if _, found := s.registry.Get(msg.RemoteKey); found { s.forwardMessageToPeer(ctx, msg) return &proto.EncryptedMessage{}, nil From fd67892cb4fa0c4e4c23c0511796cc7ce9fe296c Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 2 Oct 2024 18:24:22 +0200 Subject: [PATCH 15/37] [client] Refactor/iface pkg (#2646) Refactor the flat code structure --- .github/workflows/golang-test-freebsd.yml | 2 +- .github/workflows/golang-test-linux.yml | 2 +- client/android/client.go | 6 +-- client/cmd/login_test.go | 2 +- client/cmd/root_test.go | 2 +- client/cmd/up.go | 2 +- client/firewall/iface.go | 6 +-- client/firewall/iptables/manager_linux.go | 2 +- .../firewall/iptables/manager_linux_test.go | 2 +- client/firewall/nftables/acl_linux.go | 2 +- .../firewall/nftables/manager_linux_test.go | 2 +- client/firewall/uspfilter/uspfilter.go | 5 +- client/firewall/uspfilter/uspfilter_test.go | 21 ++++---- {iface => client/iface}/bind/bind.go | 0 {iface => client/iface}/bind/udp_mux.go | 0 .../iface}/bind/udp_mux_universal.go | 0 .../iface}/bind/udp_muxed_conn.go | 0 client/iface/configurer/err.go | 5 ++ .../iface/configurer/kernel_unix.go | 27 +++++----- {iface => client/iface/configurer}/name.go | 2 +- .../iface/configurer}/name_darwin.go | 2 +- {iface => client/iface/configurer}/uapi.go | 2 +- .../iface/configurer}/uapi_windows.go | 2 +- .../iface/configurer/usp.go | 24 ++++----- .../iface/configurer/usp_test.go | 2 +- client/iface/configurer/wgstats.go | 9 ++++ client/iface/device.go | 18 +++++++ .../iface/device/adapter.go | 2 +- {iface => client/iface/device}/address.go | 8 +-- .../iface/device/args.go | 2 +- .../iface/device/device_android.go | 54 +++++++++---------- .../iface/device/device_darwin.go | 49 ++++++++--------- .../iface/device/device_filter.go | 19 +++---- .../iface/device/device_filter_test.go | 13 ++--- .../iface/device/device_ios.go | 49 ++++++++--------- .../iface/device/device_kernel_unix.go | 31 +++++------ .../iface/device/device_netstack.go | 49 ++++++++--------- .../iface/device/device_usp_unix.go | 52 +++++++++--------- .../iface/device/device_windows.go | 47 ++++++++-------- client/iface/device/interface.go | 20 +++++++ .../iface/device/kernel_module.go | 2 +- .../iface/device/kernel_module_freebsd.go | 6 +-- .../iface/device/kernel_module_linux.go | 6 +-- .../iface/device/kernel_module_linux_test.go | 8 +-- .../iface/device/wg_link_freebsd.go | 5 +- .../iface/device/wg_link_linux.go | 2 +- {iface => client/iface/device}/wg_log.go | 2 +- client/iface/device/windows_guid.go | 4 ++ client/iface/device_android.go | 16 ++++++ {iface => client/iface}/freebsd/errors.go | 0 {iface => client/iface}/freebsd/iface.go | 0 .../iface}/freebsd/iface_internal_test.go | 0 {iface => client/iface}/freebsd/link.go | 0 {iface => client/iface}/iface.go | 53 +++++++++--------- {iface => client/iface}/iface_android.go | 9 ++-- {iface => client/iface}/iface_create.go | 0 {iface => client/iface}/iface_darwin.go | 13 ++--- {iface => client/iface}/iface_destroy_bsd.go | 0 .../iface}/iface_destroy_linux.go | 0 .../iface}/iface_destroy_mobile.go | 0 .../iface}/iface_destroy_windows.go | 0 {iface => client/iface}/iface_ios.go | 9 ++-- {iface => client/iface}/iface_moc.go | 24 +++++---- {iface => client/iface}/iface_test.go | 6 ++- {iface => client/iface}/iface_unix.go | 19 +++---- {iface => client/iface}/iface_windows.go | 15 +++--- {iface => client/iface}/iwginterface.go | 14 ++--- .../iface}/iwginterface_windows.go | 14 ++--- {iface => client/iface}/mocks/README.md | 0 {iface => client/iface}/mocks/filter.go | 2 +- .../iface}/mocks/iface/mocks/filter.go | 2 +- {iface => client/iface}/mocks/tun.go | 0 {iface => client/iface}/netstack/dialer.go | 0 {iface => client/iface}/netstack/env.go | 0 {iface => client/iface}/netstack/proxy.go | 0 {iface => client/iface}/netstack/tun.go | 0 client/internal/acl/manager_test.go | 2 +- client/internal/acl/mocks/iface_mapper.go | 5 +- client/internal/config.go | 2 +- client/internal/connect.go | 7 +-- client/internal/dns/response_writer_test.go | 2 +- client/internal/dns/server_test.go | 18 ++++--- client/internal/dns/wgiface.go | 10 ++-- client/internal/dns/wgiface_windows.go | 12 +++-- client/internal/engine.go | 13 ++--- client/internal/engine_test.go | 7 +-- client/internal/mobile_dependency.go | 4 +- client/internal/peer/conn.go | 5 +- client/internal/peer/conn_test.go | 2 +- client/internal/peer/status.go | 6 +-- client/internal/peer/worker_ice.go | 4 +- client/internal/routemanager/client.go | 2 +- client/internal/routemanager/dynamic/route.go | 2 +- client/internal/routemanager/manager.go | 5 +- client/internal/routemanager/manager_test.go | 2 +- client/internal/routemanager/mock.go | 2 +- .../internal/routemanager/server_android.go | 2 +- .../routemanager/server_nonandroid.go | 2 +- .../routemanager/sysctl/sysctl_linux.go | 2 +- .../routemanager/systemops/systemops.go | 2 +- .../systemops/systemops_generic.go | 2 +- .../systemops/systemops_generic_test.go | 2 +- iface/tun.go | 21 -------- iface/wg_configurer.go | 21 -------- util/net/net.go | 2 +- 105 files changed, 505 insertions(+), 438 deletions(-) rename {iface => client/iface}/bind/bind.go (100%) rename {iface => client/iface}/bind/udp_mux.go (100%) rename {iface => client/iface}/bind/udp_mux_universal.go (100%) rename {iface => client/iface}/bind/udp_muxed_conn.go (100%) create mode 100644 client/iface/configurer/err.go rename iface/wg_configurer_kernel_unix.go => client/iface/configurer/kernel_unix.go (83%) rename {iface => client/iface/configurer}/name.go (87%) rename {iface => client/iface/configurer}/name_darwin.go (86%) rename {iface => client/iface/configurer}/uapi.go (96%) rename {iface => client/iface/configurer}/uapi_windows.go (88%) rename iface/wg_configurer_usp.go => client/iface/configurer/usp.go (93%) rename iface/wg_configurer_usp_test.go => client/iface/configurer/usp_test.go (99%) create mode 100644 client/iface/configurer/wgstats.go create mode 100644 client/iface/device.go rename iface/tun_adapter.go => client/iface/device/adapter.go (94%) rename {iface => client/iface/device}/address.go (69%) rename iface/tun_args.go => client/iface/device/args.go (88%) rename iface/tun_android.go => client/iface/device/device_android.go (61%) rename iface/tun_darwin.go => client/iface/device/device_darwin.go (69%) rename iface/device_wrapper.go => client/iface/device/device_filter.go (81%) rename iface/device_wrapper_test.go => client/iface/device/device_filter_test.go (95%) rename iface/tun_ios.go => client/iface/device/device_ios.go (63%) rename iface/tun_kernel_unix.go => client/iface/device/device_kernel_unix.go (75%) rename iface/tun_netstack.go => client/iface/device/device_netstack.go (56%) rename iface/tun_usp_unix.go => client/iface/device/device_usp_unix.go (63%) rename iface/tun_windows.go => client/iface/device/device_windows.go (75%) create mode 100644 client/iface/device/interface.go rename iface/module.go => client/iface/device/kernel_module.go (92%) rename iface/module_freebsd.go => client/iface/device/kernel_module_freebsd.go (84%) rename iface/module_linux.go => client/iface/device/kernel_module_linux.go (98%) rename iface/module_linux_test.go => client/iface/device/kernel_module_linux_test.go (98%) rename iface/tun_link_freebsd.go => client/iface/device/wg_link_freebsd.go (95%) rename iface/tun_link_linux.go => client/iface/device/wg_link_linux.go (99%) rename {iface => client/iface/device}/wg_log.go (93%) create mode 100644 client/iface/device/windows_guid.go create mode 100644 client/iface/device_android.go rename {iface => client/iface}/freebsd/errors.go (100%) rename {iface => client/iface}/freebsd/iface.go (100%) rename {iface => client/iface}/freebsd/iface_internal_test.go (100%) rename {iface => client/iface}/freebsd/link.go (100%) rename {iface => client/iface}/iface.go (79%) rename {iface => client/iface}/iface_android.go (67%) rename {iface => client/iface}/iface_create.go (100%) rename {iface => client/iface}/iface_darwin.go (68%) rename {iface => client/iface}/iface_destroy_bsd.go (100%) rename {iface => client/iface}/iface_destroy_linux.go (100%) rename {iface => client/iface}/iface_destroy_mobile.go (100%) rename {iface => client/iface}/iface_destroy_windows.go (100%) rename {iface => client/iface}/iface_ios.go (59%) rename {iface => client/iface}/iface_moc.go (76%) rename {iface => client/iface}/iface_test.go (98%) rename {iface => client/iface}/iface_unix.go (53%) rename {iface => client/iface}/iface_windows.go (52%) rename {iface => client/iface}/iwginterface.go (65%) rename {iface => client/iface}/iwginterface_windows.go (65%) rename {iface => client/iface}/mocks/README.md (100%) rename {iface => client/iface}/mocks/filter.go (97%) rename {iface => client/iface}/mocks/iface/mocks/filter.go (97%) rename {iface => client/iface}/mocks/tun.go (100%) rename {iface => client/iface}/netstack/dialer.go (100%) rename {iface => client/iface}/netstack/env.go (100%) rename {iface => client/iface}/netstack/proxy.go (100%) rename {iface => client/iface}/netstack/tun.go (100%) delete mode 100644 iface/tun.go delete mode 100644 iface/wg_configurer.go diff --git a/.github/workflows/golang-test-freebsd.yml b/.github/workflows/golang-test-freebsd.yml index 4f13ee30e..a2d743715 100644 --- a/.github/workflows/golang-test-freebsd.yml +++ b/.github/workflows/golang-test-freebsd.yml @@ -38,7 +38,7 @@ jobs: time go test -timeout 1m -failfast ./dns/... time go test -timeout 1m -failfast ./encryption/... time go test -timeout 1m -failfast ./formatter/... - time go test -timeout 1m -failfast ./iface/... + time go test -timeout 1m -failfast ./client/iface/... time go test -timeout 1m -failfast ./route/... time go test -timeout 1m -failfast ./sharedsock/... time go test -timeout 1m -failfast ./signal/... diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 2d5cf2856..524f35f6f 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -80,7 +80,7 @@ jobs: run: git --no-pager diff --exit-code - name: Generate Iface Test bin - run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./iface/ + run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./client/iface/ - name: Generate Shared Sock Test bin run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock diff --git a/client/android/client.go b/client/android/client.go index d937e132e..229bcd974 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -8,6 +8,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" @@ -15,7 +16,6 @@ import ( "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/formatter" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/util/net" ) @@ -26,7 +26,7 @@ type ConnectionListener interface { // TunAdapter export internal TunAdapter for mobile type TunAdapter interface { - iface.TunAdapter + device.TunAdapter } // IFaceDiscover export internal IFaceDiscover for mobile @@ -51,7 +51,7 @@ func init() { // Client struct manage the life circle of background service type Client struct { cfgFile string - tunAdapter iface.TunAdapter + tunAdapter device.TunAdapter iFaceDiscover IFaceDiscover recorder *peer.Status ctxCancel context.CancelFunc diff --git a/client/cmd/login_test.go b/client/cmd/login_test.go index 6bb7eff4f..fa20435ea 100644 --- a/client/cmd/login_test.go +++ b/client/cmd/login_test.go @@ -5,8 +5,8 @@ import ( "strings" "testing" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/util" ) diff --git a/client/cmd/root_test.go b/client/cmd/root_test.go index f2805cf35..4cbbe8783 100644 --- a/client/cmd/root_test.go +++ b/client/cmd/root_test.go @@ -7,7 +7,7 @@ import ( "github.com/spf13/cobra" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" ) func TestInitCommands(t *testing.T) { diff --git a/client/cmd/up.go b/client/cmd/up.go index b447f7141..05ecce9e0 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -15,11 +15,11 @@ import ( gstatus "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/durationpb" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/system" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/util" ) diff --git a/client/firewall/iface.go b/client/firewall/iface.go index d0b5209c0..f349f9210 100644 --- a/client/firewall/iface.go +++ b/client/firewall/iface.go @@ -1,13 +1,13 @@ package firewall import ( - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface/device" ) // IFaceMapper defines subset methods of interface required for manager type IFaceMapper interface { Name() string - Address() iface.WGAddress + Address() device.WGAddress IsUserspaceBind() bool - SetFilter(iface.PacketFilter) error + SetFilter(device.PacketFilter) error } diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index fae41d9c5..6fefd58e6 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -11,7 +11,7 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" ) // Manager of iptables firewall diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go index 0072aa159..498d8f58b 100644 --- a/client/firewall/iptables/manager_linux_test.go +++ b/client/firewall/iptables/manager_linux_test.go @@ -11,7 +11,7 @@ import ( "github.com/stretchr/testify/require" fw "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" ) var ifaceMock = &iFaceMock{ diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index 85cba9e1c..eaf7fb6a0 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -16,7 +16,7 @@ import ( "golang.org/x/sys/unix" firewall "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" ) const ( diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 7f78a9a2e..904050a51 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -15,7 +15,7 @@ import ( "golang.org/x/sys/unix" fw "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" ) var ifaceMock = &iFaceMock{ diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 681058ea9..0e3ee9799 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -12,7 +12,8 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" ) const layerTypeAll = 0 @@ -23,7 +24,7 @@ var ( // IFaceMapper defines subset methods of interface required for manager type IFaceMapper interface { - SetFilter(iface.PacketFilter) error + SetFilter(device.PacketFilter) error Address() iface.WGAddress } diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index dd7366fe9..c188deea4 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -11,15 +11,16 @@ import ( "github.com/stretchr/testify/require" fw "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" ) type IFaceMock struct { - SetFilterFunc func(iface.PacketFilter) error + SetFilterFunc func(device.PacketFilter) error AddressFunc func() iface.WGAddress } -func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error { +func (i *IFaceMock) SetFilter(iface device.PacketFilter) error { if i.SetFilterFunc == nil { return fmt.Errorf("not implemented") } @@ -35,7 +36,7 @@ func (i *IFaceMock) Address() iface.WGAddress { func TestManagerCreate(t *testing.T) { ifaceMock := &IFaceMock{ - SetFilterFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(device.PacketFilter) error { return nil }, } m, err := Create(ifaceMock) @@ -52,7 +53,7 @@ func TestManagerCreate(t *testing.T) { func TestManagerAddPeerFiltering(t *testing.T) { isSetFilterCalled := false ifaceMock := &IFaceMock{ - SetFilterFunc: func(iface.PacketFilter) error { + SetFilterFunc: func(device.PacketFilter) error { isSetFilterCalled = true return nil }, @@ -90,7 +91,7 @@ func TestManagerAddPeerFiltering(t *testing.T) { func TestManagerDeleteRule(t *testing.T) { ifaceMock := &IFaceMock{ - SetFilterFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(device.PacketFilter) error { return nil }, } m, err := Create(ifaceMock) @@ -236,7 +237,7 @@ func TestAddUDPPacketHook(t *testing.T) { func TestManagerReset(t *testing.T) { ifaceMock := &IFaceMock{ - SetFilterFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(device.PacketFilter) error { return nil }, } m, err := Create(ifaceMock) @@ -271,7 +272,7 @@ func TestManagerReset(t *testing.T) { func TestNotMatchByIP(t *testing.T) { ifaceMock := &IFaceMock{ - SetFilterFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(device.PacketFilter) error { return nil }, } m, err := Create(ifaceMock) @@ -339,7 +340,7 @@ func TestNotMatchByIP(t *testing.T) { func TestRemovePacketHook(t *testing.T) { // creating mock iface iface := &IFaceMock{ - SetFilterFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(device.PacketFilter) error { return nil }, } // creating manager instance @@ -388,7 +389,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { // just check on the local interface ifaceMock := &IFaceMock{ - SetFilterFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(device.PacketFilter) error { return nil }, } manager, err := Create(ifaceMock) require.NoError(t, err) diff --git a/iface/bind/bind.go b/client/iface/bind/bind.go similarity index 100% rename from iface/bind/bind.go rename to client/iface/bind/bind.go diff --git a/iface/bind/udp_mux.go b/client/iface/bind/udp_mux.go similarity index 100% rename from iface/bind/udp_mux.go rename to client/iface/bind/udp_mux.go diff --git a/iface/bind/udp_mux_universal.go b/client/iface/bind/udp_mux_universal.go similarity index 100% rename from iface/bind/udp_mux_universal.go rename to client/iface/bind/udp_mux_universal.go diff --git a/iface/bind/udp_muxed_conn.go b/client/iface/bind/udp_muxed_conn.go similarity index 100% rename from iface/bind/udp_muxed_conn.go rename to client/iface/bind/udp_muxed_conn.go diff --git a/client/iface/configurer/err.go b/client/iface/configurer/err.go new file mode 100644 index 000000000..a64bba2dd --- /dev/null +++ b/client/iface/configurer/err.go @@ -0,0 +1,5 @@ +package configurer + +import "errors" + +var ErrPeerNotFound = errors.New("peer not found") diff --git a/iface/wg_configurer_kernel_unix.go b/client/iface/configurer/kernel_unix.go similarity index 83% rename from iface/wg_configurer_kernel_unix.go rename to client/iface/configurer/kernel_unix.go index 8b47082da..7c1c41669 100644 --- a/iface/wg_configurer_kernel_unix.go +++ b/client/iface/configurer/kernel_unix.go @@ -1,6 +1,6 @@ //go:build (linux && !android) || freebsd -package iface +package configurer import ( "fmt" @@ -12,18 +12,17 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -type wgKernelConfigurer struct { +type KernelConfigurer struct { deviceName string } -func newWGConfigurer(deviceName string) wgConfigurer { - wgc := &wgKernelConfigurer{ +func NewKernelConfigurer(deviceName string) *KernelConfigurer { + return &KernelConfigurer{ deviceName: deviceName, } - return wgc } -func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) error { +func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error { log.Debugf("adding Wireguard private key") key, err := wgtypes.ParseKey(privateKey) if err != nil { @@ -44,7 +43,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err return nil } -func (c *wgKernelConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { +func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { // parse allowed ips _, ipNet, err := net.ParseCIDR(allowedIps) if err != nil { @@ -75,7 +74,7 @@ func (c *wgKernelConfigurer) updatePeer(peerKey string, allowedIps string, keepA return nil } -func (c *wgKernelConfigurer) removePeer(peerKey string) error { +func (c *KernelConfigurer) RemovePeer(peerKey string) error { peerKeyParsed, err := wgtypes.ParseKey(peerKey) if err != nil { return err @@ -96,7 +95,7 @@ func (c *wgKernelConfigurer) removePeer(peerKey string) error { return nil } -func (c *wgKernelConfigurer) addAllowedIP(peerKey string, allowedIP string) error { +func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error { _, ipNet, err := net.ParseCIDR(allowedIP) if err != nil { return err @@ -123,7 +122,7 @@ func (c *wgKernelConfigurer) addAllowedIP(peerKey string, allowedIP string) erro return nil } -func (c *wgKernelConfigurer) removeAllowedIP(peerKey string, allowedIP string) error { +func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP string) error { _, ipNet, err := net.ParseCIDR(allowedIP) if err != nil { return fmt.Errorf("parse allowed IP: %w", err) @@ -165,7 +164,7 @@ func (c *wgKernelConfigurer) removeAllowedIP(peerKey string, allowedIP string) e return nil } -func (c *wgKernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) { +func (c *KernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) { wg, err := wgctrl.New() if err != nil { return wgtypes.Peer{}, fmt.Errorf("wgctl: %w", err) @@ -189,7 +188,7 @@ func (c *wgKernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer return wgtypes.Peer{}, ErrPeerNotFound } -func (c *wgKernelConfigurer) configure(config wgtypes.Config) error { +func (c *KernelConfigurer) configure(config wgtypes.Config) error { wg, err := wgctrl.New() if err != nil { return err @@ -205,10 +204,10 @@ func (c *wgKernelConfigurer) configure(config wgtypes.Config) error { return wg.ConfigureDevice(c.deviceName, config) } -func (c *wgKernelConfigurer) close() { +func (c *KernelConfigurer) Close() { } -func (c *wgKernelConfigurer) getStats(peerKey string) (WGStats, error) { +func (c *KernelConfigurer) GetStats(peerKey string) (WGStats, error) { peer, err := c.getPeer(c.deviceName, peerKey) if err != nil { return WGStats{}, fmt.Errorf("get wireguard stats: %w", err) diff --git a/iface/name.go b/client/iface/configurer/name.go similarity index 87% rename from iface/name.go rename to client/iface/configurer/name.go index 706cb65ad..e2133d0ea 100644 --- a/iface/name.go +++ b/client/iface/configurer/name.go @@ -1,6 +1,6 @@ //go:build linux || windows || freebsd -package iface +package configurer // WgInterfaceDefault is a default interface name of Wiretrustee const WgInterfaceDefault = "wt0" diff --git a/iface/name_darwin.go b/client/iface/configurer/name_darwin.go similarity index 86% rename from iface/name_darwin.go rename to client/iface/configurer/name_darwin.go index a4016ce15..034ce388d 100644 --- a/iface/name_darwin.go +++ b/client/iface/configurer/name_darwin.go @@ -1,6 +1,6 @@ //go:build darwin -package iface +package configurer // WgInterfaceDefault is a default interface name of Wiretrustee const WgInterfaceDefault = "utun100" diff --git a/iface/uapi.go b/client/iface/configurer/uapi.go similarity index 96% rename from iface/uapi.go rename to client/iface/configurer/uapi.go index d7ff52e7b..4801841de 100644 --- a/iface/uapi.go +++ b/client/iface/configurer/uapi.go @@ -1,6 +1,6 @@ //go:build !windows -package iface +package configurer import ( "net" diff --git a/iface/uapi_windows.go b/client/iface/configurer/uapi_windows.go similarity index 88% rename from iface/uapi_windows.go rename to client/iface/configurer/uapi_windows.go index e1f466364..46fa90c2e 100644 --- a/iface/uapi_windows.go +++ b/client/iface/configurer/uapi_windows.go @@ -1,4 +1,4 @@ -package iface +package configurer import ( "net" diff --git a/iface/wg_configurer_usp.go b/client/iface/configurer/usp.go similarity index 93% rename from iface/wg_configurer_usp.go rename to client/iface/configurer/usp.go index cd1d9d0b6..21d65ab2a 100644 --- a/iface/wg_configurer_usp.go +++ b/client/iface/configurer/usp.go @@ -1,4 +1,4 @@ -package iface +package configurer import ( "encoding/hex" @@ -19,15 +19,15 @@ import ( var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found") -type wgUSPConfigurer struct { +type WGUSPConfigurer struct { device *device.Device deviceName string uapiListener net.Listener } -func newWGUSPConfigurer(device *device.Device, deviceName string) wgConfigurer { - wgCfg := &wgUSPConfigurer{ +func NewUSPConfigurer(device *device.Device, deviceName string) *WGUSPConfigurer { + wgCfg := &WGUSPConfigurer{ device: device, deviceName: deviceName, } @@ -35,7 +35,7 @@ func newWGUSPConfigurer(device *device.Device, deviceName string) wgConfigurer { return wgCfg } -func (c *wgUSPConfigurer) configureInterface(privateKey string, port int) error { +func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error { log.Debugf("adding Wireguard private key") key, err := wgtypes.ParseKey(privateKey) if err != nil { @@ -52,7 +52,7 @@ func (c *wgUSPConfigurer) configureInterface(privateKey string, port int) error return c.device.IpcSet(toWgUserspaceString(config)) } -func (c *wgUSPConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { +func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { // parse allowed ips _, ipNet, err := net.ParseCIDR(allowedIps) if err != nil { @@ -80,7 +80,7 @@ func (c *wgUSPConfigurer) updatePeer(peerKey string, allowedIps string, keepAliv return c.device.IpcSet(toWgUserspaceString(config)) } -func (c *wgUSPConfigurer) removePeer(peerKey string) error { +func (c *WGUSPConfigurer) RemovePeer(peerKey string) error { peerKeyParsed, err := wgtypes.ParseKey(peerKey) if err != nil { return err @@ -97,7 +97,7 @@ func (c *wgUSPConfigurer) removePeer(peerKey string) error { return c.device.IpcSet(toWgUserspaceString(config)) } -func (c *wgUSPConfigurer) addAllowedIP(peerKey string, allowedIP string) error { +func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error { _, ipNet, err := net.ParseCIDR(allowedIP) if err != nil { return err @@ -121,7 +121,7 @@ func (c *wgUSPConfigurer) addAllowedIP(peerKey string, allowedIP string) error { return c.device.IpcSet(toWgUserspaceString(config)) } -func (c *wgUSPConfigurer) removeAllowedIP(peerKey string, ip string) error { +func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error { ipc, err := c.device.IpcGet() if err != nil { return err @@ -185,7 +185,7 @@ func (c *wgUSPConfigurer) removeAllowedIP(peerKey string, ip string) error { } // startUAPI starts the UAPI listener for managing the WireGuard interface via external tool -func (t *wgUSPConfigurer) startUAPI() { +func (t *WGUSPConfigurer) startUAPI() { var err error t.uapiListener, err = openUAPI(t.deviceName) if err != nil { @@ -207,7 +207,7 @@ func (t *wgUSPConfigurer) startUAPI() { }(t.uapiListener) } -func (t *wgUSPConfigurer) close() { +func (t *WGUSPConfigurer) Close() { if t.uapiListener != nil { err := t.uapiListener.Close() if err != nil { @@ -223,7 +223,7 @@ func (t *wgUSPConfigurer) close() { } } -func (t *wgUSPConfigurer) getStats(peerKey string) (WGStats, error) { +func (t *WGUSPConfigurer) GetStats(peerKey string) (WGStats, error) { ipc, err := t.device.IpcGet() if err != nil { return WGStats{}, fmt.Errorf("ipc get: %w", err) diff --git a/iface/wg_configurer_usp_test.go b/client/iface/configurer/usp_test.go similarity index 99% rename from iface/wg_configurer_usp_test.go rename to client/iface/configurer/usp_test.go index ac0fc6130..775339f24 100644 --- a/iface/wg_configurer_usp_test.go +++ b/client/iface/configurer/usp_test.go @@ -1,4 +1,4 @@ -package iface +package configurer import ( "encoding/hex" diff --git a/client/iface/configurer/wgstats.go b/client/iface/configurer/wgstats.go new file mode 100644 index 000000000..56d0d7310 --- /dev/null +++ b/client/iface/configurer/wgstats.go @@ -0,0 +1,9 @@ +package configurer + +import "time" + +type WGStats struct { + LastHandshake time.Time + TxBytes int64 + RxBytes int64 +} diff --git a/client/iface/device.go b/client/iface/device.go new file mode 100644 index 000000000..0d4e69145 --- /dev/null +++ b/client/iface/device.go @@ -0,0 +1,18 @@ +//go:build !android + +package iface + +import ( + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" +) + +type WGTunDevice interface { + Create() (device.WGConfigurer, error) + Up() (*bind.UniversalUDPMuxDefault, error) + UpdateAddr(address WGAddress) error + WgAddress() WGAddress + DeviceName() string + Close() error + FilteredDevice() *device.FilteredDevice +} diff --git a/iface/tun_adapter.go b/client/iface/device/adapter.go similarity index 94% rename from iface/tun_adapter.go rename to client/iface/device/adapter.go index adec93ed1..6ebc05390 100644 --- a/iface/tun_adapter.go +++ b/client/iface/device/adapter.go @@ -1,4 +1,4 @@ -package iface +package device // TunAdapter is an interface for create tun device from external service type TunAdapter interface { diff --git a/iface/address.go b/client/iface/device/address.go similarity index 69% rename from iface/address.go rename to client/iface/device/address.go index 5ff4fbc06..15de301da 100644 --- a/iface/address.go +++ b/client/iface/device/address.go @@ -1,18 +1,18 @@ -package iface +package device import ( "fmt" "net" ) -// WGAddress Wireguard parsed address +// WGAddress WireGuard parsed address type WGAddress struct { IP net.IP Network *net.IPNet } -// parseWGAddress parse a string ("1.2.3.4/24") address to WG Address -func parseWGAddress(address string) (WGAddress, error) { +// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address +func ParseWGAddress(address string) (WGAddress, error) { ip, network, err := net.ParseCIDR(address) if err != nil { return WGAddress{}, err diff --git a/iface/tun_args.go b/client/iface/device/args.go similarity index 88% rename from iface/tun_args.go rename to client/iface/device/args.go index 0eac2c4c0..d7b86b335 100644 --- a/iface/tun_args.go +++ b/client/iface/device/args.go @@ -1,4 +1,4 @@ -package iface +package device type MobileIFaceArguments struct { TunAdapter TunAdapter // only for Android diff --git a/iface/tun_android.go b/client/iface/device/device_android.go similarity index 61% rename from iface/tun_android.go rename to client/iface/device/device_android.go index 504993094..29e3f409d 100644 --- a/iface/tun_android.go +++ b/client/iface/device/device_android.go @@ -1,7 +1,6 @@ //go:build android -// +build android -package iface +package device import ( "strings" @@ -12,11 +11,12 @@ import ( "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" ) -// ignore the wgTunDevice interface on Android because the creation of the tun device is different on this platform -type wgTunDevice struct { +// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform +type WGTunDevice struct { address WGAddress port int key string @@ -24,15 +24,15 @@ type wgTunDevice struct { iceBind *bind.ICEBind tunAdapter TunAdapter - name string - device *device.Device - wrapper *DeviceWrapper - udpMux *bind.UniversalUDPMuxDefault - configurer wgConfigurer + name string + device *device.Device + filteredDevice *FilteredDevice + udpMux *bind.UniversalUDPMuxDefault + configurer WGConfigurer } -func newTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter, filterFn bind.FilterFn) wgTunDevice { - return wgTunDevice{ +func NewTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter, filterFn bind.FilterFn) *WGTunDevice { + return &WGTunDevice{ address: address, port: port, key: key, @@ -42,7 +42,7 @@ func newTunDevice(address WGAddress, port int, key string, mtu int, transportNet } } -func (t *wgTunDevice) Create(routes []string, dns string, searchDomains []string) (wgConfigurer, error) { +func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) { log.Info("create tun interface") routesString := routesToString(routes) @@ -61,24 +61,24 @@ func (t *wgTunDevice) Create(routes []string, dns string, searchDomains []string return nil, err } t.name = name - t.wrapper = newDeviceWrapper(tunDevice) + t.filteredDevice = newDeviceFilter(tunDevice) log.Debugf("attaching to interface %v", name) - t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] ")) + t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] ")) // without this property mobile devices can discover remote endpoints if the configured one was wrong. // this helps with support for the older NetBird clients that had a hardcoded direct mode // t.device.DisableSomeRoamingForBrokenMobileSemantics() - t.configurer = newWGUSPConfigurer(t.device, t.name) - err = t.configurer.configureInterface(t.key, t.port) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name) + err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { t.device.Close() - t.configurer.close() + t.configurer.Close() return nil, err } return t.configurer, nil } -func (t *wgTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err @@ -93,14 +93,14 @@ func (t *wgTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *wgTunDevice) UpdateAddr(addr WGAddress) error { +func (t *WGTunDevice) UpdateAddr(addr WGAddress) error { // todo implement return nil } -func (t *wgTunDevice) Close() error { +func (t *WGTunDevice) Close() error { if t.configurer != nil { - t.configurer.close() + t.configurer.Close() } if t.device != nil { @@ -115,20 +115,20 @@ func (t *wgTunDevice) Close() error { return nil } -func (t *wgTunDevice) Device() *device.Device { +func (t *WGTunDevice) Device() *device.Device { return t.device } -func (t *wgTunDevice) DeviceName() string { +func (t *WGTunDevice) DeviceName() string { return t.name } -func (t *wgTunDevice) WgAddress() WGAddress { +func (t *WGTunDevice) WgAddress() WGAddress { return t.address } -func (t *wgTunDevice) Wrapper() *DeviceWrapper { - return t.wrapper +func (t *WGTunDevice) FilteredDevice() *FilteredDevice { + return t.filteredDevice } func routesToString(routes []string) string { diff --git a/iface/tun_darwin.go b/client/iface/device/device_darwin.go similarity index 69% rename from iface/tun_darwin.go rename to client/iface/device/device_darwin.go index fcf9f8ba0..03e85a7f1 100644 --- a/iface/tun_darwin.go +++ b/client/iface/device/device_darwin.go @@ -1,6 +1,6 @@ //go:build !ios -package iface +package device import ( "fmt" @@ -11,10 +11,11 @@ import ( "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" ) -type tunDevice struct { +type TunDevice struct { name string address WGAddress port int @@ -22,14 +23,14 @@ type tunDevice struct { mtu int iceBind *bind.ICEBind - device *device.Device - wrapper *DeviceWrapper - udpMux *bind.UniversalUDPMuxDefault - configurer wgConfigurer + device *device.Device + filteredDevice *FilteredDevice + udpMux *bind.UniversalUDPMuxDefault + configurer WGConfigurer } -func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice { - return &tunDevice{ +func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice { + return &TunDevice{ name: name, address: address, port: port, @@ -39,16 +40,16 @@ func newTunDevice(name string, address WGAddress, port int, key string, mtu int, } } -func (t *tunDevice) Create() (wgConfigurer, error) { +func (t *TunDevice) Create() (WGConfigurer, error) { tunDevice, err := tun.CreateTUN(t.name, t.mtu) if err != nil { return nil, fmt.Errorf("error creating tun device: %s", err) } - t.wrapper = newDeviceWrapper(tunDevice) + t.filteredDevice = newDeviceFilter(tunDevice) // We need to create a wireguard-go device and listen to configuration requests t.device = device.NewDevice( - t.wrapper, + t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "), ) @@ -59,17 +60,17 @@ func (t *tunDevice) Create() (wgConfigurer, error) { return nil, fmt.Errorf("error assigning ip: %s", err) } - t.configurer = newWGUSPConfigurer(t.device, t.name) - err = t.configurer.configureInterface(t.key, t.port) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name) + err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { t.device.Close() - t.configurer.close() + t.configurer.Close() return nil, fmt.Errorf("error configuring interface: %s", err) } return t.configurer, nil } -func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err @@ -84,14 +85,14 @@ func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *tunDevice) UpdateAddr(address WGAddress) error { +func (t *TunDevice) UpdateAddr(address WGAddress) error { t.address = address return t.assignAddr() } -func (t *tunDevice) Close() error { +func (t *TunDevice) Close() error { if t.configurer != nil { - t.configurer.close() + t.configurer.Close() } if t.device != nil { @@ -105,20 +106,20 @@ func (t *tunDevice) Close() error { return nil } -func (t *tunDevice) WgAddress() WGAddress { +func (t *TunDevice) WgAddress() WGAddress { return t.address } -func (t *tunDevice) DeviceName() string { +func (t *TunDevice) DeviceName() string { return t.name } -func (t *tunDevice) Wrapper() *DeviceWrapper { - return t.wrapper +func (t *TunDevice) FilteredDevice() *FilteredDevice { + return t.filteredDevice } // assignAddr Adds IP address to the tunnel interface and network route based on the range provided -func (t *tunDevice) assignAddr() error { +func (t *TunDevice) assignAddr() error { cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String()) if out, err := cmd.CombinedOutput(); err != nil { log.Errorf("adding address command '%v' failed with output: %s", cmd.String(), out) diff --git a/iface/device_wrapper.go b/client/iface/device/device_filter.go similarity index 81% rename from iface/device_wrapper.go rename to client/iface/device/device_filter.go index 2fa219395..f87f10429 100644 --- a/iface/device_wrapper.go +++ b/client/iface/device/device_filter.go @@ -1,4 +1,4 @@ -package iface +package device import ( "net" @@ -28,22 +28,23 @@ type PacketFilter interface { SetNetwork(*net.IPNet) } -// DeviceWrapper to override Read or Write of packets -type DeviceWrapper struct { +// FilteredDevice to override Read or Write of packets +type FilteredDevice struct { tun.Device + filter PacketFilter mutex sync.RWMutex } -// newDeviceWrapper constructor function -func newDeviceWrapper(device tun.Device) *DeviceWrapper { - return &DeviceWrapper{ +// newDeviceFilter constructor function +func newDeviceFilter(device tun.Device) *FilteredDevice { + return &FilteredDevice{ Device: device, } } // Read wraps read method with filtering feature -func (d *DeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { +func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { if n, err = d.Device.Read(bufs, sizes, offset); err != nil { return 0, err } @@ -68,7 +69,7 @@ func (d *DeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err } // Write wraps write method with filtering feature -func (d *DeviceWrapper) Write(bufs [][]byte, offset int) (int, error) { +func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) { d.mutex.RLock() filter := d.filter d.mutex.RUnlock() @@ -92,7 +93,7 @@ func (d *DeviceWrapper) Write(bufs [][]byte, offset int) (int, error) { } // SetFilter sets packet filter to device -func (d *DeviceWrapper) SetFilter(filter PacketFilter) { +func (d *FilteredDevice) SetFilter(filter PacketFilter) { d.mutex.Lock() d.filter = filter d.mutex.Unlock() diff --git a/iface/device_wrapper_test.go b/client/iface/device/device_filter_test.go similarity index 95% rename from iface/device_wrapper_test.go rename to client/iface/device/device_filter_test.go index 2d3725ea4..d3278b918 100644 --- a/iface/device_wrapper_test.go +++ b/client/iface/device/device_filter_test.go @@ -1,4 +1,4 @@ -package iface +package device import ( "net" @@ -7,7 +7,8 @@ import ( "github.com/golang/mock/gomock" "github.com/google/gopacket" "github.com/google/gopacket/layers" - mocks "github.com/netbirdio/netbird/iface/mocks" + + mocks "github.com/netbirdio/netbird/client/iface/mocks" ) func TestDeviceWrapperRead(t *testing.T) { @@ -51,7 +52,7 @@ func TestDeviceWrapperRead(t *testing.T) { return 1, nil }) - wrapped := newDeviceWrapper(tun) + wrapped := newDeviceFilter(tun) bufs := [][]byte{{}} sizes := []int{0} @@ -99,7 +100,7 @@ func TestDeviceWrapperRead(t *testing.T) { tun := mocks.NewMockDevice(ctrl) tun.EXPECT().Write(mockBufs, 0).Return(1, nil) - wrapped := newDeviceWrapper(tun) + wrapped := newDeviceFilter(tun) bufs := [][]byte{buffer.Bytes()} @@ -147,7 +148,7 @@ func TestDeviceWrapperRead(t *testing.T) { filter := mocks.NewMockPacketFilter(ctrl) filter.EXPECT().DropIncoming(gomock.Any()).Return(true) - wrapped := newDeviceWrapper(tun) + wrapped := newDeviceFilter(tun) wrapped.filter = filter bufs := [][]byte{buffer.Bytes()} @@ -202,7 +203,7 @@ func TestDeviceWrapperRead(t *testing.T) { filter := mocks.NewMockPacketFilter(ctrl) filter.EXPECT().DropOutgoing(gomock.Any()).Return(true) - wrapped := newDeviceWrapper(tun) + wrapped := newDeviceFilter(tun) wrapped.filter = filter bufs := [][]byte{{}} diff --git a/iface/tun_ios.go b/client/iface/device/device_ios.go similarity index 63% rename from iface/tun_ios.go rename to client/iface/device/device_ios.go index 6d53cc333..226e8a2e0 100644 --- a/iface/tun_ios.go +++ b/client/iface/device/device_ios.go @@ -1,7 +1,7 @@ //go:build ios // +build ios -package iface +package device import ( "os" @@ -12,10 +12,11 @@ import ( "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" ) -type tunDevice struct { +type TunDevice struct { name string address WGAddress port int @@ -23,14 +24,14 @@ type tunDevice struct { iceBind *bind.ICEBind tunFd int - device *device.Device - wrapper *DeviceWrapper - udpMux *bind.UniversalUDPMuxDefault - configurer wgConfigurer + device *device.Device + filteredDevice *FilteredDevice + udpMux *bind.UniversalUDPMuxDefault + configurer WGConfigurer } -func newTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int, filterFn bind.FilterFn) *tunDevice { - return &tunDevice{ +func NewTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int, filterFn bind.FilterFn) *TunDevice { + return &TunDevice{ name: name, address: address, port: port, @@ -40,7 +41,7 @@ func newTunDevice(name string, address WGAddress, port int, key string, transpor } } -func (t *tunDevice) Create() (wgConfigurer, error) { +func (t *TunDevice) Create() (WGConfigurer, error) { log.Infof("create tun interface") dupTunFd, err := unix.Dup(t.tunFd) @@ -62,24 +63,24 @@ func (t *tunDevice) Create() (wgConfigurer, error) { return nil, err } - t.wrapper = newDeviceWrapper(tunDevice) + t.filteredDevice = newDeviceFilter(tunDevice) log.Debug("Attaching to interface") - t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] ")) + t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] ")) // without this property mobile devices can discover remote endpoints if the configured one was wrong. // this helps with support for the older NetBird clients that had a hardcoded direct mode // t.device.DisableSomeRoamingForBrokenMobileSemantics() - t.configurer = newWGUSPConfigurer(t.device, t.name) - err = t.configurer.configureInterface(t.key, t.port) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name) + err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { t.device.Close() - t.configurer.close() + t.configurer.Close() return nil, err } return t.configurer, nil } -func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err @@ -94,17 +95,17 @@ func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *tunDevice) Device() *device.Device { +func (t *TunDevice) Device() *device.Device { return t.device } -func (t *tunDevice) DeviceName() string { +func (t *TunDevice) DeviceName() string { return t.name } -func (t *tunDevice) Close() error { +func (t *TunDevice) Close() error { if t.configurer != nil { - t.configurer.close() + t.configurer.Close() } if t.device != nil { @@ -119,15 +120,15 @@ func (t *tunDevice) Close() error { return nil } -func (t *tunDevice) WgAddress() WGAddress { +func (t *TunDevice) WgAddress() WGAddress { return t.address } -func (t *tunDevice) UpdateAddr(addr WGAddress) error { +func (t *TunDevice) UpdateAddr(addr WGAddress) error { // todo implement return nil } -func (t *tunDevice) Wrapper() *DeviceWrapper { - return t.wrapper +func (t *TunDevice) FilteredDevice() *FilteredDevice { + return t.filteredDevice } diff --git a/iface/tun_kernel_unix.go b/client/iface/device/device_kernel_unix.go similarity index 75% rename from iface/tun_kernel_unix.go rename to client/iface/device/device_kernel_unix.go index 220c07888..f355d2cf7 100644 --- a/iface/tun_kernel_unix.go +++ b/client/iface/device/device_kernel_unix.go @@ -1,6 +1,6 @@ //go:build (linux && !android) || freebsd -package iface +package device import ( "context" @@ -10,11 +10,12 @@ import ( "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/sharedsock" ) -type tunKernelDevice struct { +type TunKernelDevice struct { name string address WGAddress wgPort int @@ -31,11 +32,11 @@ type tunKernelDevice struct { filterFn bind.FilterFn } -func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) wgTunDevice { +func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice { checkUser() ctx, cancel := context.WithCancel(context.Background()) - return &tunKernelDevice{ + return &TunKernelDevice{ ctx: ctx, ctxCancel: cancel, name: name, @@ -47,7 +48,7 @@ func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu in } } -func (t *tunKernelDevice) Create() (wgConfigurer, error) { +func (t *TunKernelDevice) Create() (WGConfigurer, error) { link := newWGLink(t.name) if err := link.recreate(); err != nil { @@ -67,16 +68,16 @@ func (t *tunKernelDevice) Create() (wgConfigurer, error) { return nil, fmt.Errorf("set mtu: %w", err) } - configurer := newWGConfigurer(t.name) + configurer := configurer.NewKernelConfigurer(t.name) - if err := configurer.configureInterface(t.key, t.wgPort); err != nil { + if err := configurer.ConfigureInterface(t.key, t.wgPort); err != nil { return nil, fmt.Errorf("error configuring interface: %s", err) } return configurer, nil } -func (t *tunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { if t.udpMux != nil { return t.udpMux, nil } @@ -111,12 +112,12 @@ func (t *tunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return t.udpMux, nil } -func (t *tunKernelDevice) UpdateAddr(address WGAddress) error { +func (t *TunKernelDevice) UpdateAddr(address WGAddress) error { t.address = address return t.assignAddr() } -func (t *tunKernelDevice) Close() error { +func (t *TunKernelDevice) Close() error { if t.link == nil { return nil } @@ -144,19 +145,19 @@ func (t *tunKernelDevice) Close() error { return closErr } -func (t *tunKernelDevice) WgAddress() WGAddress { +func (t *TunKernelDevice) WgAddress() WGAddress { return t.address } -func (t *tunKernelDevice) DeviceName() string { +func (t *TunKernelDevice) DeviceName() string { return t.name } -func (t *tunKernelDevice) Wrapper() *DeviceWrapper { +func (t *TunKernelDevice) FilteredDevice() *FilteredDevice { return nil } // assignAddr Adds IP address to the tunnel interface -func (t *tunKernelDevice) assignAddr() error { +func (t *TunKernelDevice) assignAddr() error { return t.link.assignAddr(t.address) } diff --git a/iface/tun_netstack.go b/client/iface/device/device_netstack.go similarity index 56% rename from iface/tun_netstack.go rename to client/iface/device/device_netstack.go index de1ff6654..440a1ca19 100644 --- a/iface/tun_netstack.go +++ b/client/iface/device/device_netstack.go @@ -1,7 +1,7 @@ //go:build !android // +build !android -package iface +package device import ( "fmt" @@ -10,11 +10,12 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/device" - "github.com/netbirdio/netbird/iface/bind" - "github.com/netbirdio/netbird/iface/netstack" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/netstack" ) -type tunNetstackDevice struct { +type TunNetstackDevice struct { name string address WGAddress port int @@ -23,15 +24,15 @@ type tunNetstackDevice struct { listenAddress string iceBind *bind.ICEBind - device *device.Device - wrapper *DeviceWrapper - nsTun *netstack.NetStackTun - udpMux *bind.UniversalUDPMuxDefault - configurer wgConfigurer + device *device.Device + filteredDevice *FilteredDevice + nsTun *netstack.NetStackTun + udpMux *bind.UniversalUDPMuxDefault + configurer WGConfigurer } -func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string, filterFn bind.FilterFn) wgTunDevice { - return &tunNetstackDevice{ +func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string, filterFn bind.FilterFn) *TunNetstackDevice { + return &TunNetstackDevice{ name: name, address: address, port: wgPort, @@ -42,23 +43,23 @@ func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string } } -func (t *tunNetstackDevice) Create() (wgConfigurer, error) { +func (t *TunNetstackDevice) Create() (WGConfigurer, error) { log.Info("create netstack tun interface") t.nsTun = netstack.NewNetStackTun(t.listenAddress, t.address.IP.String(), t.mtu) tunIface, err := t.nsTun.Create() if err != nil { return nil, fmt.Errorf("error creating tun device: %s", err) } - t.wrapper = newDeviceWrapper(tunIface) + t.filteredDevice = newDeviceFilter(tunIface) t.device = device.NewDevice( - t.wrapper, + t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "), ) - t.configurer = newWGUSPConfigurer(t.device, t.name) - err = t.configurer.configureInterface(t.key, t.port) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name) + err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { _ = tunIface.Close() return nil, fmt.Errorf("error configuring interface: %s", err) @@ -68,7 +69,7 @@ func (t *tunNetstackDevice) Create() (wgConfigurer, error) { return t.configurer, nil } -func (t *tunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) { if t.device == nil { return nil, fmt.Errorf("device is not ready yet") } @@ -87,13 +88,13 @@ func (t *tunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *tunNetstackDevice) UpdateAddr(WGAddress) error { +func (t *TunNetstackDevice) UpdateAddr(WGAddress) error { return nil } -func (t *tunNetstackDevice) Close() error { +func (t *TunNetstackDevice) Close() error { if t.configurer != nil { - t.configurer.close() + t.configurer.Close() } if t.device != nil { @@ -106,14 +107,14 @@ func (t *tunNetstackDevice) Close() error { return nil } -func (t *tunNetstackDevice) WgAddress() WGAddress { +func (t *TunNetstackDevice) WgAddress() WGAddress { return t.address } -func (t *tunNetstackDevice) DeviceName() string { +func (t *TunNetstackDevice) DeviceName() string { return t.name } -func (t *tunNetstackDevice) Wrapper() *DeviceWrapper { - return t.wrapper +func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice { + return t.filteredDevice } diff --git a/iface/tun_usp_unix.go b/client/iface/device/device_usp_unix.go similarity index 63% rename from iface/tun_usp_unix.go rename to client/iface/device/device_usp_unix.go index 1c1d3ac89..4175f6556 100644 --- a/iface/tun_usp_unix.go +++ b/client/iface/device/device_usp_unix.go @@ -1,6 +1,6 @@ //go:build (linux && !android) || freebsd -package iface +package device import ( "fmt" @@ -12,10 +12,11 @@ import ( "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" ) -type tunUSPDevice struct { +type USPDevice struct { name string address WGAddress port int @@ -23,39 +24,38 @@ type tunUSPDevice struct { mtu int iceBind *bind.ICEBind - device *device.Device - wrapper *DeviceWrapper - udpMux *bind.UniversalUDPMuxDefault - configurer wgConfigurer + device *device.Device + filteredDevice *FilteredDevice + udpMux *bind.UniversalUDPMuxDefault + configurer WGConfigurer } -func newTunUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice { +func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *USPDevice { log.Infof("using userspace bind mode") checkUser() - return &tunUSPDevice{ + return &USPDevice{ name: name, address: address, port: port, key: key, mtu: mtu, - iceBind: bind.NewICEBind(transportNet, filterFn), - } + iceBind: bind.NewICEBind(transportNet, filterFn)} } -func (t *tunUSPDevice) Create() (wgConfigurer, error) { +func (t *USPDevice) Create() (WGConfigurer, error) { log.Info("create tun interface") tunIface, err := tun.CreateTUN(t.name, t.mtu) if err != nil { log.Debugf("failed to create tun interface (%s, %d): %s", t.name, t.mtu, err) return nil, fmt.Errorf("error creating tun device: %s", err) } - t.wrapper = newDeviceWrapper(tunIface) + t.filteredDevice = newDeviceFilter(tunIface) // We need to create a wireguard-go device and listen to configuration requests t.device = device.NewDevice( - t.wrapper, + t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "), ) @@ -66,17 +66,17 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) { return nil, fmt.Errorf("error assigning ip: %s", err) } - t.configurer = newWGUSPConfigurer(t.device, t.name) - err = t.configurer.configureInterface(t.key, t.port) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name) + err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { t.device.Close() - t.configurer.close() + t.configurer.Close() return nil, fmt.Errorf("error configuring interface: %s", err) } return t.configurer, nil } -func (t *tunUSPDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) { if t.device == nil { return nil, fmt.Errorf("device is not ready yet") } @@ -96,14 +96,14 @@ func (t *tunUSPDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *tunUSPDevice) UpdateAddr(address WGAddress) error { +func (t *USPDevice) UpdateAddr(address WGAddress) error { t.address = address return t.assignAddr() } -func (t *tunUSPDevice) Close() error { +func (t *USPDevice) Close() error { if t.configurer != nil { - t.configurer.close() + t.configurer.Close() } if t.device != nil { @@ -116,20 +116,20 @@ func (t *tunUSPDevice) Close() error { return nil } -func (t *tunUSPDevice) WgAddress() WGAddress { +func (t *USPDevice) WgAddress() WGAddress { return t.address } -func (t *tunUSPDevice) DeviceName() string { +func (t *USPDevice) DeviceName() string { return t.name } -func (t *tunUSPDevice) Wrapper() *DeviceWrapper { - return t.wrapper +func (t *USPDevice) FilteredDevice() *FilteredDevice { + return t.filteredDevice } // assignAddr Adds IP address to the tunnel interface -func (t *tunUSPDevice) assignAddr() error { +func (t *USPDevice) assignAddr() error { link := newWGLink(t.name) return link.assignAddr(t.address) diff --git a/iface/tun_windows.go b/client/iface/device/device_windows.go similarity index 75% rename from iface/tun_windows.go rename to client/iface/device/device_windows.go index afb67bcc0..f3e216ccd 100644 --- a/iface/tun_windows.go +++ b/client/iface/device/device_windows.go @@ -1,4 +1,4 @@ -package iface +package device import ( "fmt" @@ -11,12 +11,13 @@ import ( "golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" ) const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}" -type tunDevice struct { +type TunDevice struct { name string address WGAddress port int @@ -26,13 +27,13 @@ type tunDevice struct { device *device.Device nativeTunDevice *tun.NativeTun - wrapper *DeviceWrapper + filteredDevice *FilteredDevice udpMux *bind.UniversalUDPMuxDefault - configurer wgConfigurer + configurer WGConfigurer } -func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice { - return &tunDevice{ +func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice { + return &TunDevice{ name: name, address: address, port: port, @@ -50,7 +51,7 @@ func getGUID() (windows.GUID, error) { return windows.GUIDFromString(guidString) } -func (t *tunDevice) Create() (wgConfigurer, error) { +func (t *TunDevice) Create() (WGConfigurer, error) { guid, err := getGUID() if err != nil { log.Errorf("failed to get GUID: %s", err) @@ -62,11 +63,11 @@ func (t *tunDevice) Create() (wgConfigurer, error) { return nil, fmt.Errorf("error creating tun device: %s", err) } t.nativeTunDevice = tunDevice.(*tun.NativeTun) - t.wrapper = newDeviceWrapper(tunDevice) + t.filteredDevice = newDeviceFilter(tunDevice) // We need to create a wireguard-go device and listen to configuration requests t.device = device.NewDevice( - t.wrapper, + t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "), ) @@ -92,17 +93,17 @@ func (t *tunDevice) Create() (wgConfigurer, error) { return nil, fmt.Errorf("error assigning ip: %s", err) } - t.configurer = newWGUSPConfigurer(t.device, t.name) - err = t.configurer.configureInterface(t.key, t.port) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name) + err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { t.device.Close() - t.configurer.close() + t.configurer.Close() return nil, fmt.Errorf("error configuring interface: %s", err) } return t.configurer, nil } -func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err @@ -117,14 +118,14 @@ func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *tunDevice) UpdateAddr(address WGAddress) error { +func (t *TunDevice) UpdateAddr(address WGAddress) error { t.address = address return t.assignAddr() } -func (t *tunDevice) Close() error { +func (t *TunDevice) Close() error { if t.configurer != nil { - t.configurer.close() + t.configurer.Close() } if t.device != nil { @@ -138,19 +139,19 @@ func (t *tunDevice) Close() error { } return nil } -func (t *tunDevice) WgAddress() WGAddress { +func (t *TunDevice) WgAddress() WGAddress { return t.address } -func (t *tunDevice) DeviceName() string { +func (t *TunDevice) DeviceName() string { return t.name } -func (t *tunDevice) Wrapper() *DeviceWrapper { - return t.wrapper +func (t *TunDevice) FilteredDevice() *FilteredDevice { + return t.filteredDevice } -func (t *tunDevice) getInterfaceGUIDString() (string, error) { +func (t *TunDevice) GetInterfaceGUIDString() (string, error) { if t.nativeTunDevice == nil { return "", fmt.Errorf("interface has not been initialized yet") } @@ -164,7 +165,7 @@ func (t *tunDevice) getInterfaceGUIDString() (string, error) { } // assignAddr Adds IP address to the tunnel interface and network route based on the range provided -func (t *tunDevice) assignAddr() error { +func (t *TunDevice) assignAddr() error { luid := winipcfg.LUID(t.nativeTunDevice.LUID()) log.Debugf("adding address %s to interface: %s", t.address.IP, t.name) return luid.SetIPAddresses([]netip.Prefix{netip.MustParsePrefix(t.address.String())}) diff --git a/client/iface/device/interface.go b/client/iface/device/interface.go new file mode 100644 index 000000000..0196b0085 --- /dev/null +++ b/client/iface/device/interface.go @@ -0,0 +1,20 @@ +package device + +import ( + "net" + "time" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/iface/configurer" +) + +type WGConfigurer interface { + ConfigureInterface(privateKey string, port int) error + UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + RemovePeer(peerKey string) error + AddAllowedIP(peerKey string, allowedIP string) error + RemoveAllowedIP(peerKey string, allowedIP string) error + Close() + GetStats(peerKey string) (configurer.WGStats, error) +} diff --git a/iface/module.go b/client/iface/device/kernel_module.go similarity index 92% rename from iface/module.go rename to client/iface/device/kernel_module.go index ca70cf3c7..1bdd6f7c6 100644 --- a/iface/module.go +++ b/client/iface/device/kernel_module.go @@ -1,6 +1,6 @@ //go:build (!linux && !freebsd) || android -package iface +package device // WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only) func WireGuardModuleIsLoaded() bool { diff --git a/iface/module_freebsd.go b/client/iface/device/kernel_module_freebsd.go similarity index 84% rename from iface/module_freebsd.go rename to client/iface/device/kernel_module_freebsd.go index 00ad882c2..dd6c8b408 100644 --- a/iface/module_freebsd.go +++ b/client/iface/device/kernel_module_freebsd.go @@ -1,4 +1,4 @@ -package iface +package device // WireGuardModuleIsLoaded check if kernel support wireguard func WireGuardModuleIsLoaded() bool { @@ -10,8 +10,8 @@ func WireGuardModuleIsLoaded() bool { return false } -// tunModuleIsLoaded check if tun module exist, if is not attempt to load it -func tunModuleIsLoaded() bool { +// ModuleTunIsLoaded check if tun module exist, if is not attempt to load it +func ModuleTunIsLoaded() bool { // Assume tun supported by freebsd kernel by default // TODO: implement check for module loaded in kernel or build-it return true diff --git a/iface/module_linux.go b/client/iface/device/kernel_module_linux.go similarity index 98% rename from iface/module_linux.go rename to client/iface/device/kernel_module_linux.go index 11c0482d5..0d195779d 100644 --- a/iface/module_linux.go +++ b/client/iface/device/kernel_module_linux.go @@ -1,7 +1,7 @@ //go:build linux && !android // Package iface provides wireguard network interface creation and management -package iface +package device import ( "bufio" @@ -66,8 +66,8 @@ func getModuleRoot() string { return filepath.Join(moduleLibDir, string(uname.Release[:i])) } -// tunModuleIsLoaded check if tun module exist, if is not attempt to load it -func tunModuleIsLoaded() bool { +// ModuleTunIsLoaded check if tun module exist, if is not attempt to load it +func ModuleTunIsLoaded() bool { _, err := os.Stat("/dev/net/tun") if err == nil { return true diff --git a/iface/module_linux_test.go b/client/iface/device/kernel_module_linux_test.go similarity index 98% rename from iface/module_linux_test.go rename to client/iface/device/kernel_module_linux_test.go index 97e9b1f78..de9656e47 100644 --- a/iface/module_linux_test.go +++ b/client/iface/device/kernel_module_linux_test.go @@ -1,4 +1,6 @@ -package iface +//go:build linux && !android + +package device import ( "bufio" @@ -132,7 +134,7 @@ func resetGlobals() { } func createFiles(t *testing.T) (string, []module) { - t.Helper() + t.Helper() writeFile := func(path, text string) { if err := os.WriteFile(path, []byte(text), 0644); err != nil { t.Fatal(err) @@ -168,7 +170,7 @@ func createFiles(t *testing.T) (string, []module) { } func getRandomLoadedModule(t *testing.T) (string, error) { - t.Helper() + t.Helper() f, err := os.Open("/proc/modules") if err != nil { return "", err diff --git a/iface/tun_link_freebsd.go b/client/iface/device/wg_link_freebsd.go similarity index 95% rename from iface/tun_link_freebsd.go rename to client/iface/device/wg_link_freebsd.go index be7921fdb..104010f47 100644 --- a/iface/tun_link_freebsd.go +++ b/client/iface/device/wg_link_freebsd.go @@ -1,10 +1,11 @@ -package iface +package device import ( "fmt" - "github.com/netbirdio/netbird/iface/freebsd" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/freebsd" ) type wgLink struct { diff --git a/iface/tun_link_linux.go b/client/iface/device/wg_link_linux.go similarity index 99% rename from iface/tun_link_linux.go rename to client/iface/device/wg_link_linux.go index 3ce644e84..a15cffe48 100644 --- a/iface/tun_link_linux.go +++ b/client/iface/device/wg_link_linux.go @@ -1,6 +1,6 @@ //go:build linux && !android -package iface +package device import ( "fmt" diff --git a/iface/wg_log.go b/client/iface/device/wg_log.go similarity index 93% rename from iface/wg_log.go rename to client/iface/device/wg_log.go index b44f6fc0b..db2f3111f 100644 --- a/iface/wg_log.go +++ b/client/iface/device/wg_log.go @@ -1,4 +1,4 @@ -package iface +package device import ( "os" diff --git a/client/iface/device/windows_guid.go b/client/iface/device/windows_guid.go new file mode 100644 index 000000000..1c7d40d13 --- /dev/null +++ b/client/iface/device/windows_guid.go @@ -0,0 +1,4 @@ +package device + +// CustomWindowsGUIDString is a custom GUID string for the interface +var CustomWindowsGUIDString string diff --git a/client/iface/device_android.go b/client/iface/device_android.go new file mode 100644 index 000000000..3d15080ff --- /dev/null +++ b/client/iface/device_android.go @@ -0,0 +1,16 @@ +package iface + +import ( + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" +) + +type WGTunDevice interface { + Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error) + Up() (*bind.UniversalUDPMuxDefault, error) + UpdateAddr(address WGAddress) error + WgAddress() WGAddress + DeviceName() string + Close() error + FilteredDevice() *device.FilteredDevice +} diff --git a/iface/freebsd/errors.go b/client/iface/freebsd/errors.go similarity index 100% rename from iface/freebsd/errors.go rename to client/iface/freebsd/errors.go diff --git a/iface/freebsd/iface.go b/client/iface/freebsd/iface.go similarity index 100% rename from iface/freebsd/iface.go rename to client/iface/freebsd/iface.go diff --git a/iface/freebsd/iface_internal_test.go b/client/iface/freebsd/iface_internal_test.go similarity index 100% rename from iface/freebsd/iface_internal_test.go rename to client/iface/freebsd/iface_internal_test.go diff --git a/iface/freebsd/link.go b/client/iface/freebsd/link.go similarity index 100% rename from iface/freebsd/link.go rename to client/iface/freebsd/link.go diff --git a/iface/iface.go b/client/iface/iface.go similarity index 79% rename from iface/iface.go rename to client/iface/iface.go index 545feffcf..accf5ce0a 100644 --- a/iface/iface.go +++ b/client/iface/iface.go @@ -9,28 +9,27 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/device" ) const ( - DefaultMTU = 1280 - DefaultWgPort = 51820 + DefaultMTU = 1280 + DefaultWgPort = 51820 + WgInterfaceDefault = configurer.WgInterfaceDefault ) -// WGIface represents a interface instance +type WGAddress = device.WGAddress + +// WGIface represents an interface instance type WGIface struct { - tun wgTunDevice + tun WGTunDevice userspaceBind bool mu sync.Mutex - configurer wgConfigurer - filter PacketFilter -} - -type WGStats struct { - LastHandshake time.Time - TxBytes int64 - RxBytes int64 + configurer device.WGConfigurer + filter device.PacketFilter } // IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind @@ -44,7 +43,7 @@ func (w *WGIface) Name() string { } // Address returns the interface address -func (w *WGIface) Address() WGAddress { +func (w *WGIface) Address() device.WGAddress { return w.tun.WgAddress() } @@ -75,7 +74,7 @@ func (w *WGIface) UpdateAddr(newAddr string) error { w.mu.Lock() defer w.mu.Unlock() - addr, err := parseWGAddress(newAddr) + addr, err := device.ParseWGAddress(newAddr) if err != nil { return err } @@ -90,7 +89,7 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.D defer w.mu.Unlock() log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint) - return w.configurer.updatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) + return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) } // RemovePeer removes a Wireguard Peer from the interface iface @@ -99,7 +98,7 @@ func (w *WGIface) RemovePeer(peerKey string) error { defer w.mu.Unlock() log.Debugf("Removing peer %s from interface %s ", peerKey, w.tun.DeviceName()) - return w.configurer.removePeer(peerKey) + return w.configurer.RemovePeer(peerKey) } // AddAllowedIP adds a prefix to the allowed IPs list of peer @@ -108,7 +107,7 @@ func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error { defer w.mu.Unlock() log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP) - return w.configurer.addAllowedIP(peerKey, allowedIP) + return w.configurer.AddAllowedIP(peerKey, allowedIP) } // RemoveAllowedIP removes a prefix from the allowed IPs list of peer @@ -117,7 +116,7 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error { defer w.mu.Unlock() log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP) - return w.configurer.removeAllowedIP(peerKey, allowedIP) + return w.configurer.RemoveAllowedIP(peerKey, allowedIP) } // Close closes the tunnel interface @@ -144,23 +143,23 @@ func (w *WGIface) Close() error { } // SetFilter sets packet filters for the userspace implementation -func (w *WGIface) SetFilter(filter PacketFilter) error { +func (w *WGIface) SetFilter(filter device.PacketFilter) error { w.mu.Lock() defer w.mu.Unlock() - if w.tun.Wrapper() == nil { + if w.tun.FilteredDevice() == nil { return fmt.Errorf("userspace packet filtering not handled on this device") } w.filter = filter w.filter.SetNetwork(w.tun.WgAddress().Network) - w.tun.Wrapper().SetFilter(filter) + w.tun.FilteredDevice().SetFilter(filter) return nil } // GetFilter returns packet filter used by interface if it uses userspace device implementation -func (w *WGIface) GetFilter() PacketFilter { +func (w *WGIface) GetFilter() device.PacketFilter { w.mu.Lock() defer w.mu.Unlock() @@ -168,16 +167,16 @@ func (w *WGIface) GetFilter() PacketFilter { } // GetDevice to interact with raw device (with filtering) -func (w *WGIface) GetDevice() *DeviceWrapper { +func (w *WGIface) GetDevice() *device.FilteredDevice { w.mu.Lock() defer w.mu.Unlock() - return w.tun.Wrapper() + return w.tun.FilteredDevice() } // GetStats returns the last handshake time, rx and tx bytes for the given peer -func (w *WGIface) GetStats(peerKey string) (WGStats, error) { - return w.configurer.getStats(peerKey) +func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) { + return w.configurer.GetStats(peerKey) } func (w *WGIface) waitUntilRemoved() error { diff --git a/iface/iface_android.go b/client/iface/iface_android.go similarity index 67% rename from iface/iface_android.go rename to client/iface/iface_android.go index 99f6885a5..5ed476e70 100644 --- a/iface/iface_android.go +++ b/client/iface/iface_android.go @@ -5,18 +5,19 @@ import ( "github.com/pion/transport/v3" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { - wgAddress, err := parseWGAddress(address) +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { + wgAddress, err := device.ParseWGAddress(address) if err != nil { return nil, err } wgIFace := &WGIface{ - tun: newTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter, filterFn), + tun: device.NewTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter, filterFn), userspaceBind: true, } return wgIFace, nil diff --git a/iface/iface_create.go b/client/iface/iface_create.go similarity index 100% rename from iface/iface_create.go rename to client/iface/iface_create.go diff --git a/iface/iface_darwin.go b/client/iface/iface_darwin.go similarity index 68% rename from iface/iface_darwin.go rename to client/iface/iface_darwin.go index f48f324c3..b46ea0f80 100644 --- a/iface/iface_darwin.go +++ b/client/iface/iface_darwin.go @@ -9,13 +9,14 @@ import ( "github.com/cenkalti/backoff/v4" "github.com/pion/transport/v3" - "github.com/netbirdio/netbird/iface/bind" - "github.com/netbirdio/netbird/iface/netstack" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/netstack" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, _ *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { - wgAddress, err := parseWGAddress(address) +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, _ *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { + wgAddress, err := device.ParseWGAddress(address) if err != nil { return nil, err } @@ -25,11 +26,11 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, } if netstack.IsEnabled() { - wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) + wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) return wgIFace, nil } - wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn) + wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn) return wgIFace, nil } diff --git a/iface/iface_destroy_bsd.go b/client/iface/iface_destroy_bsd.go similarity index 100% rename from iface/iface_destroy_bsd.go rename to client/iface/iface_destroy_bsd.go diff --git a/iface/iface_destroy_linux.go b/client/iface/iface_destroy_linux.go similarity index 100% rename from iface/iface_destroy_linux.go rename to client/iface/iface_destroy_linux.go diff --git a/iface/iface_destroy_mobile.go b/client/iface/iface_destroy_mobile.go similarity index 100% rename from iface/iface_destroy_mobile.go rename to client/iface/iface_destroy_mobile.go diff --git a/iface/iface_destroy_windows.go b/client/iface/iface_destroy_windows.go similarity index 100% rename from iface/iface_destroy_windows.go rename to client/iface/iface_destroy_windows.go diff --git a/iface/iface_ios.go b/client/iface/iface_ios.go similarity index 59% rename from iface/iface_ios.go rename to client/iface/iface_ios.go index 6babe5964..fc0214748 100644 --- a/iface/iface_ios.go +++ b/client/iface/iface_ios.go @@ -7,17 +7,18 @@ import ( "github.com/pion/transport/v3" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { - wgAddress, err := parseWGAddress(address) +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { + wgAddress, err := device.ParseWGAddress(address) if err != nil { return nil, err } wgIFace := &WGIface{ - tun: newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd, filterFn), + tun: device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd, filterFn), userspaceBind: true, } return wgIFace, nil diff --git a/iface/iface_moc.go b/client/iface/iface_moc.go similarity index 76% rename from iface/iface_moc.go rename to client/iface/iface_moc.go index fab3054a0..703da9ce0 100644 --- a/iface/iface_moc.go +++ b/client/iface/iface_moc.go @@ -6,7 +6,9 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/device" ) type MockWGIface struct { @@ -14,7 +16,7 @@ type MockWGIface struct { CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error IsUserspaceBindFunc func() bool NameFunc func() string - AddressFunc func() WGAddress + AddressFunc func() device.WGAddress ToInterfaceFunc func() *net.Interface UpFunc func() (*bind.UniversalUDPMuxDefault, error) UpdateAddrFunc func(newAddr string) error @@ -23,10 +25,10 @@ type MockWGIface struct { AddAllowedIPFunc func(peerKey string, allowedIP string) error RemoveAllowedIPFunc func(peerKey string, allowedIP string) error CloseFunc func() error - SetFilterFunc func(filter PacketFilter) error - GetFilterFunc func() PacketFilter - GetDeviceFunc func() *DeviceWrapper - GetStatsFunc func(peerKey string) (WGStats, error) + SetFilterFunc func(filter device.PacketFilter) error + GetFilterFunc func() device.PacketFilter + GetDeviceFunc func() *device.FilteredDevice + GetStatsFunc func(peerKey string) (configurer.WGStats, error) GetInterfaceGUIDStringFunc func() (string, error) } @@ -50,7 +52,7 @@ func (m *MockWGIface) Name() string { return m.NameFunc() } -func (m *MockWGIface) Address() WGAddress { +func (m *MockWGIface) Address() device.WGAddress { return m.AddressFunc() } @@ -86,18 +88,18 @@ func (m *MockWGIface) Close() error { return m.CloseFunc() } -func (m *MockWGIface) SetFilter(filter PacketFilter) error { +func (m *MockWGIface) SetFilter(filter device.PacketFilter) error { return m.SetFilterFunc(filter) } -func (m *MockWGIface) GetFilter() PacketFilter { +func (m *MockWGIface) GetFilter() device.PacketFilter { return m.GetFilterFunc() } -func (m *MockWGIface) GetDevice() *DeviceWrapper { +func (m *MockWGIface) GetDevice() *device.FilteredDevice { return m.GetDeviceFunc() } -func (m *MockWGIface) GetStats(peerKey string) (WGStats, error) { +func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) { return m.GetStatsFunc(peerKey) } diff --git a/iface/iface_test.go b/client/iface/iface_test.go similarity index 98% rename from iface/iface_test.go rename to client/iface/iface_test.go index 8de9f647e..87a68addb 100644 --- a/iface/iface_test.go +++ b/client/iface/iface_test.go @@ -14,6 +14,8 @@ import ( "github.com/stretchr/testify/assert" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/iface/device" ) // keep darwin compatibility @@ -414,7 +416,7 @@ func Test_ConnectPeers(t *testing.T) { } guid := fmt.Sprintf("{%s}", uuid.New().String()) - CustomWindowsGUIDString = strings.ToLower(guid) + device.CustomWindowsGUIDString = strings.ToLower(guid) iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, peer1wgPort, peer1Key.String(), DefaultMTU, newNet, nil, nil) if err != nil { @@ -436,7 +438,7 @@ func Test_ConnectPeers(t *testing.T) { } guid = fmt.Sprintf("{%s}", uuid.New().String()) - CustomWindowsGUIDString = strings.ToLower(guid) + device.CustomWindowsGUIDString = strings.ToLower(guid) newNet, err = stdnet.NewNet() if err != nil { diff --git a/iface/iface_unix.go b/client/iface/iface_unix.go similarity index 53% rename from iface/iface_unix.go rename to client/iface/iface_unix.go index 9608df1ad..09dbb2c1f 100644 --- a/iface/iface_unix.go +++ b/client/iface/iface_unix.go @@ -8,13 +8,14 @@ import ( "github.com/pion/transport/v3" - "github.com/netbirdio/netbird/iface/bind" - "github.com/netbirdio/netbird/iface/netstack" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/netstack" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { - wgAddress, err := parseWGAddress(address) +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { + wgAddress, err := device.ParseWGAddress(address) if err != nil { return nil, err } @@ -23,21 +24,21 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, // move the kernel/usp/netstack preference evaluation to upper layer if netstack.IsEnabled() { - wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) + wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) wgIFace.userspaceBind = true return wgIFace, nil } - if WireGuardModuleIsLoaded() { - wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet) + if device.WireGuardModuleIsLoaded() { + wgIFace.tun = device.NewKernelDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet) wgIFace.userspaceBind = false return wgIFace, nil } - if !tunModuleIsLoaded() { + if !device.ModuleTunIsLoaded() { return nil, fmt.Errorf("couldn't check or load tun module") } - wgIFace.tun = newTunUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, nil) + wgIFace.tun = device.NewUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, nil) wgIFace.userspaceBind = true return wgIFace, nil } diff --git a/iface/iface_windows.go b/client/iface/iface_windows.go similarity index 52% rename from iface/iface_windows.go rename to client/iface/iface_windows.go index c5edd27a9..6845ef3dd 100644 --- a/iface/iface_windows.go +++ b/client/iface/iface_windows.go @@ -5,13 +5,14 @@ import ( "github.com/pion/transport/v3" - "github.com/netbirdio/netbird/iface/bind" - "github.com/netbirdio/netbird/iface/netstack" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/netstack" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { - wgAddress, err := parseWGAddress(address) +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { + wgAddress, err := device.ParseWGAddress(address) if err != nil { return nil, err } @@ -21,11 +22,11 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, } if netstack.IsEnabled() { - wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) + wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) return wgIFace, nil } - wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn) + wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn) return wgIFace, nil } @@ -36,5 +37,5 @@ func (w *WGIface) CreateOnAndroid([]string, string, []string) error { // GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only func (w *WGIface) GetInterfaceGUIDString() (string, error) { - return w.tun.(*tunDevice).getInterfaceGUIDString() + return w.tun.(*device.TunDevice).GetInterfaceGUIDString() } diff --git a/iface/iwginterface.go b/client/iface/iwginterface.go similarity index 65% rename from iface/iwginterface.go rename to client/iface/iwginterface.go index 501f51d2b..cb6d7ccd9 100644 --- a/iface/iwginterface.go +++ b/client/iface/iwginterface.go @@ -8,7 +8,9 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/device" ) type IWGIface interface { @@ -16,7 +18,7 @@ type IWGIface interface { CreateOnAndroid(routeRange []string, ip string, domains []string) error IsUserspaceBind() bool Name() string - Address() WGAddress + Address() device.WGAddress ToInterface() *net.Interface Up() (*bind.UniversalUDPMuxDefault, error) UpdateAddr(newAddr string) error @@ -25,8 +27,8 @@ type IWGIface interface { AddAllowedIP(peerKey string, allowedIP string) error RemoveAllowedIP(peerKey string, allowedIP string) error Close() error - SetFilter(filter PacketFilter) error - GetFilter() PacketFilter - GetDevice() *DeviceWrapper - GetStats(peerKey string) (WGStats, error) + SetFilter(filter device.PacketFilter) error + GetFilter() device.PacketFilter + GetDevice() *device.FilteredDevice + GetStats(peerKey string) (configurer.WGStats, error) } diff --git a/iface/iwginterface_windows.go b/client/iface/iwginterface_windows.go similarity index 65% rename from iface/iwginterface_windows.go rename to client/iface/iwginterface_windows.go index b5053474e..6baeb66ae 100644 --- a/iface/iwginterface_windows.go +++ b/client/iface/iwginterface_windows.go @@ -6,7 +6,9 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/device" ) type IWGIface interface { @@ -14,7 +16,7 @@ type IWGIface interface { CreateOnAndroid(routeRange []string, ip string, domains []string) error IsUserspaceBind() bool Name() string - Address() WGAddress + Address() device.WGAddress ToInterface() *net.Interface Up() (*bind.UniversalUDPMuxDefault, error) UpdateAddr(newAddr string) error @@ -23,9 +25,9 @@ type IWGIface interface { AddAllowedIP(peerKey string, allowedIP string) error RemoveAllowedIP(peerKey string, allowedIP string) error Close() error - SetFilter(filter PacketFilter) error - GetFilter() PacketFilter - GetDevice() *DeviceWrapper - GetStats(peerKey string) (WGStats, error) + SetFilter(filter device.PacketFilter) error + GetFilter() device.PacketFilter + GetDevice() *device.FilteredDevice + GetStats(peerKey string) (configurer.WGStats, error) GetInterfaceGUIDString() (string, error) } diff --git a/iface/mocks/README.md b/client/iface/mocks/README.md similarity index 100% rename from iface/mocks/README.md rename to client/iface/mocks/README.md diff --git a/iface/mocks/filter.go b/client/iface/mocks/filter.go similarity index 97% rename from iface/mocks/filter.go rename to client/iface/mocks/filter.go index 2d80d69f1..6348e0e77 100644 --- a/iface/mocks/filter.go +++ b/client/iface/mocks/filter.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/netbirdio/netbird/iface (interfaces: PacketFilter) +// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter) // Package mocks is a generated GoMock package. package mocks diff --git a/iface/mocks/iface/mocks/filter.go b/client/iface/mocks/iface/mocks/filter.go similarity index 97% rename from iface/mocks/iface/mocks/filter.go rename to client/iface/mocks/iface/mocks/filter.go index 059a2b9a0..17e123abb 100644 --- a/iface/mocks/iface/mocks/filter.go +++ b/client/iface/mocks/iface/mocks/filter.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/netbirdio/netbird/iface (interfaces: PacketFilter) +// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter) // Package mocks is a generated GoMock package. package mocks diff --git a/iface/mocks/tun.go b/client/iface/mocks/tun.go similarity index 100% rename from iface/mocks/tun.go rename to client/iface/mocks/tun.go diff --git a/iface/netstack/dialer.go b/client/iface/netstack/dialer.go similarity index 100% rename from iface/netstack/dialer.go rename to client/iface/netstack/dialer.go diff --git a/iface/netstack/env.go b/client/iface/netstack/env.go similarity index 100% rename from iface/netstack/env.go rename to client/iface/netstack/env.go diff --git a/iface/netstack/proxy.go b/client/iface/netstack/proxy.go similarity index 100% rename from iface/netstack/proxy.go rename to client/iface/netstack/proxy.go diff --git a/iface/netstack/tun.go b/client/iface/netstack/tun.go similarity index 100% rename from iface/netstack/tun.go rename to client/iface/netstack/tun.go diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index eec3d3b8c..7d999669a 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -9,8 +9,8 @@ import ( "github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/acl/mocks" - "github.com/netbirdio/netbird/iface" mgmProto "github.com/netbirdio/netbird/management/proto" ) diff --git a/client/internal/acl/mocks/iface_mapper.go b/client/internal/acl/mocks/iface_mapper.go index 621b29513..3ed12b6dd 100644 --- a/client/internal/acl/mocks/iface_mapper.go +++ b/client/internal/acl/mocks/iface_mapper.go @@ -8,7 +8,8 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - iface "github.com/netbirdio/netbird/iface" + iface "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" ) // MockIFaceMapper is a mock of IFaceMapper interface. @@ -77,7 +78,7 @@ func (mr *MockIFaceMapperMockRecorder) Name() *gomock.Call { } // SetFilter mocks base method. -func (m *MockIFaceMapper) SetFilter(arg0 iface.PacketFilter) error { +func (m *MockIFaceMapper) SetFilter(arg0 device.PacketFilter) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SetFilter", arg0) ret0, _ := ret[0].(error) diff --git a/client/internal/config.go b/client/internal/config.go index 1df1e0547..ee54c6380 100644 --- a/client/internal/config.go +++ b/client/internal/config.go @@ -16,9 +16,9 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/client/ssh" - "github.com/netbirdio/netbird/iface" mgm "github.com/netbirdio/netbird/management/client" "github.com/netbirdio/netbird/util" ) diff --git a/client/internal/connect.go b/client/internal/connect.go index 36b340cfb..c77f95603 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -17,13 +17,14 @@ import ( "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" - "github.com/netbirdio/netbird/iface" mgm "github.com/netbirdio/netbird/management/client" mgmProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/relay/auth/hmac" @@ -70,7 +71,7 @@ func (c *ConnectClient) RunWithProbes( // RunOnAndroid with main logic on mobile system func (c *ConnectClient) RunOnAndroid( - tunAdapter iface.TunAdapter, + tunAdapter device.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, networkChangeListener listener.NetworkChangeListener, dnsAddresses []string, @@ -205,7 +206,7 @@ func (c *ConnectClient) run( localPeerState := peer.LocalPeerState{ IP: loginResp.GetPeerConfig().GetAddress(), PubKey: myPrivateKey.PublicKey().String(), - KernelInterface: iface.WireGuardModuleIsLoaded(), + KernelInterface: device.WireGuardModuleIsLoaded(), FQDN: loginResp.GetPeerConfig().GetFqdn(), } c.statusRecorder.UpdateLocalPeerState(localPeerState) diff --git a/client/internal/dns/response_writer_test.go b/client/internal/dns/response_writer_test.go index 5a0047700..857964406 100644 --- a/client/internal/dns/response_writer_test.go +++ b/client/internal/dns/response_writer_test.go @@ -9,7 +9,7 @@ import ( "github.com/google/gopacket/layers" "github.com/miekg/dns" - "github.com/netbirdio/netbird/iface/mocks" + "github.com/netbirdio/netbird/client/iface/mocks" ) func TestResponseWriterLocalAddr(t *testing.T) { diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index b9552bc17..53d18a678 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -15,16 +15,18 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/firewall/uspfilter" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/device" + pfmock "github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/stdnet" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/formatter" - "github.com/netbirdio/netbird/iface" - pfmock "github.com/netbirdio/netbird/iface/mocks" ) type mocWGIface struct { - filter iface.PacketFilter + filter device.PacketFilter } func (w *mocWGIface) Name() string { @@ -43,11 +45,11 @@ func (w *mocWGIface) ToInterface() *net.Interface { panic("implement me") } -func (w *mocWGIface) GetFilter() iface.PacketFilter { +func (w *mocWGIface) GetFilter() device.PacketFilter { return w.filter } -func (w *mocWGIface) GetDevice() *iface.DeviceWrapper { +func (w *mocWGIface) GetDevice() *device.FilteredDevice { panic("implement me") } @@ -59,13 +61,13 @@ func (w *mocWGIface) IsUserspaceBind() bool { return false } -func (w *mocWGIface) SetFilter(filter iface.PacketFilter) error { +func (w *mocWGIface) SetFilter(filter device.PacketFilter) error { w.filter = filter return nil } -func (w *mocWGIface) GetStats(_ string) (iface.WGStats, error) { - return iface.WGStats{}, nil +func (w *mocWGIface) GetStats(_ string) (configurer.WGStats, error) { + return configurer.WGStats{}, nil } var zoneRecords = []nbdns.SimpleRecord{ diff --git a/client/internal/dns/wgiface.go b/client/internal/dns/wgiface.go index 2f08e8d52..69bc83659 100644 --- a/client/internal/dns/wgiface.go +++ b/client/internal/dns/wgiface.go @@ -5,7 +5,9 @@ package dns import ( "net" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/device" ) // WGIface defines subset methods of interface required for manager @@ -14,7 +16,7 @@ type WGIface interface { Address() iface.WGAddress ToInterface() *net.Interface IsUserspaceBind() bool - GetFilter() iface.PacketFilter - GetDevice() *iface.DeviceWrapper - GetStats(peerKey string) (iface.WGStats, error) + GetFilter() device.PacketFilter + GetDevice() *device.FilteredDevice + GetStats(peerKey string) (configurer.WGStats, error) } diff --git a/client/internal/dns/wgiface_windows.go b/client/internal/dns/wgiface_windows.go index f8bb80fb9..765132fdb 100644 --- a/client/internal/dns/wgiface_windows.go +++ b/client/internal/dns/wgiface_windows.go @@ -1,14 +1,18 @@ package dns -import "github.com/netbirdio/netbird/iface" +import ( + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/device" +) // WGIface defines subset methods of interface required for manager type WGIface interface { Name() string Address() iface.WGAddress IsUserspaceBind() bool - GetFilter() iface.PacketFilter - GetDevice() *iface.DeviceWrapper - GetStats(peerKey string) (iface.WGStats, error) + GetFilter() device.PacketFilter + GetDevice() *device.FilteredDevice + GetStats(peerKey string) (configurer.WGStats, error) GetInterfaceGUIDString() (string, error) } diff --git a/client/internal/engine.go b/client/internal/engine.go index 998cbce2d..c51901a22 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -23,9 +23,12 @@ import ( "github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/relay" @@ -36,8 +39,6 @@ import ( nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/iface" - "github.com/netbirdio/netbird/iface/bind" mgm "github.com/netbirdio/netbird/management/client" "github.com/netbirdio/netbird/management/domain" mgmProto "github.com/netbirdio/netbird/management/proto" @@ -619,7 +620,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { e.statusRecorder.UpdateLocalPeerState(peer.LocalPeerState{ IP: e.config.WgAddr, PubKey: e.config.WgPrivateKey.PublicKey().String(), - KernelInterface: iface.WireGuardModuleIsLoaded(), + KernelInterface: device.WireGuardModuleIsLoaded(), FQDN: conf.GetFqdn(), }) @@ -1165,15 +1166,15 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) { log.Errorf("failed to create pion's stdnet: %s", err) } - var mArgs *iface.MobileIFaceArguments + var mArgs *device.MobileIFaceArguments switch runtime.GOOS { case "android": - mArgs = &iface.MobileIFaceArguments{ + mArgs = &device.MobileIFaceArguments{ TunAdapter: e.mobileDep.TunAdapter, TunFd: int(e.mobileDep.FileDescriptor), } case "ios": - mArgs = &iface.MobileIFaceArguments{ + mArgs = &device.MobileIFaceArguments{ TunFd: int(e.mobileDep.FileDescriptor), } default: diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 95aadf141..29a8439a2 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -25,14 +25,15 @@ import ( "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/iface" - "github.com/netbirdio/netbird/iface/bind" mgmt "github.com/netbirdio/netbird/management/client" mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server" @@ -874,7 +875,7 @@ func TestEngine_MultiplePeers(t *testing.T) { mu.Lock() defer mu.Unlock() guid := fmt.Sprintf("{%s}", uuid.New().String()) - iface.CustomWindowsGUIDString = strings.ToLower(guid) + device.CustomWindowsGUIDString = strings.ToLower(guid) err = engine.Start() if err != nil { t.Errorf("unable to start engine for peer %d with error %v", j, err) diff --git a/client/internal/mobile_dependency.go b/client/internal/mobile_dependency.go index 2355c67c3..2b0c92cc6 100644 --- a/client/internal/mobile_dependency.go +++ b/client/internal/mobile_dependency.go @@ -1,16 +1,16 @@ package internal import ( + "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/stdnet" - "github.com/netbirdio/netbird/iface" ) // MobileDependency collect all dependencies for mobile platform type MobileDependency struct { // Android only - TunAdapter iface.TunAdapter + TunAdapter device.TunAdapter IFaceDiscover stdnet.ExternalIFaceDiscover NetworkChangeListener listener.NetworkChangeListener HostDNSAddresses []string diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index baff1372a..ad84bd700 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -15,9 +15,10 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/wgproxy" - "github.com/netbirdio/netbird/iface" relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" nbnet "github.com/netbirdio/netbird/util/net" @@ -684,7 +685,7 @@ func (conn *Conn) setStatusToDisconnected() { // todo rethink status updates conn.log.Debugf("error while updating peer's state, err: %v", err) } - if err := conn.statusRecorder.UpdateWireGuardPeerState(conn.config.Key, iface.WGStats{}); err != nil { + if err := conn.statusRecorder.UpdateWireGuardPeerState(conn.config.Key, configurer.WGStats{}); err != nil { conn.log.Debugf("failed to reset wireguard stats for peer: %s", err) } } diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go index 22e5409f8..b4926a9d2 100644 --- a/client/internal/peer/conn_test.go +++ b/client/internal/peer/conn_test.go @@ -9,9 +9,9 @@ import ( "github.com/magiconair/properties/assert" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/wgproxy" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/util" ) diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 915fa63f0..a28992fac 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -11,8 +11,8 @@ import ( "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" + "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/internal/relay" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/management/domain" relayClient "github.com/netbirdio/netbird/relay/client" ) @@ -203,7 +203,7 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) { state, ok := d.peers[peerPubKey] if !ok { - return State{}, iface.ErrPeerNotFound + return State{}, configurer.ErrPeerNotFound } return state, nil } @@ -412,7 +412,7 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error { } // UpdateWireGuardPeerState updates the WireGuard bits of the peer state -func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats iface.WGStats) error { +func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats configurer.WGStats) error { d.mux.Lock() defer d.mux.Unlock() diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 8bf1b7568..c4e9d1950 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -15,9 +15,9 @@ import ( "github.com/pion/stun/v2" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/internal/stdnet" - "github.com/netbirdio/netbird/iface" - "github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/route" ) diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index db2caea7f..eaa232151 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -10,12 +10,12 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/iface" nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/static" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" ) diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go index e86a52810..ac94d4a5c 100644 --- a/client/internal/routemanager/dynamic/route.go +++ b/client/internal/routemanager/dynamic/route.go @@ -13,10 +13,10 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/util" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/route" ) diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index d97fe631f..d7ddf7ae8 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -14,6 +14,8 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager/notifier" @@ -21,7 +23,6 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routeselector" - "github.com/netbirdio/netbird/iface" relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" nbnet "github.com/netbirdio/netbird/util/net" @@ -102,7 +103,7 @@ func NewManager( }, func(prefix netip.Prefix, peerKey string) error { if err := wgInterface.RemoveAllowedIP(peerKey, prefix.String()); err != nil { - if !errors.Is(err, iface.ErrPeerNotFound) && !errors.Is(err, iface.ErrAllowedIPNotFound) { + if !errors.Is(err, configurer.ErrPeerNotFound) && !errors.Is(err, configurer.ErrAllowedIPNotFound) { return err } log.Tracef("Remove allowed IPs %s for %s: %v", prefix, peerKey, err) diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 2995e2740..2f26f7a5e 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -12,8 +12,8 @@ import ( "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" ) diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 58a66715c..908279c88 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -5,9 +5,9 @@ import ( "fmt" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/routeselector" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/util/net" ) diff --git a/client/internal/routemanager/server_android.go b/client/internal/routemanager/server_android.go index 2057b9cc8..c75a0a7f2 100644 --- a/client/internal/routemanager/server_android.go +++ b/client/internal/routemanager/server_android.go @@ -7,8 +7,8 @@ import ( "fmt" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" ) func newServerRouter(context.Context, iface.IWGIface, firewall.Manager, *peer.Status) (serverRouter, error) { diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index 1d1a4b063..ef38d5707 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -11,9 +11,9 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" ) diff --git a/client/internal/routemanager/sysctl/sysctl_linux.go b/client/internal/routemanager/sysctl/sysctl_linux.go index 13e1229f8..bb620ee68 100644 --- a/client/internal/routemanager/sysctl/sysctl_linux.go +++ b/client/internal/routemanager/sysctl/sysctl_linux.go @@ -13,7 +13,7 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" ) const ( diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go index 10944c1e2..d1cb83bfb 100644 --- a/client/internal/routemanager/systemops/systemops.go +++ b/client/internal/routemanager/systemops/systemops.go @@ -5,9 +5,9 @@ import ( "net/netip" "sync" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/routemanager/notifier" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" - "github.com/netbirdio/netbird/iface" ) type Nexthop struct { diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 90f06ba78..9258f4a4e 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -16,10 +16,10 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/vars" - "github.com/netbirdio/netbird/iface" nbnet "github.com/netbirdio/netbird/util/net" ) diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index 94965c119..238225807 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -19,7 +19,7 @@ import ( "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" ) type dialer interface { diff --git a/iface/tun.go b/iface/tun.go deleted file mode 100644 index 7d0a57ed6..000000000 --- a/iface/tun.go +++ /dev/null @@ -1,21 +0,0 @@ -//go:build !android -// +build !android - -package iface - -import ( - "github.com/netbirdio/netbird/iface/bind" -) - -// CustomWindowsGUIDString is a custom GUID string for the interface -var CustomWindowsGUIDString string - -type wgTunDevice interface { - Create() (wgConfigurer, error) - Up() (*bind.UniversalUDPMuxDefault, error) - UpdateAddr(address WGAddress) error - WgAddress() WGAddress - DeviceName() string - Close() error - Wrapper() *DeviceWrapper // todo eliminate this function -} diff --git a/iface/wg_configurer.go b/iface/wg_configurer.go deleted file mode 100644 index dd38ba075..000000000 --- a/iface/wg_configurer.go +++ /dev/null @@ -1,21 +0,0 @@ -package iface - -import ( - "errors" - "net" - "time" - - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" -) - -var ErrPeerNotFound = errors.New("peer not found") - -type wgConfigurer interface { - configureInterface(privateKey string, port int) error - updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error - removePeer(peerKey string) error - addAllowedIP(peerKey string, allowedIP string) error - removeAllowedIP(peerKey string, allowedIP string) error - close() - getStats(peerKey string) (WGStats, error) -} diff --git a/util/net/net.go b/util/net/net.go index 8d1fcebd0..61b47dbe7 100644 --- a/util/net/net.go +++ b/util/net/net.go @@ -4,7 +4,7 @@ import ( "net" "os" - "github.com/netbirdio/netbird/iface/netstack" + "github.com/netbirdio/netbird/client/iface/netstack" "github.com/google/uuid" ) From 8934453b30309e508df09236ac102c0537259291 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 2 Oct 2024 18:29:51 +0200 Subject: [PATCH 16/37] Update management base docker image (#2687) --- management/Dockerfile | 4 ++-- management/Dockerfile.debug | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/management/Dockerfile b/management/Dockerfile index cac640bf4..3b2df2623 100644 --- a/management/Dockerfile +++ b/management/Dockerfile @@ -1,5 +1,5 @@ -FROM ubuntu:22.04 +FROM ubuntu:24.04 RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt ENTRYPOINT [ "/go/bin/netbird-mgmt","management"] CMD ["--log-file", "console"] -COPY netbird-mgmt /go/bin/netbird-mgmt \ No newline at end of file +COPY netbird-mgmt /go/bin/netbird-mgmt diff --git a/management/Dockerfile.debug b/management/Dockerfile.debug index f4be366a8..4d9730bd7 100644 --- a/management/Dockerfile.debug +++ b/management/Dockerfile.debug @@ -1,4 +1,4 @@ -FROM ubuntu:22.04 +FROM ubuntu:24.04 RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt ENTRYPOINT [ "/go/bin/netbird-mgmt","management","--log-level","debug"] CMD ["--log-file", "console"] From 158936fb15596690003d602c5df918f6522b97c1 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 3 Oct 2024 15:50:35 +0200 Subject: [PATCH 17/37] [management] Remove file store (#2689) --- client/cmd/testutil_test.go | 13 +- client/internal/engine_test.go | 21 +- client/server/server_test.go | 2 +- client/testdata/store.json | 38 - client/testdata/store.sqlite | Bin 0 -> 163840 bytes management/client/client_test.go | 28 +- management/server/account_test.go | 2 +- management/server/dns_test.go | 2 +- management/server/file_store.go | 791 +----------------- management/server/file_store_test.go | 655 --------------- management/server/management_proto_test.go | 49 +- management/server/management_test.go | 9 +- management/server/nameserver_test.go | 2 +- management/server/peer_test.go | 18 +- management/server/route_test.go | 4 +- management/server/sql_store.go | 22 + management/server/sql_store_test.go | 144 ++-- management/server/store.go | 56 +- management/server/store_test.go | 21 +- .../server/testdata/extended-store.json | 120 --- .../server/testdata/extended-store.sqlite | Bin 0 -> 163840 bytes management/server/testdata/store.json | 88 -- management/server/testdata/store.sqlite | Bin 0 -> 163840 bytes .../server/testdata/store_policy_migrate.json | 116 --- .../testdata/store_policy_migrate.sqlite | Bin 0 -> 163840 bytes .../testdata/store_with_expired_peers.json | 130 --- .../testdata/store_with_expired_peers.sqlite | Bin 0 -> 163840 bytes management/server/testdata/storev1.json | 154 ---- management/server/testdata/storev1.sqlite | Bin 0 -> 163840 bytes management/server/user_test.go | 65 +- 30 files changed, 259 insertions(+), 2291 deletions(-) delete mode 100644 client/testdata/store.json create mode 100644 client/testdata/store.sqlite delete mode 100644 management/server/file_store_test.go delete mode 100644 management/server/testdata/extended-store.json create mode 100644 management/server/testdata/extended-store.sqlite delete mode 100644 management/server/testdata/store.json create mode 100644 management/server/testdata/store.sqlite delete mode 100644 management/server/testdata/store_policy_migrate.json create mode 100644 management/server/testdata/store_policy_migrate.sqlite delete mode 100644 management/server/testdata/store_with_expired_peers.json create mode 100644 management/server/testdata/store_with_expired_peers.sqlite delete mode 100644 management/server/testdata/storev1.json create mode 100644 management/server/testdata/storev1.sqlite diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index f0dc8bf21..033d1bb6a 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -3,7 +3,6 @@ package cmd import ( "context" "net" - "path/filepath" "testing" "time" @@ -34,18 +33,12 @@ func startTestingServices(t *testing.T) string { if err != nil { t.Fatal(err) } - testDir := t.TempDir() - config.Datadir = testDir - err = util.CopyFileContents("../testdata/store.json", filepath.Join(testDir, "store.json")) - if err != nil { - t.Fatal(err) - } _, signalLis := startSignal(t) signalAddr := signalLis.Addr().String() config.Signal.URI = signalAddr - _, mgmLis := startManagement(t, config) + _, mgmLis := startManagement(t, config, "../testdata/store.sqlite") mgmAddr := mgmLis.Addr().String() return mgmAddr } @@ -70,7 +63,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) { return s, lis } -func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) { +func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.Server, net.Listener) { t.Helper() lis, err := net.Listen("tcp", ":0") @@ -78,7 +71,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste t.Fatal(err) } s := grpc.NewServer() - store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir) + store, cleanUp, err := mgmt.NewTestStoreFromSqlite(context.Background(), testFile, t.TempDir()) if err != nil { t.Fatal(err) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 29a8439a2..3d1983c6b 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -6,7 +6,6 @@ import ( "net" "net/netip" "os" - "path/filepath" "runtime" "strings" "sync" @@ -824,20 +823,6 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { func TestEngine_MultiplePeers(t *testing.T) { // log.SetLevel(log.DebugLevel) - dir := t.TempDir() - - err := util.CopyFileContents("../testdata/store.json", filepath.Join(dir, "store.json")) - if err != nil { - t.Fatal(err) - } - defer func() { - err = os.Remove(filepath.Join(dir, "store.json")) //nolint - if err != nil { - t.Fatal(err) - return - } - }() - ctx, cancel := context.WithCancel(CtxInitState(context.Background())) defer cancel() @@ -847,7 +832,7 @@ func TestEngine_MultiplePeers(t *testing.T) { return } defer sigServer.Stop() - mgmtServer, mgmtAddr, err := startManagement(t, dir) + mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sqlite") if err != nil { t.Fatal(err) return @@ -1070,7 +1055,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) { return s, lis.Addr().String(), nil } -func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error) { +func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) { t.Helper() config := &server.Config{ @@ -1095,7 +1080,7 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error) } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir) + store, cleanUp, err := server.NewTestStoreFromSqlite(context.Background(), testFile, config.Datadir) if err != nil { return nil, "", err } diff --git a/client/server/server_test.go b/client/server/server_test.go index 9b18df4d3..e534ad7e2 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -110,7 +110,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve return nil, "", err } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir) + store, cleanUp, err := server.NewTestStoreFromSqlite(context.Background(), "", config.Datadir) if err != nil { return nil, "", err } diff --git a/client/testdata/store.json b/client/testdata/store.json deleted file mode 100644 index 8236f2703..000000000 --- a/client/testdata/store.json +++ /dev/null @@ -1,38 +0,0 @@ -{ - "Accounts": { - "bf1c8084-ba50-4ce7-9439-34653001fc3b": { - "Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b", - "SetupKeys": { - "A2C8E62B-38F5-4553-B31E-DD66C696CEBB": { - "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", - "Name": "Default key", - "Type": "reusable", - "CreatedAt": "2021-08-19T20:46:20.005936822+02:00", - "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00", - "Revoked": false, - "UsedTimes": 0 - - } - }, - "Network": { - "Id": "af1c8024-ha40-4ce2-9418-34653101fc3c", - "Net": { - "IP": "100.64.0.0", - "Mask": "//8AAA==" - }, - "Dns": null - }, - "Peers": {}, - "Users": { - "edafee4e-63fb-11ec-90d6-0242ac120003": { - "Id": "edafee4e-63fb-11ec-90d6-0242ac120003", - "Role": "admin" - }, - "f4f6d672-63fb-11ec-90d6-0242ac120003": { - "Id": "f4f6d672-63fb-11ec-90d6-0242ac120003", - "Role": "user" - } - } - } - } -} \ No newline at end of file diff --git a/client/testdata/store.sqlite b/client/testdata/store.sqlite new file mode 100644 index 0000000000000000000000000000000000000000..118c2bebc9f1fd29751627c36304d301ba156781 GIT binary patch literal 163840 zcmeI5Piz}ke#beIMbfe*#>plVXQN$8iDR>oY?6`{$qU1kWjbrbmJ`Xwu3-ej5&1|> zM9wfXL)k*G1!yP1Vu512XpTMf)I$&3Vu3d2JrzCdvA|x7_S#zy1+s_!=FPm}d*slP z>)j$DzO>BYy!U(mKEL1l{ob45pPi3xTRzJ-9Jg)y`Q_A0DNRfLc|M;?rQW3f7wEru zyg@Hch!Z+$((6$#-%MR}|8^=6&V3Weyqf!F`ZuQ^PG3n+P3NY*p4yq*o4PQ0LNh16 zPW?9Z75$^fyDz8pYiG5TW!Kqb@6#4_&@j8c!_SST>vTJwQ3{W*^yXg5=pU_Xy{kOi zsy{Y5%=H}GY#C;)#yrpPoqc9|M%QDmVbzm!&ung2HttpOx3+FnewwcyT}?HAcPn2_ zuB>{Y8Z}leJlUDe=ZVVk!Lx6**w55nFR7B}y1?cCemxV3dJUv2DjU3f;vGC@2iBa5}AD)khf|*P2n{hnd~!L38Rb)tvbGsbO;(loLLN zN)MRR?UQN!yslZX#fC-q8Nngx=}~o)fLdvgQ22Q!!;SNo?Z`(_6}+CTYMeZ+dW3QE zCPHuIF~8d}qy(!Y4699MY3w3FY&wSHDPsG~wOLC^syVJOf+{c7X_KhyFL75&CJP3G_n#T`C}o8vQdIHDfk&1IkE=EB-3jM6MRd63TN*XFg} z+1Llg(_Ido{lv_$Q|f0*!8UYI{j34k_w=0;8GU|U`%>r5sxQ}nHKn?(e#Ue7E_tFz zA!kIS3(cIyt*ni@onT2+A4qlC0WC;nvn5lfg&8fo$YaDYFlk{`O$gCl&ulWIWwkA8 zJ1(7Iy3b)oePSSH$f008+;rB%ndHBJkF#brE5<~$~fN$AN z>W1ZUpK4H)A+ri%$-s&*7UjPF7!LQ@1*DF|f-Y-zTc#Ur5)Lf($QWLS#?sM#km5(j zY>x8!vGgYt?V}RIu^~yS)-tWO;;a=Zv(<7QjY34rUe~1UC%)aB3#DT5;_Ax9;>F^^;=&!%+o!Zkm)15mHm+V> z_}Rmcx4!z--BfCtI;Q1*pQ1lJKmY_l00ck)1V8`;KmY_l00ck)1ioVg&QIPyd7KLl z;r{;*soWpFV*^BYK>!3m00ck)1V8`;KmY_l00ck)1dc^ua&qQm`2PP~DwjJ}5eR_* z2!H?xfB*=900@8p2!H?xfWSBq=-FCs=KlS+=5}^=ve|2Avf1O@MBgZHu2ojc*O!*p zZeCtmxqNwf>H2c1vUKCd>gwj|`s!xo`t=*EVRl=7p1uvqW%QjL^u-hOX8+Psac!xz zo-Y@#tgK!s)0_2+m)Dn9*UIJa{r=_UKxBQXv^HAg=vs>O>4GKxmoFCSQv~$)E`NUb zG4B75gV|yXAOHd&00JNY0w4eaAOHd&00JN|nm};>AD{mp%?CLk00JNY0w4eaAOHd& z00JNY0wC~hCve=){|E2?ryu>p0|Y<-1V8`;KmY_l00ck)1V8`;Kwyjs1o!{3{~u!o zV`v}%0w4eaAOHd&00JNY0w4eaAV37L|A!fX00@8p2!H?xfB*=900@8p2!O!&6Ttp| z{B?{Gf&d7B00@8p2!H?xfB*=900@9U@cI9V*?&#t>hy#M2!H?xfB*=900@8p2!H?x zfB*=5+X?hupP0RW|J7`E=1lh5oma9mjg`h~ef7Qa(&}0Qiavk<2!H?xfB*=900@8p2!H?xfPg|E*#95r zO$&Vge;h&l{-44ZCm;X#3c|znss$B&$jB1jSh1?$2MDrS*tP6Gkj;C z*`5*n$f9A@L&SgDHj;aAW^=o;aj%lUwRNNN(|q;lu2%DRxAN8G%BmNt(SWLjCp)tl zeSKZ~e8V?)TP(Ria%wK>)4h%Bw=2A@KWb$OdZ^|v(9o-SpFQ>$<1s|1w%KMPYtQud zSUuF9NMnyXmTUT!V;gnTC(+>9x7uv6T5}n_t{c0DLRh4jzAV==JwLDo>Ox%j!|hvl zHntz+Kd3widskT8yt}n?Z+qj`*1deSvCoaiGdiZ>dq%UGzj=4Na_jxAs6ecU{C4GL zWxKMqS=q_+p?Rbx8f>AkJCo5b>soIrw$h-G7nrEiwSBTnem*;FB7<_FQSel%ykDLg zIi1$8>Z9gOSWk%`W}aR>mCwxrJ37F5k! zp4SZu2bPz=QMtKs|MtE70{wXE!eZ63eb$UD%ViDHM>WsWn46F(3Hyf~G5DjcC}J1- zq69;d*cKF95(Q&(njSQ#4pYsEpPw2wr$IU4W2p3iIo&>)*3avj6OXtv5rby}w1aNM>cHBgV^83!~na!`uYBR?^@Yb7R!ok`5H zng7Uf_YJdNcPX|pbrjyA=d*S&11$=FG)Z=utMnrh=(tY3Tl1Ca#)lZ?w@tg-Fl&C- z9VpN21Iu;nHnYj6h6>>28dj^{6&ZJ(0&mOQ?V}C!Yk0otcY_Ra#U5kI+?54tj%~9V zd39;V%lUlb8$2kClyTXoU5i*5W`~A-V3Lw!?-95+KGh>z^7dp)^5~S=K}40?m)Z;i zQG!S?u46hS*4bm^A(d{CmkXczFAAT6gk{Oqvq6OL+xi;{-L7ZIH$=Ro}u`- z=kQ&8h7Je8WBaZ`_E}*nti8k!ClAsY{o1_NI~)6;c)H7>v!9q5c1rzBDcFV%s-HFB z`kuaXBBRgGYhUWzS@q@muclPD)z5ei-z85JDUOT?U!j@PxRvEhw-c;l>I11RJD^3c zY_??Tw3eck5_ya`uqCa9stIwo>zPevw5+y8ZO5e(O!qm=s88$#Y-xEat-n7%YRMxt zQR=@=?@ng)@WqE(MxUG0enGB| z9Egw*#|hPXM6zl-V*Ps5iZ*I9^2gzsK^}xP9XDPK)50vu>9|(gB)1VYD!D@uT&z0z zcdKFboxjMTqwxD8i9GyZQW$?EO9`R?DZsbvCUwK|xKA}G%8*$Fv1DMG7mIRVe+-BF z>;h6pVnLTRyDig=*7XM#dt?kRE@SCv&q?v4V>U;5{aE@FiuO^7;nmX`vV2yh8OE3PjT7pq-z2^0`5@*s|@HZVMq3zsI+`sUoIxh725 zD2Z2-wA0Q8yX~3m?^1f|$Ej1L>0h1vuPfCO8=zT$?q z)zYA!y&Ky6L0bPwX_RFSvo%V?Y4eRYGx~dHwcg3txadNYAFLbWEqa*n!(k&Dk__WQ zm^@%CPtU!P(a)aMz6^~elnI~bRHKNB=jH1q{muwIMt)eF4k&S-v5vtjk`zrvq*ZT% z^#txjRLEV6YC6L8hijI2J46)=wg4ai}ufsFd0=^vKPu2a6eB@e~&LrpYHdKg``9adV-Mz zl^o@Sp4UP`{gaJ*!9@n6w#RDwLa|3{F(MO9EZ#r$ifkVpX-mE+D@sP1EL?i^S&JG$ z#Rzpq4ZVKqFTzDlWOM#Yo2&+ypt`CDAMVTIxI z%c^3dww^<+^3~Htx+xJ{FpT_zSWX7lDgGZWo`!vY0oVTA_l%P7o&3q}=X{V$>&E&> zXF}u?Zg0^glce;QQYr1dKibWqgoahuFmOdpjvU{u`#GPS%jlObX}!sCNk4emRV;zx z>#!Hp`8AXuVB+hp{^e}#rwJ%DZ&}O{Qmvd<&5nMy)iUs0u@Gju}|KbV;jvxR6 zAOHd&00JNY0w4eaAOHd&Funxv{eR=DTZ{|@KmY_l00ck)1V8`;KmY_l00aa9-2WFO za0CGm009sH0T2KI5C8!X009sHf$=4P{r~vt79#@z5C8!X009sH0T2KI5C8!X00BV& z@BbGha0CGm009sH0T2KI5C8!X009sHf$=4P`~TysTZ{|@KmY_l00ck)1V8`;KmY_l z00aa9?EeJ`96s009sH0T2KI z5C8!X009sHfk6V;{|^$uIS7CN2!H?xfB*=900@8p2!H?xj2{8K|9|{+iV=YT2!H?x VfB*=900@8p2!H?xfWRPu{{>_kT=f6| literal 0 HcmV?d00001 diff --git a/management/client/client_test.go b/management/client/client_test.go index a082e354b..313a67617 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -47,25 +47,18 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { level, _ := log.ParseLevel("debug") log.SetLevel(level) - testDir := t.TempDir() - config := &mgmt.Config{} _, err := util.ReadJson("../server/testdata/management.json", config) if err != nil { t.Fatal(err) } - config.Datadir = testDir - err = util.CopyFileContents("../server/testdata/store.json", filepath.Join(testDir, "store.json")) - if err != nil { - t.Fatal(err) - } lis, err := net.Listen("tcp", ":0") if err != nil { t.Fatal(err) } s := grpc.NewServer() - store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir) + store, cleanUp, err := NewSqliteTestStore(t, context.Background(), "../server/testdata/store.sqlite") if err != nil { t.Fatal(err) } @@ -521,3 +514,22 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) { assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientID, flowInfo.ProviderConfig.ClientID, "provider configured client ID should match") assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientSecret, flowInfo.ProviderConfig.ClientSecret, "provider configured client secret should match") } + +func NewSqliteTestStore(t *testing.T, ctx context.Context, testFile string) (mgmt.Store, func(), error) { + t.Helper() + dataDir := t.TempDir() + err := util.CopyFileContents(testFile, filepath.Join(dataDir, "store.db")) + if err != nil { + t.Fatal(err) + } + + store, err := mgmt.NewSqliteStore(ctx, dataDir, nil) + if err != nil { + return nil, nil, err + } + + return store, func() { + store.Close(ctx) + os.Remove(filepath.Join(dataDir, "store.db")) + }, nil +} diff --git a/management/server/account_test.go b/management/server/account_test.go index e554ae493..198775bc3 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2366,7 +2366,7 @@ func createManager(t TB) (*DefaultAccountManager, error) { func createStore(t TB) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir) + store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir) if err != nil { return nil, err } diff --git a/management/server/dns_test.go b/management/server/dns_test.go index e033c1a21..23941495e 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -210,7 +210,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { func createDNSStore(t *testing.T) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir) + store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir) if err != nil { return nil, err } diff --git a/management/server/file_store.go b/management/server/file_store.go index 994a4b1ee..df3e9bb77 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -2,24 +2,18 @@ package server import ( "context" - "errors" - "net" "os" "path/filepath" "strings" "sync" "time" - "github.com/netbirdio/netbird/dns" - nbgroup "github.com/netbirdio/netbird/management/server/group" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/management/server/status" - "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/netbirdio/netbird/route" "github.com/rs/xid" log "github.com/sirupsen/logrus" + nbgroup "github.com/netbirdio/netbird/management/server/group" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/util" ) @@ -42,167 +36,9 @@ type FileStore struct { mux sync.Mutex `json:"-"` storeFile string `json:"-"` - // sync.Mutex indexed by resource ID - resourceLocks sync.Map `json:"-"` - globalAccountLock sync.Mutex `json:"-"` - metrics telemetry.AppMetrics `json:"-"` } -func (s *FileStore) ExecuteInTransaction(ctx context.Context, f func(store Store) error) error { - return f(s) -} - -func (s *FileStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKeyID)] - if !ok { - return status.NewSetupKeyNotFoundError() - } - - account, err := s.getAccount(accountID) - if err != nil { - return err - } - - account.SetupKeys[setupKeyID].UsedTimes++ - - return s.SaveAccount(ctx, account) -} - -func (s *FileStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return err - } - - allGroup, err := account.GetGroupAll() - if err != nil || allGroup == nil { - return errors.New("all group not found") - } - - allGroup.Peers = append(allGroup.Peers, peerID) - - return nil -} - -func (s *FileStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountId) - if err != nil { - return err - } - - account.Groups[groupID].Peers = append(account.Groups[groupID].Peers, peerId) - - return nil -} - -func (s *FileStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, ok := s.Accounts[peer.AccountID] - if !ok { - return status.NewAccountNotFoundError(peer.AccountID) - } - - account.Peers[peer.ID] = peer - return s.SaveAccount(ctx, account) -} - -func (s *FileStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, ok := s.Accounts[accountId] - if !ok { - return status.NewAccountNotFoundError(accountId) - } - - account.Network.Serial++ - - return s.SaveAccount(ctx, account) -} - -func (s *FileStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(key)] - if !ok { - return nil, status.NewSetupKeyNotFoundError() - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - setupKey, ok := account.SetupKeys[key] - if !ok { - return nil, status.Errorf(status.NotFound, "setup key not found") - } - - return setupKey, nil -} - -func (s *FileStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - var takenIps []net.IP - for _, existingPeer := range account.Peers { - takenIps = append(takenIps, existingPeer.IP) - } - - return takenIps, nil -} - -func (s *FileStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - existingLabels := []string{} - for _, peer := range account.Peers { - if peer.DNSLabel != "" { - existingLabels = append(existingLabels, peer.DNSLabel) - } - } - return existingLabels, nil -} - -func (s *FileStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - return account.Network, nil -} - -type StoredAccount struct{} - // NewFileStore restores a store from the file located in the datadir func NewFileStore(ctx context.Context, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) { fs, err := restore(ctx, filepath.Join(dataDir, storeFileName)) @@ -213,25 +49,6 @@ func NewFileStore(ctx context.Context, dataDir string, metrics telemetry.AppMetr return fs, nil } -// NewFilestoreFromSqliteStore restores a store from Sqlite and stores to Filestore json in the file located in datadir -func NewFilestoreFromSqliteStore(ctx context.Context, sqlStore *SqlStore, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) { - store, err := NewFileStore(ctx, dataDir, metrics) - if err != nil { - return nil, err - } - - err = store.SaveInstallationID(ctx, sqlStore.GetInstallationID()) - if err != nil { - return nil, err - } - - for _, account := range sqlStore.GetAllAccounts(ctx) { - store.Accounts[account.Id] = account - } - - return store, store.persist(ctx, store.storeFile) -} - // restore the state of the store from the file. // Creates a new empty store file if doesn't exist func restore(ctx context.Context, file string) (*FileStore, error) { @@ -240,7 +57,6 @@ func restore(ctx context.Context, file string) (*FileStore, error) { s := &FileStore{ Accounts: make(map[string]*Account), mux: sync.Mutex{}, - globalAccountLock: sync.Mutex{}, SetupKeyID2AccountID: make(map[string]string), PeerKeyID2AccountID: make(map[string]string), UserID2AccountID: make(map[string]string), @@ -416,252 +232,6 @@ func (s *FileStore) persist(ctx context.Context, file string) error { return nil } -// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock -func (s *FileStore) AcquireGlobalLock(ctx context.Context) (unlock func()) { - log.WithContext(ctx).Debugf("acquiring global lock") - start := time.Now() - s.globalAccountLock.Lock() - - unlock = func() { - s.globalAccountLock.Unlock() - log.WithContext(ctx).Debugf("released global lock in %v", time.Since(start)) - } - - took := time.Since(start) - log.WithContext(ctx).Debugf("took %v to acquire global lock", took) - if s.metrics != nil { - s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took) - } - - return unlock -} - -// AcquireWriteLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock -func (s *FileStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) { - log.WithContext(ctx).Debugf("acquiring lock for ID %s", uniqueID) - start := time.Now() - value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.Mutex{}) - mtx := value.(*sync.Mutex) - mtx.Lock() - - unlock = func() { - mtx.Unlock() - log.WithContext(ctx).Debugf("released lock for ID %s in %v", uniqueID, time.Since(start)) - } - - return unlock -} - -// AcquireReadLockByUID acquires an ID lock for reading a resource and returns a function that releases the lock -// This method is still returns a write lock as file store can't handle read locks -func (s *FileStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) { - return s.AcquireWriteLockByUID(ctx, uniqueID) -} - -func (s *FileStore) SaveAccount(ctx context.Context, account *Account) error { - s.mux.Lock() - defer s.mux.Unlock() - - if account.Id == "" { - return status.Errorf(status.InvalidArgument, "account id should not be empty") - } - - accountCopy := account.Copy() - - s.Accounts[accountCopy.Id] = accountCopy - - // todo check that account.Id and keyId are not exist already - // because if keyId exists for other accounts this can be bad - for keyID := range accountCopy.SetupKeys { - s.SetupKeyID2AccountID[strings.ToUpper(keyID)] = accountCopy.Id - } - - // enforce peer to account index and delete peer to route indexes for rebuild - for _, peer := range accountCopy.Peers { - s.PeerKeyID2AccountID[peer.Key] = accountCopy.Id - s.PeerID2AccountID[peer.ID] = accountCopy.Id - } - - for _, user := range accountCopy.Users { - s.UserID2AccountID[user.Id] = accountCopy.Id - for _, pat := range user.PATs { - s.TokenID2UserID[pat.ID] = user.Id - s.HashedPAT2TokenID[pat.HashedToken] = pat.ID - } - } - - if accountCopy.DomainCategory == PrivateCategory && accountCopy.IsDomainPrimaryAccount { - s.PrivateDomain2AccountID[accountCopy.Domain] = accountCopy.Id - } - - return s.persist(ctx, s.storeFile) -} - -func (s *FileStore) DeleteAccount(ctx context.Context, account *Account) error { - s.mux.Lock() - defer s.mux.Unlock() - - if account.Id == "" { - return status.Errorf(status.InvalidArgument, "account id should not be empty") - } - - for keyID := range account.SetupKeys { - delete(s.SetupKeyID2AccountID, strings.ToUpper(keyID)) - } - - // enforce peer to account index and delete peer to route indexes for rebuild - for _, peer := range account.Peers { - delete(s.PeerKeyID2AccountID, peer.Key) - delete(s.PeerID2AccountID, peer.ID) - } - - for _, user := range account.Users { - for _, pat := range user.PATs { - delete(s.TokenID2UserID, pat.ID) - delete(s.HashedPAT2TokenID, pat.HashedToken) - } - delete(s.UserID2AccountID, user.Id) - } - - if account.DomainCategory == PrivateCategory && account.IsDomainPrimaryAccount { - delete(s.PrivateDomain2AccountID, account.Domain) - } - - delete(s.Accounts, account.Id) - - return s.persist(ctx, s.storeFile) -} - -// DeleteHashedPAT2TokenIDIndex removes an entry from the indexing map HashedPAT2TokenID -func (s *FileStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error { - s.mux.Lock() - defer s.mux.Unlock() - - delete(s.HashedPAT2TokenID, hashedToken) - - return nil -} - -// DeleteTokenID2UserIDIndex removes an entry from the indexing map TokenID2UserID -func (s *FileStore) DeleteTokenID2UserIDIndex(tokenID string) error { - s.mux.Lock() - defer s.mux.Unlock() - - delete(s.TokenID2UserID, tokenID) - - return nil -} - -// GetAccountByPrivateDomain returns account by private domain -func (s *FileStore) GetAccountByPrivateDomain(_ context.Context, domain string) (*Account, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.PrivateDomain2AccountID[strings.ToLower(domain)] - if !ok { - return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - return account.Copy(), nil -} - -// GetAccountBySetupKey returns account by setup key id -func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*Account, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)] - if !ok { - return nil, status.NewSetupKeyNotFoundError() - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - return account.Copy(), nil -} - -// GetTokenIDByHashedToken returns the id of a personal access token by its hashed secret -func (s *FileStore) GetTokenIDByHashedToken(_ context.Context, token string) (string, error) { - s.mux.Lock() - defer s.mux.Unlock() - - tokenID, ok := s.HashedPAT2TokenID[token] - if !ok { - return "", status.Errorf(status.NotFound, "tokenID not found: provided token doesn't exists") - } - - return tokenID, nil -} - -// GetUserByTokenID returns a User object a tokenID belongs to -func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User, error) { - s.mux.Lock() - defer s.mux.Unlock() - - userID, ok := s.TokenID2UserID[tokenID] - if !ok { - return nil, status.Errorf(status.NotFound, "user not found: provided tokenID doesn't exists") - } - - accountID, ok := s.UserID2AccountID[userID] - if !ok { - return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists") - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - return account.Users[userID].Copy(), nil -} - -func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID string) (*User, error) { - accountID, ok := s.UserID2AccountID[userID] - if !ok { - return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists") - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - user := account.Users[userID].Copy() - pat := make([]PersonalAccessToken, 0, len(user.PATs)) - for _, token := range user.PATs { - if token != nil { - pat = append(pat, *token) - } - } - user.PATsG = pat - - return user, nil -} - -func (s *FileStore) GetAccountGroups(_ context.Context, accountID string) ([]*nbgroup.Group, error) { - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - groupsSlice := make([]*nbgroup.Group, 0, len(account.Groups)) - - for _, group := range account.Groups { - groupsSlice = append(groupsSlice, group) - } - - return groupsSlice, nil -} - // GetAllAccounts returns all accounts func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) { s.mux.Lock() @@ -673,278 +243,6 @@ func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) { return all } -// getAccount returns a reference to the Account. Should not return a copy. -func (s *FileStore) getAccount(accountID string) (*Account, error) { - account, ok := s.Accounts[accountID] - if !ok { - return nil, status.NewAccountNotFoundError(accountID) - } - - return account, nil -} - -// GetAccount returns an account for ID -func (s *FileStore) GetAccount(_ context.Context, accountID string) (*Account, error) { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - return account.Copy(), nil -} - -// GetAccountByUser returns a user account -func (s *FileStore) GetAccountByUser(_ context.Context, userID string) (*Account, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.UserID2AccountID[userID] - if !ok { - return nil, status.NewUserNotFoundError(userID) - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - return account.Copy(), nil -} - -// GetAccountByPeerID returns an account for a given peer ID -func (s *FileStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.PeerID2AccountID[peerID] - if !ok { - return nil, status.Errorf(status.NotFound, "provided peer ID doesn't exists %s", peerID) - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - // this protection is needed because when we delete a peer, we don't really remove index peerID -> accountID. - // check Account.Peers for a match - if _, ok := account.Peers[peerID]; !ok { - delete(s.PeerID2AccountID, peerID) - log.WithContext(ctx).Warnf("removed stale peerID %s to accountID %s index", peerID, accountID) - return nil, status.NewPeerNotFoundError(peerID) - } - - return account.Copy(), nil -} - -// GetAccountByPeerPubKey returns an account for a given peer WireGuard public key -func (s *FileStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.PeerKeyID2AccountID[peerKey] - if !ok { - return nil, status.NewPeerNotFoundError(peerKey) - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - // this protection is needed because when we delete a peer, we don't really remove index peerKey -> accountID. - // check Account.Peers for a match - stale := true - for _, peer := range account.Peers { - if peer.Key == peerKey { - stale = false - break - } - } - if stale { - delete(s.PeerKeyID2AccountID, peerKey) - log.WithContext(ctx).Warnf("removed stale peerKey %s to accountID %s index", peerKey, accountID) - return nil, status.NewPeerNotFoundError(peerKey) - } - - return account.Copy(), nil -} - -func (s *FileStore) GetAccountIDByPeerPubKey(_ context.Context, peerKey string) (string, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.PeerKeyID2AccountID[peerKey] - if !ok { - return "", status.NewPeerNotFoundError(peerKey) - } - - return accountID, nil -} - -func (s *FileStore) GetAccountIDByUserID(userID string) (string, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.UserID2AccountID[userID] - if !ok { - return "", status.NewUserNotFoundError(userID) - } - - return accountID, nil -} - -func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) (string, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)] - if !ok { - return "", status.NewSetupKeyNotFoundError() - } - - return accountID, nil -} - -func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, _ LockingStrength, peerKey string) (*nbpeer.Peer, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.PeerKeyID2AccountID[peerKey] - if !ok { - return nil, status.NewPeerNotFoundError(peerKey) - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - for _, peer := range account.Peers { - if peer.Key == peerKey { - return peer.Copy(), nil - } - } - - return nil, status.NewPeerNotFoundError(peerKey) -} - -func (s *FileStore) GetAccountSettings(_ context.Context, _ LockingStrength, accountID string) (*Settings, error) { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - return account.Settings.Copy(), nil -} - -// GetInstallationID returns the installation ID from the store -func (s *FileStore) GetInstallationID() string { - return s.InstallationID -} - -// SaveInstallationID saves the installation ID -func (s *FileStore) SaveInstallationID(ctx context.Context, ID string) error { - s.mux.Lock() - defer s.mux.Unlock() - - s.InstallationID = ID - - return s.persist(ctx, s.storeFile) -} - -// SavePeer saves the peer in the account -func (s *FileStore) SavePeer(_ context.Context, accountID string, peer *nbpeer.Peer) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return err - } - - newPeer := peer.Copy() - - account.Peers[peer.ID] = newPeer - - s.PeerKeyID2AccountID[peer.Key] = accountID - s.PeerID2AccountID[peer.ID] = accountID - - return nil -} - -// SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things. -// PeerStatus will be saved eventually when some other changes occur. -func (s *FileStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return err - } - - peer := account.Peers[peerID] - if peer == nil { - return status.Errorf(status.NotFound, "peer %s not found", peerID) - } - - peer.Status = &peerStatus - - return nil -} - -// SavePeerLocation stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things. -// Peer.Location will be saved eventually when some other changes occur. -func (s *FileStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return err - } - - peer := account.Peers[peerWithLocation.ID] - if peer == nil { - return status.Errorf(status.NotFound, "peer %s not found", peerWithLocation.ID) - } - - peer.Location = peerWithLocation.Location - - return nil -} - -// SaveUserLastLogin stores the last login time for a user in memory. It doesn't attempt to persist data to speed up things. -func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID string, lastLogin time.Time) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return err - } - - peer := account.Users[userID] - if peer == nil { - return status.Errorf(status.NotFound, "user %s not found", userID) - } - - peer.LastLogin = lastLogin - - return nil -} - -func (s *FileStore) GetPostureCheckByChecksDefinition(_ string, _ *posture.ChecksDefinition) (*posture.Checks, error) { - return nil, status.Errorf(status.Internal, "GetPostureCheckByChecksDefinition is not implemented") -} - // Close the FileStore persisting data to disk func (s *FileStore) Close(ctx context.Context) error { s.mux.Lock() @@ -959,86 +257,3 @@ func (s *FileStore) Close(ctx context.Context) error { func (s *FileStore) GetStoreEngine() StoreEngine { return FileStoreEngine } - -func (s *FileStore) SaveUsers(_ string, _ map[string]*User) error { - return status.Errorf(status.Internal, "SaveUsers is not implemented") -} - -func (s *FileStore) SaveGroups(_ string, _ map[string]*nbgroup.Group) error { - return status.Errorf(status.Internal, "SaveGroups is not implemented") -} - -func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ LockingStrength, _ string) (string, error) { - return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented") -} - -func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ LockingStrength, accountID string) (string, string, error) { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return "", "", err - } - - return account.Domain, account.DomainCategory, nil -} - -// AccountExists checks whether an account exists by the given ID. -func (s *FileStore) AccountExists(_ context.Context, _ LockingStrength, id string) (bool, error) { - _, exists := s.Accounts[id] - return exists, nil -} - -func (s *FileStore) GetAccountDNSSettings(_ context.Context, _ LockingStrength, _ string) (*DNSSettings, error) { - return nil, status.Errorf(status.Internal, "GetAccountDNSSettings is not implemented") -} - -func (s *FileStore) GetGroupByID(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) { - return nil, status.Errorf(status.Internal, "GetGroupByID is not implemented") -} - -func (s *FileStore) GetGroupByName(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) { - return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented") -} - -func (s *FileStore) GetAccountPolicies(_ context.Context, _ LockingStrength, _ string) ([]*Policy, error) { - return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented") -} - -func (s *FileStore) GetPolicyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*Policy, error) { - return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented") - -} - -func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ LockingStrength, _ string) ([]*posture.Checks, error) { - return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented") -} - -func (s *FileStore) GetPostureChecksByID(_ context.Context, _ LockingStrength, _ string, _ string) (*posture.Checks, error) { - return nil, status.Errorf(status.Internal, "GetPostureChecksByID is not implemented") -} - -func (s *FileStore) GetAccountRoutes(_ context.Context, _ LockingStrength, _ string) ([]*route.Route, error) { - return nil, status.Errorf(status.Internal, "GetAccountRoutes is not implemented") -} - -func (s *FileStore) GetRouteByID(_ context.Context, _ LockingStrength, _ string, _ string) (*route.Route, error) { - return nil, status.Errorf(status.Internal, "GetRouteByID is not implemented") -} - -func (s *FileStore) GetAccountSetupKeys(_ context.Context, _ LockingStrength, _ string) ([]*SetupKey, error) { - return nil, status.Errorf(status.Internal, "GetAccountSetupKeys is not implemented") -} - -func (s *FileStore) GetSetupKeyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*SetupKey, error) { - return nil, status.Errorf(status.Internal, "GetSetupKeyByID is not implemented") -} - -func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ LockingStrength, _ string) ([]*dns.NameServerGroup, error) { - return nil, status.Errorf(status.Internal, "GetAccountNameServerGroups is not implemented") -} - -func (s *FileStore) GetNameServerGroupByID(_ context.Context, _ LockingStrength, _ string, _ string) (*dns.NameServerGroup, error) { - return nil, status.Errorf(status.Internal, "GetNameServerGroupByID is not implemented") -} diff --git a/management/server/file_store_test.go b/management/server/file_store_test.go deleted file mode 100644 index 56e46b696..000000000 --- a/management/server/file_store_test.go +++ /dev/null @@ -1,655 +0,0 @@ -package server - -import ( - "context" - "crypto/sha256" - "net" - "path/filepath" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/netbirdio/netbird/management/server/group" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/util" -) - -type accounts struct { - Accounts map[string]*Account -} - -func TestStalePeerIndices(t *testing.T) { - storeDir := t.TempDir() - - err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json")) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - return - } - - account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") - require.NoError(t, err) - - peerID := "some_peer" - peerKey := "some_peer_key" - account.Peers[peerID] = &nbpeer.Peer{ - ID: peerID, - Key: peerKey, - } - - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err) - - account.DeletePeer(peerID) - - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err) - - _, err = store.GetAccountByPeerID(context.Background(), peerID) - require.Error(t, err, "expecting to get an error when found stale index") - - _, err = store.GetAccountByPeerPubKey(context.Background(), peerKey) - require.Error(t, err, "expecting to get an error when found stale index") -} - -func TestNewStore(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) - - if store.Accounts == nil || len(store.Accounts) != 0 { - t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") - } - - if store.SetupKeyID2AccountID == nil || len(store.SetupKeyID2AccountID) != 0 { - t.Errorf("expected to create a new empty SetupKeyID2AccountID map when creating a new FileStore") - } - - if store.PeerKeyID2AccountID == nil || len(store.PeerKeyID2AccountID) != 0 { - t.Errorf("expected to create a new empty PeerKeyID2AccountID map when creating a new FileStore") - } - - if store.UserID2AccountID == nil || len(store.UserID2AccountID) != 0 { - t.Errorf("expected to create a new empty UserID2AccountID map when creating a new FileStore") - } - - if store.HashedPAT2TokenID == nil || len(store.HashedPAT2TokenID) != 0 { - t.Errorf("expected to create a new empty HashedPAT2TokenID map when creating a new FileStore") - } - - if store.TokenID2UserID == nil || len(store.TokenID2UserID) != 0 { - t.Errorf("expected to create a new empty TokenID2UserID map when creating a new FileStore") - } -} - -func TestSaveAccount(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) - - account := newAccountWithId(context.Background(), "account_id", "testuser", "") - setupKey := GenerateDefaultSetupKey() - account.SetupKeys[setupKey.Key] = setupKey - account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } - - // SaveAccount should trigger persist - err := store.SaveAccount(context.Background(), account) - if err != nil { - return - } - - if store.Accounts[account.Id] == nil { - t.Errorf("expecting Account to be stored after SaveAccount()") - } - - if store.PeerKeyID2AccountID["peerkey"] == "" { - t.Errorf("expecting PeerKeyID2AccountID index updated after SaveAccount()") - } - - if store.UserID2AccountID["testuser"] == "" { - t.Errorf("expecting UserID2AccountID index updated after SaveAccount()") - } - - if store.SetupKeyID2AccountID[setupKey.Key] == "" { - t.Errorf("expecting SetupKeyID2AccountID index updated after SaveAccount()") - } -} - -func TestDeleteAccount(t *testing.T) { - storeDir := t.TempDir() - storeFile := filepath.Join(storeDir, "store.json") - err := util.CopyFileContents("testdata/store.json", storeFile) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - t.Fatal(err) - } - defer store.Close(context.Background()) - - var account *Account - for _, a := range store.Accounts { - account = a - break - } - - require.NotNil(t, account, "failed to restore a FileStore file and get at least one account") - - err = store.DeleteAccount(context.Background(), account) - require.NoError(t, err, "failed to delete account, error: %v", err) - - _, ok := store.Accounts[account.Id] - require.False(t, ok, "failed to delete account") - - for id := range account.Users { - _, ok := store.UserID2AccountID[id] - assert.False(t, ok, "failed to delete UserID2AccountID index") - for _, pat := range account.Users[id].PATs { - _, ok := store.HashedPAT2TokenID[pat.HashedToken] - assert.False(t, ok, "failed to delete HashedPAT2TokenID index") - _, ok = store.TokenID2UserID[pat.ID] - assert.False(t, ok, "failed to delete TokenID2UserID index") - } - } - - for _, p := range account.Peers { - _, ok := store.PeerKeyID2AccountID[p.Key] - assert.False(t, ok, "failed to delete PeerKeyID2AccountID index") - _, ok = store.PeerID2AccountID[p.ID] - assert.False(t, ok, "failed to delete PeerID2AccountID index") - } - - for id := range account.SetupKeys { - _, ok := store.SetupKeyID2AccountID[id] - assert.False(t, ok, "failed to delete SetupKeyID2AccountID index") - } - - _, ok = store.PrivateDomain2AccountID[account.Domain] - assert.False(t, ok, "failed to delete PrivateDomain2AccountID index") - -} - -func TestStore(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) - - account := newAccountWithId(context.Background(), "account_id", "testuser", "") - account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } - account.Groups["all"] = &group.Group{ - ID: "all", - Name: "all", - Peers: []string{"testpeer"}, - } - account.Policies = append(account.Policies, &Policy{ - ID: "all", - Name: "all", - Enabled: true, - Rules: []*PolicyRule{ - { - ID: "all", - Name: "all", - Sources: []string{"all"}, - Destinations: []string{"all"}, - }, - }, - }) - account.Policies = append(account.Policies, &Policy{ - ID: "dmz", - Name: "dmz", - Enabled: true, - Rules: []*PolicyRule{ - { - ID: "dmz", - Name: "dmz", - Enabled: true, - Sources: []string{"all"}, - Destinations: []string{"all"}, - }, - }, - }) - - // SaveAccount should trigger persist - err := store.SaveAccount(context.Background(), account) - if err != nil { - return - } - - restored, err := NewFileStore(context.Background(), store.storeFile, nil) - if err != nil { - return - } - - restoredAccount := restored.Accounts[account.Id] - if restoredAccount == nil { - t.Errorf("failed to restore a FileStore file - missing Account %s", account.Id) - return - } - - if restoredAccount.Peers["testpeer"] == nil { - t.Errorf("failed to restore a FileStore file - missing Peer testpeer") - } - - if restoredAccount.CreatedBy != "testuser" { - t.Errorf("failed to restore a FileStore file - missing Account CreatedBy") - } - - if restoredAccount.Users["testuser"] == nil { - t.Errorf("failed to restore a FileStore file - missing User testuser") - } - - if restoredAccount.Network == nil { - t.Errorf("failed to restore a FileStore file - missing Network") - } - - if restoredAccount.Groups["all"] == nil { - t.Errorf("failed to restore a FileStore file - missing Group all") - } - - if len(restoredAccount.Policies) != 2 { - t.Errorf("failed to restore a FileStore file - missing Policies") - return - } - - assert.Equal(t, account.Policies[0], restoredAccount.Policies[0], "failed to restore a FileStore file - missing Policy all") - assert.Equal(t, account.Policies[1], restoredAccount.Policies[1], "failed to restore a FileStore file - missing Policy dmz") -} - -func TestRestore(t *testing.T) { - storeDir := t.TempDir() - - err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json")) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - return - } - - account := store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] - - require.NotNil(t, account, "failed to restore a FileStore file - missing account bf1c8084-ba50-4ce7-9439-34653001fc3b") - - require.NotNil(t, account.Users["edafee4e-63fb-11ec-90d6-0242ac120003"], "failed to restore a FileStore file - missing Account User edafee4e-63fb-11ec-90d6-0242ac120003") - - require.NotNil(t, account.Users["f4f6d672-63fb-11ec-90d6-0242ac120003"], "failed to restore a FileStore file - missing Account User f4f6d672-63fb-11ec-90d6-0242ac120003") - - require.NotNil(t, account.Network, "failed to restore a FileStore file - missing Account Network") - - require.NotNil(t, account.SetupKeys["A2C8E62B-38F5-4553-B31E-DD66C696CEBB"], "failed to restore a FileStore file - missing Account SetupKey A2C8E62B-38F5-4553-B31E-DD66C696CEBB") - - require.NotNil(t, account.Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"], "failed to restore a FileStore wrong PATs length") - - require.Len(t, store.UserID2AccountID, 2, "failed to restore a FileStore wrong UserID2AccountID mapping length") - - require.Len(t, store.SetupKeyID2AccountID, 1, "failed to restore a FileStore wrong SetupKeyID2AccountID mapping length") - - require.Len(t, store.PrivateDomain2AccountID, 1, "failed to restore a FileStore wrong PrivateDomain2AccountID mapping length") - - require.Len(t, store.HashedPAT2TokenID, 1, "failed to restore a FileStore wrong HashedPAT2TokenID mapping length") - - require.Len(t, store.TokenID2UserID, 1, "failed to restore a FileStore wrong TokenID2UserID mapping length") -} - -func TestRestoreGroups_Migration(t *testing.T) { - storeDir := t.TempDir() - - err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json")) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - return - } - - // create default group - account := store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] - account.Groups = map[string]*group.Group{ - "cfefqs706sqkneg59g3g": { - ID: "cfefqs706sqkneg59g3g", - Name: "All", - }, - } - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err, "failed to save account") - - // restore account with default group with empty Issue field - if store, err = NewFileStore(context.Background(), storeDir, nil); err != nil { - return - } - account = store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] - - require.Contains(t, account.Groups, "cfefqs706sqkneg59g3g", "failed to restore a FileStore file - missing Account Groups") - require.Equal(t, group.GroupIssuedAPI, account.Groups["cfefqs706sqkneg59g3g"].Issued, "default group should has API issued mark") -} - -func TestGetAccountByPrivateDomain(t *testing.T) { - storeDir := t.TempDir() - - err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json")) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - return - } - - existingDomain := "test.com" - - account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain) - require.NoError(t, err, "should found account") - require.Equal(t, existingDomain, account.Domain, "domains should match") - - _, err = store.GetAccountByPrivateDomain(context.Background(), "missing-domain.com") - require.Error(t, err, "should return error on domain lookup") -} - -func TestFileStore_GetAccount(t *testing.T) { - storeDir := t.TempDir() - storeFile := filepath.Join(storeDir, "store.json") - err := util.CopyFileContents("testdata/store.json", storeFile) - if err != nil { - t.Fatal(err) - } - - accounts := &accounts{} - _, err = util.ReadJson(storeFile, accounts) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - t.Fatal(err) - } - - expected := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] - if expected == nil { - t.Fatalf("expected account doesn't exist") - return - } - - account, err := store.GetAccount(context.Background(), expected.Id) - if err != nil { - t.Fatal(err) - } - - assert.Equal(t, expected.IsDomainPrimaryAccount, account.IsDomainPrimaryAccount) - assert.Equal(t, expected.DomainCategory, account.DomainCategory) - assert.Equal(t, expected.Domain, account.Domain) - assert.Equal(t, expected.CreatedBy, account.CreatedBy) - assert.Equal(t, expected.Network.Identifier, account.Network.Identifier) - assert.Len(t, account.Peers, len(expected.Peers)) - assert.Len(t, account.Users, len(expected.Users)) - assert.Len(t, account.SetupKeys, len(expected.SetupKeys)) - assert.Len(t, account.Routes, len(expected.Routes)) - assert.Len(t, account.NameServerGroups, len(expected.NameServerGroups)) -} - -func TestFileStore_GetTokenIDByHashedToken(t *testing.T) { - storeDir := t.TempDir() - storeFile := filepath.Join(storeDir, "store.json") - err := util.CopyFileContents("testdata/store.json", storeFile) - if err != nil { - t.Fatal(err) - } - - accounts := &accounts{} - _, err = util.ReadJson(storeFile, accounts) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - t.Fatal(err) - } - - hashedToken := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].HashedToken - tokenID, err := store.GetTokenIDByHashedToken(context.Background(), hashedToken) - if err != nil { - t.Fatal(err) - } - - expectedTokenID := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].ID - assert.Equal(t, expectedTokenID, tokenID) -} - -func TestFileStore_DeleteHashedPAT2TokenIDIndex(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) - store.HashedPAT2TokenID["someHashedToken"] = "someTokenId" - - err := store.DeleteHashedPAT2TokenIDIndex("someHashedToken") - if err != nil { - t.Fatal(err) - } - - assert.Empty(t, store.HashedPAT2TokenID["someHashedToken"]) -} - -func TestFileStore_DeleteTokenID2UserIDIndex(t *testing.T) { - store := newStore(t) - store.TokenID2UserID["someTokenId"] = "someUserId" - - err := store.DeleteTokenID2UserIDIndex("someTokenId") - if err != nil { - t.Fatal(err) - } - - assert.Empty(t, store.TokenID2UserID["someTokenId"]) -} - -func TestFileStore_GetTokenIDByHashedToken_Failure(t *testing.T) { - storeDir := t.TempDir() - storeFile := filepath.Join(storeDir, "store.json") - err := util.CopyFileContents("testdata/store.json", storeFile) - if err != nil { - t.Fatal(err) - } - - accounts := &accounts{} - _, err = util.ReadJson(storeFile, accounts) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - t.Fatal(err) - } - - wrongToken := sha256.Sum256([]byte("someNotValidTokenThatFails1234")) - _, err = store.GetTokenIDByHashedToken(context.Background(), string(wrongToken[:])) - - assert.Error(t, err, "GetTokenIDByHashedToken should throw error if token invalid") -} - -func TestFileStore_GetUserByTokenID(t *testing.T) { - storeDir := t.TempDir() - storeFile := filepath.Join(storeDir, "store.json") - err := util.CopyFileContents("testdata/store.json", storeFile) - if err != nil { - t.Fatal(err) - } - - accounts := &accounts{} - _, err = util.ReadJson(storeFile, accounts) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - t.Fatal(err) - } - - tokenID := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].ID - user, err := store.GetUserByTokenID(context.Background(), tokenID) - if err != nil { - t.Fatal(err) - } - - assert.Equal(t, "f4f6d672-63fb-11ec-90d6-0242ac120003", user.Id) -} - -func TestFileStore_GetUserByTokenID_Failure(t *testing.T) { - storeDir := t.TempDir() - storeFile := filepath.Join(storeDir, "store.json") - err := util.CopyFileContents("testdata/store.json", storeFile) - if err != nil { - t.Fatal(err) - } - - accounts := &accounts{} - _, err = util.ReadJson(storeFile, accounts) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - t.Fatal(err) - } - - wrongTokenID := "someNonExistingTokenID" - _, err = store.GetUserByTokenID(context.Background(), wrongTokenID) - - assert.Error(t, err, "GetUserByTokenID should throw error if tokenID invalid") -} - -func TestFileStore_SavePeerStatus(t *testing.T) { - storeDir := t.TempDir() - - err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json")) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - return - } - - account, err := store.getAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b") - if err != nil { - t.Fatal(err) - } - - // save status of non-existing peer - newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()} - err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus) - assert.Error(t, err) - - // save new status of existing peer - account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - ID: "testpeer", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}, - } - - err = store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatal(err) - } - - err = store.SavePeerStatus(account.Id, "testpeer", newStatus) - if err != nil { - t.Fatal(err) - } - account, err = store.getAccount(account.Id) - if err != nil { - t.Fatal(err) - } - - actual := account.Peers["testpeer"].Status - assert.Equal(t, newStatus, *actual) -} - -func TestFileStore_SavePeerLocation(t *testing.T) { - storeDir := t.TempDir() - - err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json")) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - return - } - account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") - require.NoError(t, err) - - peer := &nbpeer.Peer{ - AccountID: account.Id, - ID: "testpeer", - Location: nbpeer.Location{ - ConnectionIP: net.ParseIP("10.0.0.0"), - CountryCode: "YY", - CityName: "City", - GeoNameID: 1, - }, - Meta: nbpeer.PeerSystemMeta{}, - } - // error is expected as peer is not in store yet - err = store.SavePeerLocation(account.Id, peer) - assert.Error(t, err) - - account.Peers[peer.ID] = peer - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err) - - peer.Location.ConnectionIP = net.ParseIP("35.1.1.1") - peer.Location.CountryCode = "DE" - peer.Location.CityName = "Berlin" - peer.Location.GeoNameID = 2950159 - - err = store.SavePeerLocation(account.Id, account.Peers[peer.ID]) - assert.NoError(t, err) - - account, err = store.GetAccount(context.Background(), account.Id) - require.NoError(t, err) - - actual := account.Peers[peer.ID].Location - assert.Equal(t, peer.Location, actual) -} - -func newStore(t *testing.T) *FileStore { - t.Helper() - store, err := NewFileStore(context.Background(), t.TempDir(), nil) - if err != nil { - t.Errorf("failed creating a new store") - } - - return store -} diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index ff09129bd..f8ab46d81 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -6,7 +6,6 @@ import ( "io" "net" "os" - "path/filepath" "runtime" "sync" "sync/atomic" @@ -89,14 +88,7 @@ func getServerKey(client mgmtProto.ManagementServiceClient) (*wgtypes.Key, error func Test_SyncProtocol(t *testing.T) { dir := t.TempDir() - err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json")) - if err != nil { - t.Fatal(err) - } - defer func() { - os.Remove(filepath.Join(dir, "store.json")) //nolint - }() - mgmtServer, _, mgmtAddr, err := startManagementForTest(t, &Config{ + mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{ Stuns: []*Host{{ Proto: "udp", URI: "stun:stun.wiretrustee.com:3468", @@ -117,6 +109,7 @@ func Test_SyncProtocol(t *testing.T) { Datadir: dir, HttpConfig: nil, }) + defer cleanup() if err != nil { t.Fatal(err) return @@ -412,18 +405,18 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) { } } -func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultAccountManager, string, error) { +func startManagementForTest(t *testing.T, testFile string, config *Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) { t.Helper() lis, err := net.Listen("tcp", "localhost:0") if err != nil { - return nil, nil, "", err + return nil, nil, "", nil, err } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := NewTestStoreFromJson(context.Background(), config.Datadir) + + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), testFile) if err != nil { - return nil, nil, "", err + t.Fatal(err) } - t.Cleanup(cleanUp) peersUpdateManager := NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} @@ -437,7 +430,8 @@ func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultA eventStore, nil, false, MocIntegratedValidator{}, metrics) if err != nil { - return nil, nil, "", err + cleanup() + return nil, nil, "", cleanup, err } secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) @@ -445,7 +439,7 @@ func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultA ephemeralMgr := NewEphemeralManager(store, accountManager) mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, ephemeralMgr) if err != nil { - return nil, nil, "", err + return nil, nil, "", cleanup, err } mgmtProto.RegisterManagementServiceServer(s, mgmtServer) @@ -455,7 +449,7 @@ func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultA } }() - return s, accountManager, lis.Addr().String(), nil + return s, accountManager, lis.Addr().String(), cleanup, nil } func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.ClientConn, error) { @@ -475,6 +469,7 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie return mgmtProto.NewManagementServiceClient(conn), conn, nil } + func Test_SyncStatusRace(t *testing.T) { if os.Getenv("CI") == "true" && os.Getenv("NETBIRD_STORE_ENGINE") == "postgres" { t.Skip("Skipping on CI and Postgres store") @@ -488,15 +483,8 @@ func Test_SyncStatusRace(t *testing.T) { func testSyncStatusRace(t *testing.T) { t.Helper() dir := t.TempDir() - err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json")) - if err != nil { - t.Fatal(err) - } - defer func() { - os.Remove(filepath.Join(dir, "store.json")) //nolint - }() - mgmtServer, am, mgmtAddr, err := startManagementForTest(t, &Config{ + mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{ Stuns: []*Host{{ Proto: "udp", URI: "stun:stun.wiretrustee.com:3468", @@ -517,6 +505,7 @@ func testSyncStatusRace(t *testing.T) { Datadir: dir, HttpConfig: nil, }) + defer cleanup() if err != nil { t.Fatal(err) return @@ -665,15 +654,8 @@ func Test_LoginPerformance(t *testing.T) { t.Run(bc.name, func(t *testing.T) { t.Helper() dir := t.TempDir() - err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json")) - if err != nil { - t.Fatal(err) - } - defer func() { - os.Remove(filepath.Join(dir, "store.json")) //nolint - }() - mgmtServer, am, _, err := startManagementForTest(t, &Config{ + mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{ Stuns: []*Host{{ Proto: "udp", URI: "stun:stun.wiretrustee.com:3468", @@ -694,6 +676,7 @@ func Test_LoginPerformance(t *testing.T) { Datadir: dir, HttpConfig: nil, }) + defer cleanup() if err != nil { t.Fatal(err) return diff --git a/management/server/management_test.go b/management/server/management_test.go index 3956d96b1..ba27dc5e8 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -5,7 +5,6 @@ import ( "math/rand" "net" "os" - "path/filepath" "runtime" sync2 "sync" "time" @@ -52,8 +51,6 @@ var _ = Describe("Management service", func() { dataDir, err = os.MkdirTemp("", "wiretrustee_mgmt_test_tmp_*") Expect(err).NotTo(HaveOccurred()) - err = util.CopyFileContents("testdata/store.json", filepath.Join(dataDir, "store.json")) - Expect(err).NotTo(HaveOccurred()) var listener net.Listener config := &server.Config{} @@ -61,7 +58,7 @@ var _ = Describe("Management service", func() { Expect(err).NotTo(HaveOccurred()) config.Datadir = dataDir - s, listener = startServer(config) + s, listener = startServer(config, dataDir, "testdata/store.sqlite") addr = listener.Addr().String() client, conn = createRawClient(addr) @@ -530,12 +527,12 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie return mgmtProto.NewManagementServiceClient(conn), conn } -func startServer(config *server.Config) (*grpc.Server, net.Listener) { +func startServer(config *server.Config, dataDir string, testFile string) (*grpc.Server, net.Listener) { lis, err := net.Listen("tcp", ":0") Expect(err).NotTo(HaveOccurred()) s := grpc.NewServer() - store, _, err := server.NewTestStoreFromJson(context.Background(), config.Datadir) + store, _, err := server.NewTestStoreFromSqlite(context.Background(), testFile, dataDir) if err != nil { log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 5f8545243..7dbd4420c 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -773,7 +773,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { func createNSStore(t *testing.T) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir) + store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir) if err != nil { return nil, err } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 387adb91d..225571f62 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1004,7 +1004,11 @@ func Test_RegisterPeerByUser(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + if err != nil { + t.Fatal(err) + } + defer cleanup() eventStore := &activity.InMemoryEventStore{} @@ -1065,7 +1069,11 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + if err != nil { + t.Fatal(err) + } + defer cleanup() eventStore := &activity.InMemoryEventStore{} @@ -1127,7 +1135,11 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + if err != nil { + t.Fatal(err) + } + defer cleanup() eventStore := &activity.InMemoryEventStore{} diff --git a/management/server/route_test.go b/management/server/route_test.go index b556816be..fbe022102 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1257,7 +1257,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { func createRouterStore(t *testing.T) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir) + store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir) if err != nil { return nil, err } @@ -1737,7 +1737,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { } assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) - //peerD is also the routing peer for route1, should contain same routes firewall rules as peerA + // peerD is also the routing peer for route1, should contain same routes firewall rules as peerA routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers) assert.Len(t, routesFirewallRules, 2) assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 85c68ef44..cce748a0f 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -915,6 +915,28 @@ func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore, return store, nil } +// NewPostgresqlStoreFromSqlStore restores a store from SqlStore and stores Postgres DB. +func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { + store, err := NewPostgresqlStore(ctx, dsn, metrics) + if err != nil { + return nil, err + } + + err = store.SaveInstallationID(ctx, sqliteStore.GetInstallationID()) + if err != nil { + return nil, err + } + + for _, account := range sqliteStore.GetAllAccounts(ctx) { + err := store.SaveAccount(ctx, account) + if err != nil { + return nil, err + } + } + + return store, nil +} + func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) { var setupKey SetupKey result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 64ef36831..dc07849d9 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -7,7 +7,6 @@ import ( "net" "net/netip" "os" - "path/filepath" "runtime" "testing" "time" @@ -25,7 +24,6 @@ import ( "github.com/netbirdio/netbird/management/server/status" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/util" ) func TestSqlite_NewStore(t *testing.T) { @@ -347,7 +345,11 @@ func TestSqlite_GetAccount(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/store.json") + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") + if err != nil { + t.Fatal(err) + } + defer cleanup() id := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -367,7 +369,11 @@ func TestSqlite_SavePeer(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/store.json") + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") + if err != nil { + t.Fatal(err) + } + defer cleanup() account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) @@ -415,7 +421,11 @@ func TestSqlite_SavePeerStatus(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/store.json") + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") + defer cleanup() + if err != nil { + t.Fatal(err) + } account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) @@ -468,8 +478,11 @@ func TestSqlite_SavePeerLocation(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/store.json") - + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") + defer cleanup() + if err != nil { + t.Fatal(err) + } account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) @@ -519,8 +532,11 @@ func TestSqlite_TestGetAccountByPrivateDomain(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/store.json") - + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") + defer cleanup() + if err != nil { + t.Fatal(err) + } existingDomain := "test.com" account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain) @@ -539,8 +555,11 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/store.json") - + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") + defer cleanup() + if err != nil { + t.Fatal(err) + } hashed := "SoMeHaShEdToKeN" id := "9dj38s35-63fb-11ec-90d6-0242ac120003" @@ -560,8 +579,11 @@ func TestSqlite_GetUserByTokenID(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/store.json") - + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") + defer cleanup() + if err != nil { + t.Fatal(err) + } id := "9dj38s35-63fb-11ec-90d6-0242ac120003" user, err := store.GetUserByTokenID(context.Background(), id) @@ -668,24 +690,9 @@ func newSqliteStore(t *testing.T) *SqlStore { t.Helper() store, err := NewSqliteStore(context.Background(), t.TempDir(), nil) - require.NoError(t, err) - require.NotNil(t, store) - - return store -} - -func newSqliteStoreFromFile(t *testing.T, filename string) *SqlStore { - t.Helper() - - storeDir := t.TempDir() - - err := util.CopyFileContents(filename, filepath.Join(storeDir, "store.json")) - require.NoError(t, err) - - fStore, err := NewFileStore(context.Background(), storeDir, nil) - require.NoError(t, err) - - store, err := NewSqliteStoreFromFileStore(context.Background(), fStore, storeDir, nil) + t.Cleanup(func() { + store.Close(context.Background()) + }) require.NoError(t, err) require.NotNil(t, store) @@ -733,32 +740,31 @@ func newPostgresqlStore(t *testing.T) *SqlStore { return store } -func newPostgresqlStoreFromFile(t *testing.T, filename string) *SqlStore { +func newPostgresqlStoreFromSqlite(t *testing.T, filename string) *SqlStore { t.Helper() - storeDir := t.TempDir() - err := util.CopyFileContents(filename, filepath.Join(storeDir, "store.json")) - require.NoError(t, err) + store, cleanUpQ, err := NewSqliteTestStore(context.Background(), t.TempDir(), filename) + t.Cleanup(cleanUpQ) + if err != nil { + return nil + } - fStore, err := NewFileStore(context.Background(), storeDir, nil) - require.NoError(t, err) - - cleanUp, err := testutil.CreatePGDB() + cleanUpP, err := testutil.CreatePGDB() if err != nil { t.Fatal(err) } - t.Cleanup(cleanUp) + t.Cleanup(cleanUpP) postgresDsn, ok := os.LookupEnv(postgresDsnEnv) if !ok { t.Fatalf("could not initialize postgresql store: %s is not set", postgresDsnEnv) } - store, err := NewPostgresqlStoreFromFileStore(context.Background(), fStore, postgresDsn, nil) + pstore, err := NewPostgresqlStoreFromSqlStore(context.Background(), store, postgresDsn, nil) require.NoError(t, err) require.NotNil(t, store) - return store + return pstore } func TestPostgresql_NewStore(t *testing.T) { @@ -924,7 +930,7 @@ func TestPostgresql_SavePeerStatus(t *testing.T) { t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) } - store := newPostgresqlStoreFromFile(t, "testdata/store.json") + store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite") account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) @@ -963,7 +969,7 @@ func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) { t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) } - store := newPostgresqlStoreFromFile(t, "testdata/store.json") + store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite") existingDomain := "test.com" @@ -980,7 +986,7 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) { t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) } - store := newPostgresqlStoreFromFile(t, "testdata/store.json") + store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite") hashed := "SoMeHaShEdToKeN" id := "9dj38s35-63fb-11ec-90d6-0242ac120003" @@ -995,7 +1001,7 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) { t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) } - store := newPostgresqlStoreFromFile(t, "testdata/store.json") + store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite") id := "9dj38s35-63fb-11ec-90d6-0242ac120003" @@ -1009,12 +1015,15 @@ func TestSqlite_GetTakenIPs(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") - defer store.Close(context.Background()) + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + defer cleanup() + if err != nil { + t.Fatal(err) + } existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - _, err := store.GetAccount(context.Background(), existingAccountID) + _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) takenIPs, err := store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID) @@ -1054,12 +1063,15 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") - defer store.Close(context.Background()) + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + if err != nil { + return + } + t.Cleanup(cleanup) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - _, err := store.GetAccount(context.Background(), existingAccountID) + _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) @@ -1096,12 +1108,15 @@ func TestSqlite_GetAccountNetwork(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") - defer store.Close(context.Background()) + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) + } existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - _, err := store.GetAccount(context.Background(), existingAccountID) + _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) network, err := store.GetAccountNetwork(context.Background(), LockingStrengthShare, existingAccountID) @@ -1118,12 +1133,15 @@ func TestSqlite_GetSetupKeyBySecret(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") - defer store.Close(context.Background()) + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) + } existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - _, err := store.GetAccount(context.Background(), existingAccountID) + _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") @@ -1137,12 +1155,16 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") - defer store.Close(context.Background()) + + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) + } existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - _, err := store.GetAccount(context.Background(), existingAccountID) + _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") diff --git a/management/server/store.go b/management/server/store.go index f34a73c2d..041c936ae 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -12,10 +12,11 @@ import ( "strings" "time" - "github.com/netbirdio/netbird/dns" log "github.com/sirupsen/logrus" "gorm.io/gorm" + "github.com/netbirdio/netbird/dns" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/telemetry" @@ -236,23 +237,29 @@ func getMigrations(ctx context.Context) []migrationFunc { } } -// NewTestStoreFromJson is only used in tests -func NewTestStoreFromJson(ctx context.Context, dataDir string) (Store, func(), error) { - fstore, err := NewFileStore(ctx, dataDir, nil) - if err != nil { - return nil, nil, err - } - +// NewTestStoreFromSqlite is only used in tests +func NewTestStoreFromSqlite(ctx context.Context, filename string, dataDir string) (Store, func(), error) { // if store engine is not set in the config we first try to evaluate NETBIRD_STORE_ENGINE kind := getStoreEngineFromEnv() if kind == "" { kind = SqliteStoreEngine } - var ( - store Store - cleanUp func() - ) + var store *SqlStore + var err error + var cleanUp func() + + if filename == "" { + store, err = NewSqliteStore(ctx, dataDir, nil) + cleanUp = func() { + store.Close(ctx) + } + } else { + store, cleanUp, err = NewSqliteTestStore(ctx, dataDir, filename) + } + if err != nil { + return nil, nil, err + } if kind == PostgresStoreEngine { cleanUp, err = testutil.CreatePGDB() @@ -265,21 +272,32 @@ func NewTestStoreFromJson(ctx context.Context, dataDir string) (Store, func(), e return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv) } - store, err = NewPostgresqlStoreFromFileStore(ctx, fstore, dsn, nil) + store, err = NewPostgresqlStoreFromSqlStore(ctx, store, dsn, nil) if err != nil { return nil, nil, err } - } else { - store, err = NewSqliteStoreFromFileStore(ctx, fstore, dataDir, nil) - if err != nil { - return nil, nil, err - } - cleanUp = func() { store.Close(ctx) } } return store, cleanUp, nil } +func NewSqliteTestStore(ctx context.Context, dataDir string, testFile string) (*SqlStore, func(), error) { + err := util.CopyFileContents(testFile, filepath.Join(dataDir, "store.db")) + if err != nil { + return nil, nil, err + } + + store, err := NewSqliteStore(ctx, dataDir, nil) + if err != nil { + return nil, nil, err + } + + return store, func() { + store.Close(ctx) + os.Remove(filepath.Join(dataDir, "store.db")) + }, nil +} + // MigrateFileStoreToSqlite migrates the file store to the SQLite store. func MigrateFileStoreToSqlite(ctx context.Context, dataDir string) error { fileStorePath := path.Join(dataDir, storeFileName) diff --git a/management/server/store_test.go b/management/server/store_test.go index 40c36c9e0..fc821670d 100644 --- a/management/server/store_test.go +++ b/management/server/store_test.go @@ -14,12 +14,6 @@ type benchCase struct { size int } -var newFs = func(b *testing.B) Store { - b.Helper() - store, _ := NewFileStore(context.Background(), b.TempDir(), nil) - return store -} - var newSqlite = func(b *testing.B) Store { b.Helper() store, _ := NewSqliteStore(context.Background(), b.TempDir(), nil) @@ -28,13 +22,9 @@ var newSqlite = func(b *testing.B) Store { func BenchmarkTest_StoreWrite(b *testing.B) { cases := []benchCase{ - {name: "FileStore_Write", storeFn: newFs, size: 100}, {name: "SqliteStore_Write", storeFn: newSqlite, size: 100}, - {name: "FileStore_Write", storeFn: newFs, size: 500}, {name: "SqliteStore_Write", storeFn: newSqlite, size: 500}, - {name: "FileStore_Write", storeFn: newFs, size: 1000}, {name: "SqliteStore_Write", storeFn: newSqlite, size: 1000}, - {name: "FileStore_Write", storeFn: newFs, size: 2000}, {name: "SqliteStore_Write", storeFn: newSqlite, size: 2000}, } @@ -61,11 +51,8 @@ func BenchmarkTest_StoreWrite(b *testing.B) { func BenchmarkTest_StoreRead(b *testing.B) { cases := []benchCase{ - {name: "FileStore_Read", storeFn: newFs, size: 100}, {name: "SqliteStore_Read", storeFn: newSqlite, size: 100}, - {name: "FileStore_Read", storeFn: newFs, size: 500}, {name: "SqliteStore_Read", storeFn: newSqlite, size: 500}, - {name: "FileStore_Read", storeFn: newFs, size: 1000}, {name: "SqliteStore_Read", storeFn: newSqlite, size: 1000}, } @@ -89,3 +76,11 @@ func BenchmarkTest_StoreRead(b *testing.B) { }) } } + +func newStore(t *testing.T) Store { + t.Helper() + + store := newSqliteStore(t) + + return store +} diff --git a/management/server/testdata/extended-store.json b/management/server/testdata/extended-store.json deleted file mode 100644 index 7f96e57a8..000000000 --- a/management/server/testdata/extended-store.json +++ /dev/null @@ -1,120 +0,0 @@ -{ - "Accounts": { - "bf1c8084-ba50-4ce7-9439-34653001fc3b": { - "Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b", - "CreatedBy": "", - "Domain": "test.com", - "DomainCategory": "private", - "IsDomainPrimaryAccount": true, - "SetupKeys": { - "A2C8E62B-38F5-4553-B31E-DD66C696CEBB": { - "Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", - "AccountID": "", - "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", - "Name": "Default key", - "Type": "reusable", - "CreatedAt": "2021-08-19T20:46:20.005936822+02:00", - "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00", - "UpdatedAt": "0001-01-01T00:00:00Z", - "Revoked": false, - "UsedTimes": 0, - "LastUsed": "0001-01-01T00:00:00Z", - "AutoGroups": ["cfefqs706sqkneg59g2g"], - "UsageLimit": 0, - "Ephemeral": false - }, - "A2C8E62B-38F5-4553-B31E-DD66C696CEBC": { - "Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC", - "AccountID": "", - "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC", - "Name": "Faulty key with non existing group", - "Type": "reusable", - "CreatedAt": "2021-08-19T20:46:20.005936822+02:00", - "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00", - "UpdatedAt": "0001-01-01T00:00:00Z", - "Revoked": false, - "UsedTimes": 0, - "LastUsed": "0001-01-01T00:00:00Z", - "AutoGroups": ["abcd"], - "UsageLimit": 0, - "Ephemeral": false - } - }, - "Network": { - "id": "af1c8024-ha40-4ce2-9418-34653101fc3c", - "Net": { - "IP": "100.64.0.0", - "Mask": "//8AAA==" - }, - "Dns": "", - "Serial": 0 - }, - "Peers": {}, - "Users": { - "edafee4e-63fb-11ec-90d6-0242ac120003": { - "Id": "edafee4e-63fb-11ec-90d6-0242ac120003", - "AccountID": "", - "Role": "admin", - "IsServiceUser": false, - "ServiceUserName": "", - "AutoGroups": ["cfefqs706sqkneg59g3g"], - "PATs": {}, - "Blocked": false, - "LastLogin": "0001-01-01T00:00:00Z" - }, - "f4f6d672-63fb-11ec-90d6-0242ac120003": { - "Id": "f4f6d672-63fb-11ec-90d6-0242ac120003", - "AccountID": "", - "Role": "user", - "IsServiceUser": false, - "ServiceUserName": "", - "AutoGroups": null, - "PATs": { - "9dj38s35-63fb-11ec-90d6-0242ac120003": { - "ID": "9dj38s35-63fb-11ec-90d6-0242ac120003", - "UserID": "", - "Name": "", - "HashedToken": "SoMeHaShEdToKeN", - "ExpirationDate": "2023-02-27T00:00:00Z", - "CreatedBy": "user", - "CreatedAt": "2023-01-01T00:00:00Z", - "LastUsed": "2023-02-01T00:00:00Z" - } - }, - "Blocked": false, - "LastLogin": "0001-01-01T00:00:00Z" - } - }, - "Groups": { - "cfefqs706sqkneg59g4g": { - "ID": "cfefqs706sqkneg59g4g", - "Name": "All", - "Peers": [] - }, - "cfefqs706sqkneg59g3g": { - "ID": "cfefqs706sqkneg59g3g", - "Name": "AwesomeGroup1", - "Peers": [] - }, - "cfefqs706sqkneg59g2g": { - "ID": "cfefqs706sqkneg59g2g", - "Name": "AwesomeGroup2", - "Peers": [] - } - }, - "Rules": null, - "Policies": [], - "Routes": null, - "NameServerGroups": null, - "DNSSettings": null, - "Settings": { - "PeerLoginExpirationEnabled": false, - "PeerLoginExpiration": 86400000000000, - "GroupsPropagationEnabled": false, - "JWTGroupsEnabled": false, - "JWTGroupsClaimName": "" - } - } - }, - "InstallationID": "" -} diff --git a/management/server/testdata/extended-store.sqlite b/management/server/testdata/extended-store.sqlite new file mode 100644 index 0000000000000000000000000000000000000000..81aea8118ccf7d3af562ddece1f3007f1e8fc942 GIT binary patch literal 163840 zcmeI5Piz}ke#beYMN+aQ#_=W-C(-UmiDMI%Y*LhD$qU1kW!f=p%ZX%R*D!+NhyTukw&U@-<(bEEZE!smbz2#6Km;T zzyElfJ{<86=v@kZ9`y0u#5wD4N4aqJTQ2iP_S^B_9lJk%F*Q1#9sPQAGr2Q*CV52C zM!rt`KJg{}qsQ6TQ}X4LQbISX?2-MTK@C*3wquH8wQia1maP`t_Y<9|*V6LGD^ll@ z$kwZm)E2XB)6g2KR;e)CRvmMf8MfNCnWgI0_}t@b8>Q8|CFREY)zVLu^1;=Vl{@Q7 zIli*;nX-4Xa_;fwWLjQXkv?B_wCx6qFA$%a_1biI^~$Z%z;h{QsI_IqVUL{I;I!VM zp*5MGwWHZPtm?w>)7YbyZfTBg8fsN@NR+=idXvqTD;A^ARdsvc7xq&$CoI>{Y=@(z zCIp2)+PHChb>p6LvvkjGU2b;m&idxvjny0Lca?H&SLni4TbkQj5wa`H<4aB*C<3PwX|EloO^QVXj+zK>FGUT20}!<9E&R1EgR8ks2WeLme~l@ zs@p0xc2BP`0ky0c?IyGI%52%-T4fC;P8doLi>qM^;bY7jUcSz^8)jvf>o%xW=(}Oo zbt9zkz*?fVq1$$w7v{F9TrFK&-MV#GnbBJM%xqaV99H*i&tf%#TvkLHvtlws!p?rn z@BG1b=CN}Je~hVx*ybYE<`?u0E_W;?Uy=tIoY0zD5N{LP9jD;Up74-MOBD;GEw6cY2^cq zQ=G!}MmBTWEj5&&w|v8Dl3A)drtO3cP$bq()wCmGyUa3JBa&1xE#C;rqExG)IXs9( zqUDGN~+v+LLj6K~ljV3e5LAweNYOnu1c6QA= zluMDr+E*)PH89j}rta)V^c&>VnJF?N?(jL>oEV|~5p{SsmwA?(b8DwCN~36OFO`-r zPfMMXfe#9XdpLA<6Vv@pshb%JwyuNfX7#wfCm$S1%hS`+Q&~8xu3Y!iXw+?WGoHhD zg$GI~*6K`c z=uMr@9h8nS-RCf)F0mJ|rQ&ExzCJxDWQI;r=zdOZC)09qO6o{~Rr!%vt(r|uH~hmR z{nq4Xd6KRj`H5sjZUyDs7avJ!d1^}f1-UwMAijkEJ`%MaKRIeU{$%)|>9KcC&mX%- zDtQoAH?3f1O>;Oer)B9)jogNJ(vUm!gNt4z|E|~cuJiYE=-v1GeiC{3-Xu5vge-+e z0aAdY8+B@iZVR8PQj{UH;<2P>Mi__+Uq6JyeRhFRLjpmI)!Plt@)ikuI(w-0&qD)g zZ#_uyqovgcdHq28Q;POpiT>CSCsk=^dNbmzBT`zUVLlv$h^D=kM$1tOWPX5)30gsY zuDCxjCYL|~Q6~>#TG12q59HkWk(9hPHE68)Caf0xPf5PcP9%OXIYF;d_G*Hjcz^&1 zfB*=900@8p2!H?xfB*=900=yTKxajIV{78})H`EaKbgoJ&x~d=8NROGt`#cF`Q?SV zZEZ0>w@_gh=2jMpD|5w#rNv@CU#L}z+w=MPg}Fk0EtFu*P9iZ*?Ub^AO3)Jz5C8!X009sH0T2KI5C8!X z009sHfme*c>EzbfF#8Yg`v03m_M2C%gJ>=YfB*=900@8p2!H?xfB*=900@ADm`X{~f+ZJ3L{!_;$4Yd4&I zg}%f`d8j)(3Vo}T!XD{1|Gof4{3X4`=qo7bdlcwa_QG6#d9JXc%;zsIEM1(Zd(ZQW zE5)Vd`FZzl_W2?gS(z&=4;DGNmOR}aKPR3_K2P_^)AOu&y+5OESE@6gJU%0FV()px zE)7NO%6CNU%2ih5%s8@_2~DNOY7guS`6c_ouEFYyEA{#M%qRT%|48;IQQt!M)XfB*=900@8p z2!H?xfB*=900=|~@b&*N_bG_=|1g65{(t(>KRiGH1V8`;KmY_l00ck)1V8`;KmY^| zF#&%4AM5`^tYCBv1V8`;KmY_l00ck)1V8`;KmY`Y0M`F70}ucK5C8!X009sH0T2KI z5C8!XIQ#^#{y+RWMh`&%1V8`;KmY_l00ck)1V8`;KtTNd-{e0fvg`DM2MB-w2!H?x zfB*=900@8p2!H?xygURtZ;ecDZM~7nOdQW#zI`GyQCp}jRhKTz&n*>e+jE5itIVzB zt4nkF`Gt9{QkbXT3oH)vU3hJqS@%B~`0Il7`-Tg1g*;WO6qYXL3m1#a=jIFfr3=gM zj~i+&y>niorpeT_OK#IxRjV>rc;rAjf#`pifJcf;yfB*=9 z00@8p2!H?xfB*=900@A9M}U9-|1jGf;`{%H*?hqIAGQSoAOHd&00JNY0w4eaAOHd& z00J)^f$!z_|D^0X-T(j6^%V^Q0T2KI5C8!X009sH0T2KI5CDOflmOlTKen|sarW{X zmoKkWKP@iX#l>gc`ajI2{7v&V`%v55DOK;9H`zM<6+)5z1YvIef|Ac)6wkBbmH+jE zn^~C47kV=NQhgcx`oEO@Q-Yp&fB*=900@8p2!H?xfB*=900@8p2)tqhPA9jQsCyr+(`oC#^zCO%v2&_J2w%KIYEwkO4r(XoP z|H;P4;5v)XU1vem`JqI0o``Sf!f!3m00ck) z1V8`;KmY_l00cnbl_Bs}a#4!gp-9GO6zd~$a&fG`#6o={``?L4={JdE-(>&g*ssQa zckKT7#nk9{cJ%Af&E#Jvcale>-;8{n_b11|@ru;BB(nADBelgW+cdO>s#Pk?wpGX6WrnTtFE3E_s;hwS z_Z;ypIKH+~TD@CRZmeG|{ZuI*+|;siXI&}BS5`h#_BvF~J>Hy5%PT9==c|sk-C*$r z;#0F;o9?b&xm6P9`c5m0(L-4|LtQT`4twOx27T}j4Xw%itR2nXVO3Xqej0n!(k;!= zO+&3}4vF$tM{ly(a>ZiwxvFmO`@(*T=7i-Mn(c5~pe6)`Kiar)dv)WUaLZC!45 z?aun<-Hp{7>vxrMZC4nNt+q7PvDJE6xprrxbmRKES0GS?vQfHL+9<8Bl{OX8HJj8# zoz3O8C(`nwEOkZ$E9Hf3ZlY$}aL6vj@npY=^vbzLAyTRGZh1P%v6OsC9yD&;dJ5ug z;>o3>X<3$~r}u=#h(z&nENTxTEvm8589k+0W+PCVZmZOr`hrBo1eT{B6%*rm;Tu`gfX2Y!OMo683Ekrf1+jg55=9Z^iEnQpP zx^-8Xp)Jap*|KgptnOKs#cHIFvLe!$6_XhfcJ^C-=MT0bkDcp_Vhkz7HW#razhGca z5GMMwww02dkTD2&) zv1l*secNG8J_26z8nH6b=C(yFYYP;e@hMVZ3m*3Qkc1^1|ZL6m|Gxl`LG@8sH zpXw?=l&k5DZd1azs}y(}+IAN$*RN_jn$zYP)y1s=fa6aMv~KP%cFdYhSIH)xc1@nYyzd(QlAbXQs%A__*irU1Egx2f?9z zS1$9cFy+=>5N~69skD4~TI!q(d{8jl!=baAnC^E<-ONz1bsbbUtH<>{`QS)eo}QMT z%EDQ7<+`6nqi(C4@f^M@JWxV$WI*_G&7Arjna;FZd=^vfNwwG>O?ty-OQue9DViyf z#|Q#j++3&}6L;ITR%dEMZ|ZdJpmdDsK8G1~iM@a=6-QI@_31%Vp3o@@-Os7*WLhpx zNgXM$DnGocRkNw-hJSdZ-trWH(vX=3K(v@E@;k=yW28ghqzaM7#e-}Rc_b^d-1 zz59OOPa+TBo8-oykfrb_KnielqfX7xZQ)Z@iZWzYJeKrK^8!)f>xXc-&n^&ZNFZpj zdb^=n-n@QKXAjl>$z>qztvM-vw6yvluOCQ%O3~gc(H|S)q$&+fZ$_MTL`rKk%!h*z z(OlNjXvs@~%nxueK`W@w758Vmf}L8D|&+dft))(l9Jb^28}h}gw=xoDM>5s z3}0NA4mQ@IeFx(k-ty;+lZF{YjB?f!He1D1(bpZVWQCHCIR4~>;E9EM)E0*5k%i?gh z2PPPvE+AjRa|5B_@A`*8JkEACtFq(y5xd{a`i-V{8SUGKI|1QKzpmo?merw7KYKN_ z_QRC?lfod&>}RVM`qQRs@22GoC#B9t`t;UKRx(7;@H zFEF(@`_8j=A-(Dg?)wivJDHYOPf4Arz<63_L$BRA~+fg{XTgY2Y zhbknK3aTvUo=9?9K6Og^qU0{ELxSN7-;P>WI4kOGx>b8q3O-tVbBM8)evQ)7G31M= z2t|GoHUsBkn{A8q?AN-jS80tVN<~&e6mcE1Vg?KImT5UbA%37*no$qxIoJV-zUibA zp_tx{yl&5<4#zJ=-1%-E-cgR+_lxswbiSo+=aJN^379&T##3iRqKwuCb7me+D_mNO&a&IicsZ z5Lf?X<2K)9P%AsEvg<4MU@dxNyn!X!7u}KVq9bjEH)Z{jo+fkW-+0!f#;2mXI-?W4 zb@VUYNlooUZ@B22f4B*6_Uo)q($$#f2cw(1&vl{V@%zU`Z29~>pV{nD(Q!XaCvfrVORM$ zIg^%iIq8XaAu-~eqZOZu_BWatwTZA1#7rMT4e&X|1UxNka&Oq z2!H?xfB*=900@8p2!H?xfB*;_UIP63Ki2<;SHI{P2!H?xfB*=900@8p2!H?xfB*<^ z0$BgU4nP0|KmY_l00ck)1V8`;KmY_l;P4Z`{r`tw$LJvlfB*=900@8p2!H?xfB*=9 z00`jvKWqR5KmY_l00ck)1V8`;KmY_l00a&{0j&QIzmCyE5C8!X009sH0T2KI5C8!X z009ud{r|855C8!X009sH0T2KI5C8!X009s<`~-0Q|M2S=Jp=&|009sH0T2KI5C8!X z009sH0j&RF10VnbAOHd&00JNY0w4eaAOHd&aQF$}{{O?TWAqRNKmY_l00ck)1V8`; zKmY_l00eOTA2t92AOHd&00JNY0w4eaAOHd&00M`f0M`G9U&rVn2!H?xfB*=900@8p z2!H?xfB*>a`~Rivw+VXU0RkWZ0w4eaAOHd&00JNY0w4eaAaM8yB&B3x0)PL1_!NsC zfdB}A00@8p2!H?xfB*=900@9UF9EFodx_u}1V8`;KmY_l00ck)1V8`;KmY^|9|8RR z|KZaqdISO>00JNY0w4eaAOHd&00JNY0=)#V{_iD%V-NrV5C8!X009sH0T2KI5C8!X lID7plVC(-UmiDMI%Y*Li=$@Av+Fe!utoy*D$Iwmy1bdQ7X>PTladeBzaaq9p!8(-MiqJM{lD{THvd z=))OtLPtgV-0$N%iL1`vjrhXp?|qrq)8CK(?%c=YtI3h^^vE|OTf=)Jmxs?N#?Uv3 z-zUDNfAqTYYEr#XnrG(TdW#!L!Zs zQ@z0)*S3tBZj?&Qb#>3)XO^osUFPU!xp(gQ_07WCc0s$hal7zyt$2JjMeWgsR_t9_ z@p3V2teAPWHJMUZR+KN+JY%=UdKc)OnhyK4y>{zCp>JE-W$JBF^Vm~wHflCJwTwCw zS$l@N$I1Z=k;a}jOvms{+tSO1N230_XV%$lvE(rNT-J9Fg|J95JXx-0xSo%ex)2rq zaP!{7wav%c{lep*cbVDsM;lw)n``$rwzXnqpX%fwVems+KkaKXJ(70<*{mLdk(7*G6JblR&&uKOytWREittuj`?YqAGYjqjBTpX~koVlQ$7~ zBbRy2hAt(T4PjVyGE05Wc0Jhud1BSpZ8s*i&m4=@Vo4?25k^qtr5ZKE^MhDSs-qs; z)emT7a-4z*knSx@?*lnZr z^{cy{;WhmX^6M^Rv1u*~lx)jlB?=LebCC17kI;F9=_})~&zdH&(v1cU`@kS2M~=py zOx!Dkw&d=~mgLeYv;2rEw=cCBM716w{Gg@|$$B1HJuZ2~yM~H5^MwRI`vu;w` zQRyD0dkQmZ6MF$$%8w+~yVLzbrqo2P{W-ZioKo{sN=u2XN&50#WUHKqKDTpc+OAt8=KaqAJuaoZ7#;p3LaVVj{p4$gG)AgpRT z(aM_EaA8ivG3y4ojj&P49g5&$mdU@H6|?R9MGhT>-xo>b;X9Lp_#;`09|cGOo@rI7 z8>Y*Bs!mac%*u}?9V@~}l>7PwKSZ?bbqv~$QXmrnu1C;{>I=lh!kAnF z1w@lPi0#B1=$^=#YePwOeX8GF6DF+Z#HV3@pPfkjXmWzymGtcd{on-xAOHd&00JNY z0w4eaAOHd&00JOz1cBCy^7_ui!>PB&c78T7dVX|dbad3;SMOGGrRD7M!rZQLJv+Bh zVoP%?3;C6~{KDe(d^Ve_l=8dt+4+UJTy`!yujLk3v-#ER;?B7(fT{SXzl~iQUSm0I8Q)y0LY*{j)^*_nri zyH9D?t}U;vt=+gW^NWu^+4%Z5j}nP->Zg+aHbFmlfdB}A00@8p2!H?xfB*=900@8p z2)twjE)DODo#gsMu>b!qk^b%_8z8z10w4eaAOHd&00JNY0w4eaAOHd&a3q1V!&778 zxq{&N|8ydqKC%chK>!3m00ck)1V8`;KmY_l00cl_kO;IaB|WjT^X}Bv*4F6g&GVz9 zC%K8fHov}HSe(B#mtVefeQx3U_59qee6BEe`}X4E`r^vsdg0cs+pJkuFliL`q}F%`Nie=`QY*Y`MfW(GM8KKFLHb>S^BcU z9RJa>S^81|{aoSiAAb_u|Nmnm{l`I4BgO&(AOHd&00JNY0w4eaAOHd&00JN&2n-KT zjN$wLf(DKt00JNY0w4eaAOHd&00JNY0wD165jg4h|NZCx>5u;51p*)d0w4eaAOHd& z00JNY0w4eaATYoL{QLjd{|~T&F*FbW0T2KI5C8!X009sH0T2KI5Fi5B|HBMG00ck) z1V8`;KmY_l00ck)1VCW$31I&}_&UZ2K>!3m00ck)1V8`;KmY_l00cn5|Nj5b{kx~8_f#>z&8u?6;`T{Fsc(|WPK`S<^a(*Kx9e@kz8 zfdB}A00@8p2!H?xfB*=900@8p2>cKTTpr$0PI6&m$p8KSC%Mdk{r?X^t>^;?fB*=9 z00@8p2!H?xfB*=900_hg`1}8pJhZ^~|0fajpZ{0V8wvWs3j{y_1V8`;KmY_l00ck) z1V8`;K;T3KTGy1Zot=p*H($SbbEW)ge%Z}mKk6a>libSRvLCV!jIF&wdE36vHt72X zdHS-!-29T3&93sFEBxL6o2}90WiB1V8`;KmY_l00ck)1V8`;K;We!;NSnp{{N*?89ECBAOHd& z00JNY0w4eaAOHd&00O5bFp>VB#H8}u#JTU%|9bAP$A5S3n~|;IzZu>e zKBN40=$pjv6JOImdR=)nsouPxBuuN!p1PmaOpobC)3f=xUUlqd!_{-a@wL{}D=GD( z6{U59XPf1xdV@KxZ5cJ)D3zG&>YlyNELZn`d4X=0gYN%%+vwec^Xr?1we5m-Z{v31 z=UVajt`@aN8(OheIr|q>{4+^}kKWSw> z^ib3;)6k2W$DVq#(HO#0%cwJvwP(0{tQ=@hq_L+B(=j~Lw)C>$k*NRfnRPZ>EIEum zm-XF4AuLi1PnN40uIJkVbs;MJ;pV-EYnzX?`-R6r?=rLNk2bcpH`nfMY-`2JJ~tj$ zZy37g>eZrl=h0^2-rbF`K%@w5vv8-dS=d-FY-xOGE~$wIo5}1>q}1!G(i(}Z)Gy@v zCTce=kL;44Pj;I~r(9qZJe4YMm#0R~CDj{hzj+hXljDbp=Qqx#R8>{Jddw|`C-S#* zaeLrtag9kWhDvknTBJ17)oEM@W{GiRMa^o~nPZk_i-y1QfFifam|~q+w=?jmZ#k=+*#Xsu&vF|A5Wc`Et;0cs-a~$ ztU~%IYCMfOJu)R>@30{Tf4mii>;hlZ!;mDlnV2n!f{{6mcbZd!#m$MIpY1lMPC4OY zsC0)pJs3->msG`!EH)^*&uGLXJ>IWw5{OsoCuDw^N^y((bv^V^R0XeRG;VP`?Kq3$ zO#}|aWnQzPO9^H}SWBI3P2aO!Pqr$aShaQAjfw3u$6~cuQpt9NW{bR3qh@#&+o{K- zI_j}q{eVVB4k{+~#Os;dG`-iD@lm*kuE*;B3^XbH(InYr zPOKlHK*O=i&5{?JZhVMge%-K|6{F-eosRO%IxrpEsxymxYM=mKu42~OU6FB@De%^e z-8NcZzq;!gUenJYSL`wto4c|=$+j$3BCjsZcsZX>e4Ph{zA_H`tZ5P}-DuFT4-8Ur z)h!4vzg%;-^JDyTij4`T<(lzMYo zXa>=kl@fW3D6sWf3l)3B-KJ|)nO-yN zCbb=v?qRy8FrzlH7qF%LNK(B!-EYYwHIZw7PVNq;)clmvQX;Do;Z-l&b;Gp8>6vb8 z5?P^SU`Ha6tSG3UnECQUC8bVHDZe6DM-D_th~rS)dPH*EcEtMixD{>KX6TQDGo3sL ztJ+Sq7^a0;nA32~x2jZ{QNnSf3F|rWX_$7}qyBDtV)PFQHSs~>Y;OEFW1o)Q z82$IuUyc54YHFmI{I4_rIXrphUxxlZ@gGA*;)C?(lm9*O50n2iF@5fzsqi0yy}EdB zz44Qj`Y5NgR-%dPdr-dNpy7we!{#Zvb!X0WPv~wTF)?_2^f%z$A|2u4*)OJ2YA&aI zd67@{Zu#yME$#ri1>&xtJ*a4|Njv2-b4!ld@VCWsb_gb#o&g{s5xN0s`Un0Y67OZZ zhEv)L{fOIbX5B{9xs4XK5iCIbtzTPlUE6AD&`0lv)_T6&DycbwnTzUJbtB_&!x#0NY&o89Z zwTnt?Dl(piT{BCEy3?eMksh9yWjXboZab2*+l9j2v{WIPR8(a#^ITC=>cxx7mxW+! zEeXmMz7@AFIVqvVUI2t|G#HY4Za+D(V_ENa~~%d|)1 zrD8iFinyL#vZIZ8!*;x=kbj~(hETwZ3+{kXd>&souZX&jp3`2F8$+&)&A)NU1-i#pwCC<@jA}&vnuMg+3+| zmzL~>GCJJPvvWV<3)AQOZDS!R(Sq8;Nc>8Ub3&)JkU;-r^K`-sG(!Nma@mdVY zgcFPRFMdU~jgGV>Uz8OkLrrF`y?)f9Mo`fMol!$?p8d;UQB!%XGhB47KZ1lex^>ni z8E7o@gYk>Hr@B$`?7j0mwtVp}S8h;Y@cwFCv0+=MP^*0Pbe3*P_!kUA{~(r={&kA~ zhl{5{-(SGBpZcCr?|UbIw);8v(@9ld>FZ30e2?2(bjhSwI(_d}S^AUR9O}_<+%{{;ygK>!3m z00ck)1V8`;KmY_l00cl_a0&R&|10V56ZC@@2!H?xfB*=900@8p2!H?xfB*=9z~B)W zR)!N3`2PRkDHbCF0T2KI5C8!X009sH0T2KI5CDNr0@(j|62UnLfB*=900@8p2!H?x zfB*=900;~o0et^|@N|k1fdB}A00@8p2!H?xfB*=900@9UCjspLJBi>N1V8`;KmY_l w00ck)1V8`;KmY^=j{u(kA3U96L?8eHAOHd&00JNY0w4eaAOHd&&`IEb0dP~Nf&c&j literal 0 HcmV?d00001 diff --git a/management/server/testdata/store_policy_migrate.json b/management/server/testdata/store_policy_migrate.json deleted file mode 100644 index 1b046e632..000000000 --- a/management/server/testdata/store_policy_migrate.json +++ /dev/null @@ -1,116 +0,0 @@ -{ - "Accounts": { - "bf1c8084-ba50-4ce7-9439-34653001fc3b": { - "Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b", - "Domain": "test.com", - "DomainCategory": "private", - "IsDomainPrimaryAccount": true, - "SetupKeys": { - "A2C8E62B-38F5-4553-B31E-DD66C696CEBB": { - "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", - "Name": "Default key", - "Type": "reusable", - "CreatedAt": "2021-08-19T20:46:20.005936822+02:00", - "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00", - "Revoked": false, - "UsedTimes": 0 - } - }, - "Network": { - "Id": "af1c8024-ha40-4ce2-9418-34653101fc3c", - "Net": { - "IP": "100.64.0.0", - "Mask": "//8AAA==" - }, - "Dns": null - }, - "Peers": { - "cfefqs706sqkneg59g4g": { - "ID": "cfefqs706sqkneg59g4g", - "Key": "MI5mHfJhbggPfD3FqEIsXm8X5bSWeUI2LhO9MpEEtWA=", - "SetupKey": "", - "IP": "100.103.179.238", - "Meta": { - "Hostname": "Ubuntu-2204-jammy-amd64-base", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "22.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "development", - "UIVersion": "" - }, - "Name": "crocodile", - "DNSLabel": "crocodile", - "Status": { - "LastSeen": "2023-02-13T12:37:12.635454796Z", - "Connected": true - }, - "UserID": "edafee4e-63fb-11ec-90d6-0242ac120003", - "SSHKey": "AAAAC3NzaC1lZDI1NTE5AAAAIJN1NM4bpB9K", - "SSHEnabled": false - }, - "cfeg6sf06sqkneg59g50": { - "ID": "cfeg6sf06sqkneg59g50", - "Key": "zMAOKUeIYIuun4n0xPR1b3IdYZPmsyjYmB2jWCuloC4=", - "SetupKey": "", - "IP": "100.103.26.180", - "Meta": { - "Hostname": "borg", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "22.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "development", - "UIVersion": "" - }, - "Name": "dingo", - "DNSLabel": "dingo", - "Status": { - "LastSeen": "2023-02-21T09:37:42.565899199Z", - "Connected": false - }, - "UserID": "f4f6d672-63fb-11ec-90d6-0242ac120003", - "SSHKey": "AAAAC3NzaC1lZDI1NTE5AAAAILHW", - "SSHEnabled": true - } - }, - "Groups": { - "cfefqs706sqkneg59g3g": { - "ID": "cfefqs706sqkneg59g3g", - "Name": "All", - "Peers": [ - "cfefqs706sqkneg59g4g", - "cfeg6sf06sqkneg59g50" - ] - } - }, - "Rules": { - "cfefqs706sqkneg59g40": { - "ID": "cfefqs706sqkneg59g40", - "Name": "Default", - "Description": "This is a default rule that allows connections between all the resources", - "Disabled": false, - "Source": [ - "cfefqs706sqkneg59g3g" - ], - "Destination": [ - "cfefqs706sqkneg59g3g" - ], - "Flow": 0 - } - }, - "Users": { - "edafee4e-63fb-11ec-90d6-0242ac120003": { - "Id": "edafee4e-63fb-11ec-90d6-0242ac120003", - "Role": "admin" - }, - "f4f6d672-63fb-11ec-90d6-0242ac120003": { - "Id": "f4f6d672-63fb-11ec-90d6-0242ac120003", - "Role": "user" - } - } - } - } -} diff --git a/management/server/testdata/store_policy_migrate.sqlite b/management/server/testdata/store_policy_migrate.sqlite new file mode 100644 index 0000000000000000000000000000000000000000..0c1a491a68d58e019b5256b60338ab8b5f5e40d4 GIT binary patch literal 163840 zcmeI5O>7%SmdDwmB~r2}#_=Q_CDDXd;@CtanawYY9vEI)q7#j6*^w+b9wTTr$s*Mj z*-dvhWov>gKsyO$2UyGwHpc~KPji^V%wmAqoaeOI!(jKaK(343!(8^3!vdMZzW!u4 zKdcW=7MABv%Otz%Rdvzx=C()cwIJFqONK!$w*{EFNHWR^jC>QC=_~={-2}& z?&EcObId)USI+l-)Y~^hi`qYqo5JyLO_^8X-%kDZ#J#DD6XR3y@vp`=!dv6#!pAsi z?5oi4LtoNAdc61Y1b=yk3&~22J<=aEsez)@F;weVY-nnyt&4*Fda677QiT7oz;!QK z*>dfX*k+opDpFIFs#T`zqM>dxMHf3d(?q#8Jon`CdZ~E3lvrI`DSebE?_EtfadRzE z9$s1bT-iBUIr(^FI>HwU+-F5Ysx;Z~0>e||PMdBQuUs#UJXhizwYHov*dt@!JFRo5 zNG;}OZAtnTtJyHzH1?=1Ymym&7*Y75-mG0WDOU^IfT-&(4UR+(fohaA0EnVnhTM`XjY?KpMZ?2bC-(PbI zcY@@JJsinR0<>Zq~$0Iz?b6?!G%)k<{o=yanY?lq_)K`t8R!wbs zYL#`78oMJ`nFY0+P&zH9$<_I?VrrE&nRP;6x?f!MTZlQvoZ;o_ywX&w+oo>4S}lDy z)rPG26z*9|&^BaU@0f*6+f1yKt`;}1-%iX)ZFz3KEGq_UIJT#;IzcWctTd(#$@B>u zyKT4gd)t}A&K~?jOwGqO8L&3Dplfi+6BGO;ew4vkS`!89C03d0H>WlWnv->We9)Zw z_rTfh3`s4(Emgi*8V(p^ajNp*;)Tp{iAXurHko-v`Vj1VpT8@vRDy({CEV~MPGv4b`7gVW+`r|y5ToKE3u)9svZ#AW}3p9fuyRcxkga7O0}Dk zVFs~)R9`)+F78mz$lC>^9vVZFJMz$EyD>wv6lOeBwQW(V)ijDbEZ7TuS2tM89IP^h zCK^u_rUlyJ2(&e|)~OnSk!f|&$!|$Yr!G~Ej@DP6DLb;JDlMjvgSHi5m8;9mUQ_(I zYZMflQl*F1)UT);lF>0U$gk^+1%|m_psFegt5S&YIR}5fOa5loOV-5JjZdBfcwI(vzcL8sKq^aWenLG`lwT;G#-k45;?r@1e9%USj0dhf=A zZmXAZ0N>>wD87(0BGTDrPW=wdjXG^}N>uAh)z}VANd0C@rcM(xns$-L@M7SwiB)+> zi0;no#lQ?u8tgtE8)J51+B+T4%&`88QyDp z?3~l_$M%s(9)vYi&6`=%9L~vUYjR5>x8a=B=MLTABG<^j%XPWu{M{USb^X4ZL>|6B z$&NohOEIGWDZr4G1~o(0EuShJWvq;#H z*+X$~9_mRu>p_YiZK*NJ>wD7oDcUg6 za|7IvpykzPi@OtJatRa=W%3}Z7Cgb=Ku(??o8XscM~yYtghj!97dF?~q0slHWAx0$ zS3>lM2MB-w2!H?xfB*=900@8p2!H?xfWT7-bPL=oo3R_SZ%l69iA7IF$D`4xxvs9% zh3ZmjDZ5aSa;b%Em0egUWHW_@Og5j(q*6k?nyI8y>Fk1#T1cf6LjGbZdoi6^%w?9c zg797{eKD0Xm~Jdq)fR0G?a;1~WK~H;wxq08WtyrKmh378tIFypbE~)JF3t(5)M7rn zm|9HD&ClJC^leHze}1W0EMB@a_vyWl*S`Gu%}{8H+R4Sg4$&VTAOHd&00JNY0w4ea zAOHd&00JNY0!NI%+3@D%e)b>i_5U}a_%}zagJ>=YfB*=900@8p2!H?xfB*=900@A< zQwbao&rZ7c71;Ox$3yY>Q;Q%I1V8`;KmY_l00ck)1V8`;KmY_@Bm!N9i^n!M-=5vr z*oa0ipNvNLa}m9mUS2BY(^nQUOILFX*<3EOa3v#@7FJgB`Q>~ezg)U9AFQ(~c{Zy`y$uFhT_U-@aj44uB5SB)Z>|INW zK5VdH{UuT<`cMJ=y=Oh&``Etz|EEy=PcM=h(H9T^0T2KI5C8!X009sH0T2KI5C8#} zKsX$m#OMEA8h8Z(5C8!X009sH0T2KI5C8!X0D+@NV85ULH}C%+JMp{F^i_Jo0|Y<- z1V8`;KmY_l00ck)1fDknPo~Gt<~CVu;T8Hx<(Zk-+p$k88(Hl3S zCu6Z_wayxOy`IYJ54IK7$Q2s7)PC;h{_&0C&1;)%_3mn?qhyuTqg(4jC9_()`=eVe zefR#|)|K@AJIkG>x}3daebc&-$}Fbyi^5WBuA*vBUs`(NbQ_XE!VK zo1mQ<+hI+$ZT__0Cv(e9wNs-nW~Z-czvcXbt$FlmjedT`2XE^wAY z&X@WW^441$lcbQ%s81pF2l@q{Lb8qhDCEX!uJu9vhg+3K<5qnobM--KRsX29^ii&| zaffZLrmt_^EZk_9O2(aHPa(pE!eTnJG-rcfNT*ZTh5J&gwYwm-YI*Z}`E~ZJnyP9l z{l+W(#IE;Gf%e+$L`NKbfx00@8p2!H?xfB*=900@8p2!H?x90dZeg>&4nJ+Co7oSUSK zfxQ=d{r`(l{NLlhI0_v^8$kdBKmY_l00ck)1V8`;KmY_l-~}X*3N!8?_o9rQ53|XG z+$&+P|9=yTe;xnL3&;;$0s#;J0T2KI5C8!X009sH0T2Lzqe$Rfc$3@D$0f!N^!fkk ze+|WJ^n?cpfB*=900@8p2!H?xfB*=900=zq1iG({O>b_#5{<@AMlatu6^-rZTM_oN zl2s(_3CO^E*R009sH0T2KI5C8!X009sH0T2KI5csYMbhv14^VF&BXslXi^#}Te zR9=6ut*}O}(8x6Q^J@UbW|Mww;NIL|`D|lu-b`-f^}4UxTx#xP^ZGv*|6_>$@Bjf2 z009sH0T2KI5C8!X009sH0T4Li1o&`na&TiH(}>0YH#E)tI&|Wj_`jU^yQ$xvxHoli zVtgt-{?+(K`0vA8;bYvd$G!^vKJ+F1qsMzMPwMwB(|BRtBTYVrD~Pwx@f4|OwmR2M;1l7W<&mN&oR6O zCzsbt#oMLC>e@=_qeOY{rj`>o*AnI7m6gwxoeq_gk2j_xe4)U7Ry3qalMOF0JT>mL z>2~qT^^$e2KWk+}^iWQmqpp_|276@8dwp;Y6{*GCtSw33Vl`WPZW?>kmNm(cRYj~x z28o)_hTLNF<*LT$eNC+Fy25UXWccNpl5Uu`Kuz!p-&$Y2QCz>9xK_Grw=Oxqd~=v|d_UE^Q>NuIZ#E>TEJu ziADGv&vnN=D>V!0rirQ@#UQ)1j;9Arq+iZ93M-W=@0F*MoS5J*@uS9#U5{YB#GYI_ z9^rYO`{J%;F;=4WbRuXER$5SFzB4*X)6}M?G+7s^UpsP@S;)!>rPE@XT%9i~rp{QC zSts>rpjCNO}m8M$VHZ|wfYH72nHe|)8PR|yCnwNFGV-_|oFR@a(THL&T zJ26N5nRD}HSut3{u`G?%Ngw5emBzFonLc4-x9xU*Z!2=x*}iCqA^F%Q1GeNA^vr3h z-<;YkXinDg@j-Lymvem#mF_dA>ys1wS)P+Ui?xexGwLx(PmQXZ1cH^C3CW*CB9_Jd ztmXJ9s=}%#8nieoZ7++nPGLI`of)0B=u42>uC=tt*2FDUH~dy*B{o!1)dOPNOjB4h zkW^JQSF>fSRJ$n|byaHxr26Vnb#aG!Mh+?<_0Sla+>wVS+l?8Tr7+{6s%?u>t)@|I zW5HhNySl+z<_MH2{Lvt(FfGsyN1&~#wNBLt3^%KbPJT;LI(4aPbhN(mOxck&RcSGW ze5$Pgt6W`f_L}0yU8BI;lqx;6rhY};kc^I*L9SS5EHHNc0##K}Se3lGZ^Zlat${CE zL1Cnf#vXKJVkJs#>h_LAO7^^m>E5iN?%0yP<+miA4w+&`RR8&X=d=US5D~Lq`*cdI zy~W5w2AV})&aT?&zZvd&X6?(R$f50uRkh|BYA;hZb_4qLavDsvG6FvC0KUr_p@Tti z-@YpueOj2ZYZt7S$(@M^fB7`mo$-8-H{AW9vzHhdbV|KUU$C_uR4=Q~^*wp_ScE@) zn)`ycoK;V*_ijAswt5){@Lm3a;)^3A!k2C4)bGG_rqecOF}1!_jqT8+*Kf9D>NJ<4 znG$&nFR%@p3zdh&-Ht9bnAntCGM(EiJ;ZbmU`9P+&tOZL@d^I@)1#(5zEcRj_Y;+H zgwM=!UCy&AH@u29wI#`ldw6Wnn%pc$(zYWvk*vtBpq%{t78l`XXStt}t0M>EO1Q6M zLF;jogSO+&U-z2PI_Gr!v3(?x2Vo6W^CrVIF>`X-n%t7eZ8#_OxkER&$Tjlsa$W8@ ze>aC-UBB-pk%#Y3vg41>Qp_ko3NU1)LCuhL%cqJIWyq||SkgDm^F%FQzYmA|^a8$y zc!C;hbefXp%#3TWXB*`kwTCiuO*4!PqcNs@jy~R=`;Yq@-q3 zeK-mc&1E%-mb?_m+yFNuXnFP7;_hshTml6|nLLQ91y3+Ikdx=fCivyqQDeIVOcZ zh<`HuKe1m;|5NPriQiG-KiF$^_qqG(TM_=Iz;z4Wz%@Onwcw!cJII~WQ*;~5oEsd_ zgF^1Wu;Qb+03Q_T3m1<+osIB{H@4V{Or@$fsy3;+ypju^!W% zxD{HiC1^UX>$m4D-h7}X=?^+Il{IIh-V~db2ekb6x!^&w3RX@%G&k*N>TA%(!d_l! zq=DJ+o?&Y7-WyNbg>u68uivs-ILuF)FJDiv4>QN%UWs_HGw+p1=Gh0Fs{laz*6&)yEm z>YGd|u@uw4k=N^a(BZhHh`YI)XKu@`l&+U(FL$}Pv0Pj!B`>@-!CyNw$^|*~+QuOG zkKRyTjqpF9$>_?imY7xo6;q7>QZQ zUQXzs7Gmq4Y+N@t8N})qt8TlB-CK(e8E0Tw?F;V6_Rx{G{F}0FNk@~(^RGN@QsYt) zZJp7HUOWER_N1nMsy|!|%s=deR|j=AAZcsN@q@ul-2+{yc>K;uE4F<0wxwLVLi_pU zpkkeK9YC%A-P3uxC}D0GI{v|(PMZ4^{|}p|cH5u9wIBGJ(eP_0-`Vw?Yw-zQER1v} zM1IKSE!r{}mQKIB#a;NbT^t&6;-G8j+fnmJj-z!w=l+=pfBrnz4ck-t{%u!x3gqp> zo(qnzfqWB_HSg+ePfIi{GB+NR-HCT1{QF6+oAd<%zhex90)NKYLFW>5slJkKSn!u8 zH|TejkHd2jKAGg6I2RHF-Z@zDnP7i|nL(TI8^Mt2gRcROKRJ9f2=uM*DxL2&o(JlPu>`V_Fyd7*yAkWpNH_`X!(n7-AL+H2@#KGX5wD14_Iu!rf zUCHAW1V8`;KmY_l00ck)1V8`;KmY_l;HVHd8{VAU&zA<6-~Wf-|2ryLMLR(N1V8`; zKmY_l00ck)1V8`;K;Uo$@cVy<qOK>!3m00ck)1V8`;KmY_l00cnba0IaaKO8rd z0s#;J0T2KI5C8!X009sH0T2LzqeQ^G|DTJ08=^lvKmY_l00ck)1V8`;KmY_l00ck) z1YSG>VJ;ks;q(75o?_7>5C8!X009sH0T2KI5C8!X009u_CxG>TKM@>*00@8p2!H?x zfB*=900@8p2!OzgM*yGyfAMsR9)SP|fB*=900@8p2!H?xfB*=9KtBPj|NDvH7z987 z1V8`;KmY_l00ck)1V8`;UOWQ0|Nq6)DS89~AOHd&00JNY0w4eaAOHd&00R94{ttLC B#c2Ql literal 0 HcmV?d00001 diff --git a/management/server/testdata/store_with_expired_peers.json b/management/server/testdata/store_with_expired_peers.json deleted file mode 100644 index 44c225682..000000000 --- a/management/server/testdata/store_with_expired_peers.json +++ /dev/null @@ -1,130 +0,0 @@ -{ - "Accounts": { - "bf1c8084-ba50-4ce7-9439-34653001fc3b": { - "Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b", - "Domain": "test.com", - "DomainCategory": "private", - "IsDomainPrimaryAccount": true, - "Settings": { - "PeerLoginExpirationEnabled": true, - "PeerLoginExpiration": 3600000000000 - }, - "SetupKeys": { - "A2C8E62B-38F5-4553-B31E-DD66C696CEBB": { - "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", - "Name": "Default key", - "Type": "reusable", - "CreatedAt": "2021-08-19T20:46:20.005936822+02:00", - "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00", - "Revoked": false, - "UsedTimes": 0 - - } - }, - "Network": { - "Id": "af1c8024-ha40-4ce2-9418-34653101fc3c", - "Net": { - "IP": "100.64.0.0", - "Mask": "//8AAA==" - }, - "Dns": null - }, - "Peers": { - "cfvprsrlo1hqoo49ohog": { - "ID": "cfvprsrlo1hqoo49ohog", - "Key": "5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=", - "SetupKey": "72546A29-6BC8-4311-BCFC-9CDBF33F1A48", - "IP": "100.64.114.31", - "Meta": { - "Hostname": "f2a34f6a4731", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "11", - "Platform": "unknown", - "OS": "Debian GNU/Linux", - "WtVersion": "0.12.0", - "UIVersion": "" - }, - "Name": "f2a34f6a4731", - "DNSLabel": "f2a34f6a4731", - "Status": { - "LastSeen": "2023-03-02T09:21:02.189035775+01:00", - "Connected": false, - "LoginExpired": false - }, - "UserID": "", - "SSHKey": "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk", - "SSHEnabled": false, - "LoginExpirationEnabled": true, - "LastLogin": "2023-03-01T19:48:19.817799698+01:00" - }, - "cg05lnblo1hkg2j514p0": { - "ID": "cg05lnblo1hkg2j514p0", - "Key": "RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=", - "SetupKey": "", - "IP": "100.64.39.54", - "Meta": { - "Hostname": "expiredhost", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "22.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "development", - "UIVersion": "" - }, - "Name": "expiredhost", - "DNSLabel": "expiredhost", - "Status": { - "LastSeen": "2023-03-02T09:19:57.276717255+01:00", - "Connected": false, - "LoginExpired": true - }, - "UserID": "edafee4e-63fb-11ec-90d6-0242ac120003", - "SSHKey": "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK", - "SSHEnabled": false, - "LoginExpirationEnabled": true, - "LastLogin": "2023-03-02T09:14:21.791679181+01:00" - }, - "cg3161rlo1hs9cq94gdg": { - "ID": "cg3161rlo1hs9cq94gdg", - "Key": "mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=", - "SetupKey": "", - "IP": "100.64.117.96", - "Meta": { - "Hostname": "testhost", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "22.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "development", - "UIVersion": "" - }, - "Name": "testhost", - "DNSLabel": "testhost", - "Status": { - "LastSeen": "2023-03-06T18:21:27.252010027+01:00", - "Connected": false, - "LoginExpired": false - }, - "UserID": "edafee4e-63fb-11ec-90d6-0242ac120003", - "SSHKey": "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM", - "SSHEnabled": false, - "LoginExpirationEnabled": false, - "LastLogin": "2023-03-07T09:02:47.442857106+01:00" - } - }, - "Users": { - "edafee4e-63fb-11ec-90d6-0242ac120003": { - "Id": "edafee4e-63fb-11ec-90d6-0242ac120003", - "Role": "admin" - }, - "f4f6d672-63fb-11ec-90d6-0242ac120003": { - "Id": "f4f6d672-63fb-11ec-90d6-0242ac120003", - "Role": "user" - } - } - } - } -} \ No newline at end of file diff --git a/management/server/testdata/store_with_expired_peers.sqlite b/management/server/testdata/store_with_expired_peers.sqlite new file mode 100644 index 0000000000000000000000000000000000000000..ed1133211d28b5b2faf4963e833b2ac521ff7fe0 GIT binary patch literal 163840 zcmeI5Pi!06eaAVXB}%esuCv+9D!WQ(InJ7}C2{yi6df2_N}{xi^~b9}>$Ml@a7aE; zW05oT%utrRNCB?1Mi8J#i{@CMIrY#(QUqwW*PMzTg7#7%*P=c2(%T*iY!3a+U(WE4 zCDoh4t9)sh!+G!b-n{qu{oe2Q-pp{g{r;M4u=t*;H6$aR4!suQxX_=(nl&vQR27*e^;`WNV*8gtsTQ&?Cl4m?);2DP>nH`tMJ%R8*I zt4IyzX6;M*KC9X=+%$I7lr_naRYj~y28o*QhTLGcN)?UK=c-sfc7@#($?(h7CEYO5 zQWLzw_qJBo3tJE3cZ(0~)+KH&ZfX^5#}?b!EdT z;IWHu6_<-!#f`<{cHHWkPGqUGiNxL52ruRbnMi+}qR>D6wQ@AVXL+tY>{*c60^Kwp zwWSzjGuHm)UK8n-vyH+^rIvQemlDtJT#WEM&pm%&nSmu@y3(t1Zz1LwbB33z^KxCS9GJTGYPIxT zS8KB3Q@Cd>LEDgZy=4|QZ8N@9TrTXc?Zl^~raX14Br681Iku;Fnv=DEvDcis z_rMt{&ZES>phv#I^V(p>_jNp*;_@KH;AXurHkoa*VVj1U88;*~nDy(`&gT`s4 zl}<2D>kzg#(wWg}ioOK7=^9pp%u?J}b;ECfR$@&RRXrefz%+%`14$KCbB&;6m1@={ z!wg~psjhldT|A_ok+%y-Ju&(wx8%Obc4PWxDa?4HY6qfJt!fl^Sg;rRv2L)2Iap;1 zO*Ed$ObfKb5ol^^wN)_!Bh%`lli!e()}B-`T3T0mrX0$esx+8F4%$|LRc=qNcbekI zU8SH{m&zTqrhY};kc^g@L4I9lEHKRd0u@zJScO7_&pG(>EgvCT5oVx_#vZq1VkJsV z>h_^TO7-3t-^>Q(MJZ#k=uT<6np z&~0@x&fvTJ1H~6|21GjB%&FgjxlyZWPKl~rsTw<^38~*~$<%3LM$<0x7+wtQH?b=9 z3DGTGsxh%HH)J}tSGtesp23Vd#9qRd(!(SC%GE(3({~7=^LeBkj_~P8uFZK?dpF4Dei(DoDF7L@5=kMmw*7f^t5_$OU zBs>22EX9lhqyR%!YSau_w|uHdQHIRQj3r$&LQmB4^`~&SFD~F~h$pDATB|N;&LZJZ zW>3W4d8jAttOqH6G^N@gukT4eqG<1w=#35iq$+huZUmflKuW6D)hB}x(X>~SXgNxO z%nfjTf|ggGE$&W?$t6%gl*xmrTJQk91357}G{P@V4jOB&35$aJDQvE@L!qx_*oo-x}M!A053iHat2yYObrx zdqQPCIiHy>OWEXfrowX5`Aj-LozBc<)5)Z;S4o#s$y8=qNKPkHaUpj*nY^7&&16%# zY-;{yGDZ6a(~X&m+Mp9dhji9RvZ|ya`%=cLGDTGi^LCYjRb}PlsnvT^x2J?;a%L_w zlblIT-I`jL^aDzpot-Zf3U}^I{qUn7gqM^5`18;>wUdkeE<}HLfdB}A00@8p2!H?x zfB*=900@8p2z&^z=enC{8ae&CM;&<>wZQ3kyqZPiobT zI6VzXWAw}pdhi6@`!7r<=ck2yJe9ninY*2$oAr~~e0pv^m9lUDPo+(f{IoDXSmfkd zlJsSRY3nae-!QQLZd&gj{lLEd|NBtv_vcBC=nDvd00@8p2!H?xfB*=900@8p2!Mb~ zARLa4;rstC4QxRG1V8`;KmY_l00ck)1V8`;K;WxK;I!ZWH}C%skL#h!-@a6&4PGDs z0w4eaAOHd&00JQJWfOR|8NQL-Wzn0H^n;m-a@Lsfw;qJlmaz%TTnP0rWbFi^no~^30N;V}6 zYuW01nWg=e*3Hbr*1FfMTs)b-O^+PU`ofXfj(@=;z#o{exQSQC7$_lc&+>R(<<8b@)_WNxgTxzOb{q z+W2r&{$PJ`L6#S!?Bk{E(bM^Gre#X;iIzN-mOgX|g$=hdSJJL(6IvO7^gN?JL!y|QuBkju089^PEfRn{{1jh&~v+D`MpJERaNqsz>b>CMaw zxm-R!m#1G3Fg*&1^!s%I-_SL$|A%A04#oZ}_Aj)-3j{y_1V8`;KmY_l00ck)1V8`; zK;V@j@Mbv6^*c8jwljK9u?_Lz>{zAG85(>2uY_X%8vC1=^2&4&4FLfV009sH0T2KI z5C8!X009sHfv+%uWSDVhxi=|n7C6hTFGI6oHg=XfD(v;E15 zc_B}a&oCd4AvNXpn;f<6jZWM4=!pc?xm@Z+Px?4b+gdu)Uao28_5Y#RFGF-zs`&N6uGf4cwwkAxw*2?8Jh z0w4eaAOHd&00JNY0w4eauR4K0$o>DJ*IxB*qkSL%0w4eaAOHd&00JNY0w4eaAOHd{ zL%_WLkM;k{5JXuJ009sH0T2KI5C8!X009sH0T6ig31I#I>gxzC1OX5L0T2KI5C8!X z009sH0T2Lzmn9H2E5QB#FN+c7K>!3m00ck)1V8`;KmY_l00cnbl_r4e|F5)q&^QnP z0T2KI5C8!X009sH0T2KI5cpCEL}ULGy3G9|bm{ZhKVABZ@n2o~X#Dob@OW(av*GRV zUxoL>7r0*xeHQvn=x6kgUN^rs!oNGgg=D45j`YWM*d$rY%)z(6{#*tl?u~!(NGVVqKoDa8H#e%?$RIj82wvtWpS%m*eS+W zHx% zETo$zspO`pvHVhbd;v4bx&!sE>gb^b@D zgVh|%(%2s9qZGH&nARuLCu|%y-OitEMGiaL7xghDAKOI0mfV7#IgNLlQOX-*bYQzMyo0M66B_9Ee*0YabML9zg1a@HC0sgfY<@k6jl!;RaDK@Y{@Fs ztV_n8sx<;qUG=ECct|}X2NjTdV)RXJ$$gXU#`MinnDIo_4n(P1)hM>HU@!Dz-CzxK z1j-ctXpoed7HEef(A3mwt6~I(o7F`pzac5DJ*i@}w65|@Ig~Y3X)uL+s;vO4+@4(T zG{uj*N`bd7l{;un{ffFF87(t|T(QnrVC?z@DypKe3VC(ki1+7P17Ea)!ax~~J#NXw zN|c(^?L&!_?0FB6BP=pOJ?QG>g2P zUA5DHGu(B|+Lud_Lpv5LYSlB;PNr-e2lVUZ)R<~z1bo~Xe3vysdxPMqeOF@iMPbUW zU9dJ|ha(aG-K$)C!t+7iaQBDKPGY3jDRnY^!Pa(Aovbd`_w2hDBK*~>+;iS?Rvo#{ zr{SR6>SUb3clig3FOCcdU$&W3zXQ{mR@0otRJ&3&c1V+6zuA(h(_D&XO5`!Tz}9ar zRO%CVTe?(ZVqI>?bZoD5AJaX98Fh%gge|3qNBEVigQh&bLkOMEBjs>}Pfv1f&a)~v zyoyz|A<2rnd!g5w+$=}Zwj(!@tjMmQl=$=>7vU!-xu1}$BM0J2xZ9zi^|;AF+i~Zw zC(US`V>H;1;a-*=P9 z!*?gy@yBN=W)vU=7_w5MX2`nbQ$>n0WL9P@>6+$wqL!~eg~NSu0bfHrL5d#6NiZ0ILdsY`Mr;H(2uQoXJ|8H9-DvYJFo zUJ7Jxfa?>qy!vc$ceYC|fdZmT9z@lG2k0HhiP@nMesOZpSaVHS6x>f?T4|4(tL^CM zZ$o_Oz0gHr{4d8Ijolgjx5%H5{&i$>xHR&g7ycuB`NFS<{xt+~JnlbV(rJvVN5>DlyIo2=V5})x75NlRMSQc~KGtKp6SqRkwFFJa zb^Z37#hVW_B>i!Vrn2U2)SF_{@_?5AJ{R0;R>8`thvrE;n)(`aV&NpOG|<3ocrP)v zxcSzL&O$oX3-~TQESn{NOX4Osz3GZUOVzt~vCvmR^n0($VLOs0c-V zD{OksMORxI>DjHdELUlbW|a!8gec+~YDM)H=1o;Iyh7%Vs7XrAtLJ0~Wc5uZl~{`D zK9SeydC=jwrHH$EHqSgQyHs2&(z)El!uDceshG&UIl|w)KF9?*_1eZD`1jsYz7gTS zN0ZSjLCZ0l*ct2Y@h1kDOicem=SMb!%e!?4V;TB;_7~?_(rpB_}zdvs#F) zf3k7iJjozd_F3h?RqV-HbjUaZ%W7ZnjBE!TY0H07)-CC1GBNwcizYQL717oi9q7%A ze`Zf=_FnG}7d`V2JK-C>I_r_NHRkxi;7Q#xU8wlvJ6EjO@}p~(a_tK3_pb#N>m2J0 zYW1Hzy+s!#%oB!=e{iRh=6Q<$hm)sv+h4-9pZS_m|7$0IuuOzs3!WRVmjxi7l{2A}`I+vhJ^_6tPg1-26AA8_b0IO{or4u$5B4{h8MFz%5%ifp_!{8&lmB}(2=uM%DxL2*FOCV{yEeja zCI$|_02DJa_z(Ks_Qk6f!uI=?;n?eKcljOM`%VyHFw;bM(yHF|y91Z;iS|%D!oTwl z*Iu+*BP(<|s!o?N$W!Zqoom6i1oB*MdJ}zrE-fU?a|kVWg4i3p6S)81J%fQQ2!H?x zfB*=900@8p2!H?xfB*=bUjlgk-}%)odIka@00JNY0w4eaAOHd&00JNY0xkhu|944X z3j!bj0w4eaAOHd&00JNY0w4ea=a&H1|L0e?=otur00@8p2!H?xfB*=900@8p2)G1r z|G!HDTMz&N5C8!X009sH0T2KI5C8!XIKKpN{r~*x7Ci$25C8!X009sH0T2KI5C8!X z00EZ(*8eUEY(W47KmY_l00ck)1V8`;KmY_l;QSK6{r~4zx9AxNfB*=900@8p2!H?x zfB*=900_7QaQ)vUfh`Dt00@8p2!H?xfB*=900@8p2%KL6SpT12-J)k800JNY0w4ea zAOHd&00JNY0wCZL!2SO&32Z?C1V8`;KmY_l00ck)1V8`;K;Zlm!1e$0t6TI81V8`; zKmY_l00ck)1V8`;KmY_>0$BgMB(Mbm5C8!X009sH0T2KI5C8!X0D<#Mz`XyTi+vHI zKfFKy1V8`;KmY_l00ck)1V8`;KmY{JAAv9z4n^_(|MRC<^aun%00ck)1V8`;KmY_l z00ck)1iA@e{ohRl`yc=UAOHd&00JNY0w4eaAOHd&aQ+D3`~T-pr|1y~fB*=900@8p z2!H?xfB*=900?vw!1}+N2=+k$1V8`;KmY_l00ck)1V8`;K;Zll!2SQ{Pp9Y+2!H?x WfB*=900@8p2!H?xfB*<|6Zl`7K~6CM literal 0 HcmV?d00001 diff --git a/management/server/testdata/storev1.json b/management/server/testdata/storev1.json deleted file mode 100644 index 674b2b87a..000000000 --- a/management/server/testdata/storev1.json +++ /dev/null @@ -1,154 +0,0 @@ -{ - "Accounts": { - "auth0|61bf82ddeab084006aa1bccd": { - "Id": "auth0|61bf82ddeab084006aa1bccd", - "SetupKeys": { - "1B2B50B0-B3E8-4B0C-A426-525EDB8481BD": { - "Id": "831727121", - "Key": "1B2B50B0-B3E8-4B0C-A426-525EDB8481BD", - "Name": "One-off key", - "Type": "one-off", - "CreatedAt": "2021-12-24T16:09:45.926075752+01:00", - "ExpiresAt": "2022-01-23T16:09:45.926075752+01:00", - "Revoked": false, - "UsedTimes": 1, - "LastUsed": "2021-12-24T16:12:45.763424077+01:00" - }, - "EB51E9EB-A11F-4F6E-8E49-C982891B405A": { - "Id": "1769568301", - "Key": "EB51E9EB-A11F-4F6E-8E49-C982891B405A", - "Name": "Default key", - "Type": "reusable", - "CreatedAt": "2021-12-24T16:09:45.926073628+01:00", - "ExpiresAt": "2022-01-23T16:09:45.926073628+01:00", - "Revoked": false, - "UsedTimes": 1, - "LastUsed": "2021-12-24T16:13:06.236748538+01:00" - } - }, - "Network": { - "Id": "a443c07a-5765-4a78-97fc-390d9c1d0e49", - "Net": { - "IP": "100.64.0.0", - "Mask": "/8AAAA==" - }, - "Dns": "" - }, - "Peers": { - "oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=": { - "Key": "oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=", - "SetupKey": "EB51E9EB-A11F-4F6E-8E49-C982891B405A", - "IP": "100.64.0.2", - "Meta": { - "Hostname": "braginini", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "21.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "" - }, - "Name": "braginini", - "Status": { - "LastSeen": "2021-12-24T16:13:11.244342541+01:00", - "Connected": false - } - }, - "xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=": { - "Key": "xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=", - "SetupKey": "1B2B50B0-B3E8-4B0C-A426-525EDB8481BD", - "IP": "100.64.0.1", - "Meta": { - "Hostname": "braginini", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "21.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "" - }, - "Name": "braginini", - "Status": { - "LastSeen": "2021-12-24T16:12:49.089339333+01:00", - "Connected": false - } - } - } - }, - "google-oauth2|103201118415301331038": { - "Id": "google-oauth2|103201118415301331038", - "SetupKeys": { - "5AFB60DB-61F2-4251-8E11-494847EE88E9": { - "Id": "2485964613", - "Key": "5AFB60DB-61F2-4251-8E11-494847EE88E9", - "Name": "Default key", - "Type": "reusable", - "CreatedAt": "2021-12-24T16:10:02.238476+01:00", - "ExpiresAt": "2022-01-23T16:10:02.238476+01:00", - "Revoked": false, - "UsedTimes": 1, - "LastUsed": "2021-12-24T16:12:05.994307717+01:00" - }, - "A72E4DC2-00DE-4542-8A24-62945438104E": { - "Id": "3504804807", - "Key": "A72E4DC2-00DE-4542-8A24-62945438104E", - "Name": "One-off key", - "Type": "one-off", - "CreatedAt": "2021-12-24T16:10:02.238478209+01:00", - "ExpiresAt": "2022-01-23T16:10:02.238478209+01:00", - "Revoked": false, - "UsedTimes": 1, - "LastUsed": "2021-12-24T16:11:27.015741738+01:00" - } - }, - "Network": { - "Id": "b6d0b152-364e-40c1-a8a1-fa7bcac2267f", - "Net": { - "IP": "100.64.0.0", - "Mask": "/8AAAA==" - }, - "Dns": "" - }, - "Peers": { - "6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=": { - "Key": "6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=", - "SetupKey": "5AFB60DB-61F2-4251-8E11-494847EE88E9", - "IP": "100.64.0.2", - "Meta": { - "Hostname": "braginini", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "21.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "" - }, - "Name": "braginini", - "Status": { - "LastSeen": "2021-12-24T16:12:05.994305438+01:00", - "Connected": false - } - }, - "Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=": { - "Key": "Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=", - "SetupKey": "A72E4DC2-00DE-4542-8A24-62945438104E", - "IP": "100.64.0.1", - "Meta": { - "Hostname": "braginini", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "21.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "" - }, - "Name": "braginini", - "Status": { - "LastSeen": "2021-12-24T16:11:27.015739803+01:00", - "Connected": false - } - } - } - } - } -} \ No newline at end of file diff --git a/management/server/testdata/storev1.sqlite b/management/server/testdata/storev1.sqlite new file mode 100644 index 0000000000000000000000000000000000000000..9a376698e4d226fc08fa68c12fb9bb4cf50375cd GIT binary patch literal 163840 zcmeI5U2GfKb;miPBucg=uCv)poJ29DU9U~Zv^o5y17WKV+N-Q3c_lkpV-XC8nW4`*J9@$Z(o!v!l_ zZtRO4rs=98wM402XSyyL>Mm1sv8yvplpBL{FD`FZ7VlRQtLrP3_Y&3PtEnb7))Upi zl~u1*oyMxkr(4r8zFg+Mzi3Fc78_h(aBAG?)BVMzyOp7BC9Y9#s|kbc8`r&Nol`|> zGdF8T(sx+HhT*2M{f?|jhO8=LLo!Izd^hAayI!qpj6OHS+JP(VrbtFmt|jS)iI%$H z6~4E*y0*CaF!4_1q20UW_2rHAt^1pctLyg@)$LtN7rNMyL_-&w)x_RxFITn_*3fh!OM^`&--(a&VtJT}48|!6gVSHBO~v>;&mE3=7G(B7H_b=w zDhAn%bw1s1BE539QCO+eQ@?yQ`E22QjOTgof|4 zJ;pe#CTwq{Go#xP0||1+HLNz7rMRQ&M$iDQ#HK2$dPr=SX$otFlIp7F8bQ@6)oDqF z8N@i zXg<}L7V3v1(9zUJw{C=HrZq$-zbz@C-rc2pZ6Vvg&C{)=^X7&+5Wx5y#dRXgK1 zW2bM{fn16l+JRVC8=j&1nX+*Z(yy1(WU7@Ba)+mIbJh&)kEkbia}!6oIlFcWqcn>q z_r_!VTeIBZCC>+W(>)kE{lr+mQ}Qzd!Pa(AepZj`d-mpNjGvw5KIScF<;(e>#=>sP z&p3te3Qm+j$QcspY%`~Eht@{jj=3ai^rUKRj~1jsvn5lfg&8fo$YXdhaL~f4Iv_-M zb*ah3mfV)9ZLjnI(>;Y5`NW>Xma=2x{GHihAv4f~;C~*kMPqz+hCAdut8yc;*ihS& zthlG6{nq4WIg+*=xrt;&b_LbsNB6iGKQqJqkX#)(5Ld!Ij)bkpO%B_RyBI!hdF-_5 z_+$G_BoD%xs^+b%X$|M(bTqjwk=t+@4Y)%$xX2Ci@A9_nJAXHaj;`N#lgPvOCfV^P zU@2x4AO#q*(xh(4y5&h@!nyqCYkalB%~PxgB!WAt|ZVQlAV%M9W@HqU|UJGB>~t z2wGl!wz#`6CYL|~Q6>+fYT*X@Cvx)U$T+_|Gir>MEljzqpVJw@*$)Bg~O zKcqLjKmY_l00ck)1V8`;KmY_l00cnb^FZJ*Ir7@p)Y{b4#Di~7O!+MS3kP%zgO&ch*1o@kS(a_y+gd>VUqQeN2C^Ak?-?nMQ+2wR9<$P8TFe zsMYI@LERKeLaFFk56<%@-UDkl|7sd90E5C8!X009sH0T2KI5CDNQPT*-|G(Nla;PtO=ZEa0VZe5(1$mY|zlKHQAg8Ph% z#Y`o)vYbh!(<_x!E}zS!N{jTobs09Z{jZC(bE7HT-<@vNgPi2z`RVt)1*`q{GP)krmxRoj1NDKL5PAFzec0=!f z=Wsu2A(&ZPTeA^la;1E^kShq;6ClVh-d-xCSC&!*;dUmK%jAVrsUireTscQ0s#Hp) zN_mBCOWl@XBG4H9YyF~3LW!Q9Kl<)*WhpOI%9W+mq9EK(*RNE|7^8ilo_q=D*9uFVpaS<1b}}Vx}l$jv2kM zlv&EBm(r=FY^9XSEv1)Ji@8i8mCxiWD@&zZNmyDr&g=u-&KL8UL0TUsGI;QzS^^?M z#-v;w=GH}PM`e;NO;Gw$dxJP-f@5C8!X009sH0T2KI5C8!X z_#-AT9i8R;D|VyiseknjFHd?`IPCrZuKE4{e~$kozWYZE0bK_H5C8!X z009sH0T2KI5C8!X0D;c~ffIgdfAoZM(YfdtPK})Z3;KW; z2!H?xfB*=900@A<871&+Wb`WAVh^rdroT&j>CzOvBJ}@aD+#)AQo~DYV=1LyC2Lgx72Qfezlu^zx$r^ zcSFsy553af2YNm|{FJ))@OZXt{dTzdlsdfx^C|VTkf0wHxBksr?^JfckB-yNdDC;g z=8ulg%{_}mhbZB4j}i*Ik8162Jr;J_-TJq_ySMb7n&0TwALch6?#(~i(wBrs5ANT6 zoGA+{PfntQ(^PkoXWM-xWd7hv=uv_H5WD%^0{cPc=TO1Z4@a)D)sxJF{nmc@W_hJF zFV~b!d3ANEq`ZA+ZU4^3clW;~6t&XU&b_yra!z`xFP&|r+~V)F0Gaxe_Kv3AJki0-6wZ8=jXS& z+Wt2W*0P;4X&j8SOOmf0=)3@*c zFGS)>d?Eg?@xM8v!4AU$0T2KI5C8!X009sH0T2KI5CDPSpFldwxYJzC95oA^=E_^t zEO45OMI$$(Z1OZ$H?aTz{nd|>AOHd&00JNY0w4eaAOHd&00Ms^1kTLc|3_YV@qHw{S00@8p2!H?xfB*=900@8p2!OzgCV>0@ zFPbpw2LTWO0T2KI5C8!X009sH0T2Lzmrnru|Cg@=^Z*1v00ck)1V8`;KmY_l00ck) z1YR@&Jpccq38Q`x009sH0T2KI5C8!X009sH0T6ik1aSZVxoLM`@bc@{}g)B8_WLt}_c+HKBCdOq1)^tBR>J)?!wJf%Krb7}UQx<(=8?YP8l; z>$|4ryjm@7w$!Gq1k~x-LRj;%u6NDCrsXA8Dz_IO+`XTelREO;^{T8Gtm#;m#0GZlby(P(tn>5z=F}_a`WPzR zV@`J`$N8%~Cwmra7u{tvVv;^LtZotrS866Ce-MjV7WbpJqcn0SwnR4+mh1Vmg+`V>nYEaJy}zgHdDx_+6u7BZObjcD?!{13cM|; z=A$+BE9!=1bj=KM#X4i5xf>LytBS(vjJ_asuX z=RHjKW=(a+mh_#VCFyj^6f>d*+Yhv92ciKY=D1Gilvrnnk%tU*i@cm&wKIM*-1%l5 z$fd}k9f)qq5a~gMOIn(W!tC&Vls>b$c(Hk^dGId%@(MpLth8Nfdt%a%s;%--$noMlT zZJFBkN)IsIQ<#xY>^W>HJ2uYWnH{#|2{a-2pT}#_7@wWt4mr=N-0&(k)V3rm?&)a1 zHMv=iq-{rTB3Y4LK{ff&Jub%2%y2&>S4R%Sm2i(EVe4^|!?xqDUyobSI&C`s*gg}< zgRrKmd5d9Mm^nEeO>Rr%Hk?KS?$8Y`a)bQ4ye<3A-_4<;>-XIx^6Mcobhn#gtN@}&#C&LiYT2_;2%S(aG4R8a3mRFxG z?yh#pB~U<=$%Cj`xPktOoV+!;h(7Ruk5|L z`+oSP%Q1dK;10{)#5Fyrwc()QJII~pDZ2G%&h<~|ej#^aSn<)^fcJ~^go~#i&ct{@ z;6A!yP4<5I{**-60rU%mU4cI+Z>|M)$_=L1HMwJMi-Xw?7;k#ofLsa34FsmY?H@ex zLAEPt^&Q8L=>2BaZ#2EzXxBFE1&F!z^A*>(tqu(O=-trr*T?y93&Sk4pRFkLr_FAE zEyfowafg$hanXe)!(2CdTXZ|&a=#JvN!oG2PVO<5XIH)w<1by}KDLd;ma*Sw!baf~ zPXw=*_?>a}==fpxv`2}%jCDk-qJW~QNMO}FzcUwh^ItvcDx_1rU?0E! z!KE0#c!fKh@rN(y4S!0t)C6;1(FXZ_n4?7&U6md7N z=9!meS1NZabS-yzacg;TrIIXud7OXe$}kt?)N31q5Ip;;@}(I64O)y|3|o%bMSrfl z?Oz&VGGXa}y--Gv`+0idi`K&Q*{*LaB;~fi4=@t5lH;7vX)VOoKiRl$UStsKJFLFz zD)x9SI%J%QW%Vz7Mb<}0+6rEjbxS&$Ox}F$s6~xSMYMHB4So6ipWBO??N@ulMc?|v zPWV#4&iW*6jX8cWd{OsQH!7aKanXt`-+$dwu3e%1{)Mn&owiP)*5K9C>vU7XykO|~ z2X{GXUZ?mBE}q(be-77v>U&0m@14A``#JB#$9b_l)R_?Z0k^m4lF6X-S0WLv_y@Z= zG@#+IYv{S67K|K!T=#PxU5W8GZ*qrGdr9AW+0|VFdDmgDh38jazKO|Nclnp6C0Z7l z7ao&`<8Q?HJ4x;^83+PF#~2C)!Hlc@&L!+p10~(C5G+q_Fz70uM(1LDGRZx2ZX|}h zbGYIw;qitu!!{8#f&t5iKnEOu^7}`F(AawJ(goT%IwpMM^>Kb9IduAkpqP;%*yx*w z=VvX1hwoa3WAC@!?RWU>jWEJ+rit*lUA^mfhi>DOha-s?fAuPNxNP-CR_Jn6i*93( zr`AI!uZQ~*%5$~pE%bxAw2?5cA#~jZVt?>X;`x8~3I>iK00JNY0w4eaAOHd&00JNY z0w8d93E=nt&aQ4TG7ta(5C8!X009sH0T2KI5C8!Xa0%f4ze@r~5C8!X009sH0T2KI z5C8!X009s{{QUi79#@z5C8!X009sH0T2KI5C8!X00EZ( zp8t1A;0OXB00JNY0w4eaAOHd&00JNY0%w;1?*E@%-C|@Q00JNY0w4eaAOHd&00JNY z0wCZL!2aJQfg=cj00@8p2!H?xfB*=900@8p2%KF4=JWqt{C5%hhZhKd00@8p2!H?x zfB*=900@8p2!O!ZBM{}Hktuxt|LiFiBLV>s009sH0T2KI5C8!X009sHfnEaG|MwEX zIS7CN2!H?xfB*=900@8p2!H?xoIL{g{{Pw2DMkbWAOHd&00JNY0w4eaAOHd&00O-P zu>bERf^!f60T2KI5C8!X009sH0T2KI5IB1T@cjST( Date: Fri, 4 Oct 2024 17:17:01 +0300 Subject: [PATCH 18/37] [management] Refactor User JWT group sync (#2690) * Refactor GetAccountIDByUserOrAccountID Signed-off-by: bcmmbaga * sync user jwt group changes Signed-off-by: bcmmbaga * propagate jwt group changes to peers Signed-off-by: bcmmbaga * fix no jwt groups synced Signed-off-by: bcmmbaga * fix tests and lint Signed-off-by: bcmmbaga * Move the account peer update outside the transaction Signed-off-by: bcmmbaga * move updateUserPeersInGroups to account manager Signed-off-by: bcmmbaga * move event store outside of transaction Signed-off-by: bcmmbaga * get user with update lock Signed-off-by: bcmmbaga * Run jwt sync in transaction Signed-off-by: bcmmbaga --------- Signed-off-by: bcmmbaga --- management/server/account.go | 284 ++++++++++++------ management/server/account_test.go | 169 +++++++---- management/server/mock_server/account_mock.go | 12 +- management/server/sql_store.go | 120 +++++++- management/server/sql_store_test.go | 30 ++ management/server/store.go | 5 +- management/server/user.go | 72 ++++- management/server/user_test.go | 5 +- 8 files changed, 520 insertions(+), 177 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index d5e8c8cf8..da3203852 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -76,7 +76,7 @@ type AccountManager interface { SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) - GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) + GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error) @@ -478,12 +478,12 @@ func (a *Account) GetPeerNetworkMap( } nm := &NetworkMap{ - Peers: peersToConnect, - Network: a.Network.Copy(), - Routes: routesUpdate, - DNSConfig: dnsUpdate, - OfflinePeers: expiredPeers, - FirewallRules: firewallRules, + Peers: peersToConnect, + Network: a.Network.Copy(), + Routes: routesUpdate, + DNSConfig: dnsUpdate, + OfflinePeers: expiredPeers, + FirewallRules: firewallRules, RoutesFirewallRules: routesFirewallRules, } @@ -843,55 +843,54 @@ func (a *Account) GetPeer(peerID string) *nbpeer.Peer { return a.Peers[peerID] } -// SetJWTGroups updates the user's auto groups by synchronizing JWT groups. -// Returns true if there are changes in the JWT group membership. -func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { - user, ok := a.Users[userID] - if !ok { - return false - } - +// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups. +// Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups, +// newly groups to create and an error if any occurred. +func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgroup.Group, groupNames []string) (bool, []string, []*nbgroup.Group, error) { existedGroupsByName := make(map[string]*nbgroup.Group) - for _, group := range a.Groups { + for _, group := range groups { existedGroupsByName[group.Name] = group } - newAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, a.Groups) - groupsToAdd := difference(groupsNames, maps.Keys(jwtGroupsMap)) - groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupsNames) + newUserAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, groups) + + groupsToAdd := difference(groupNames, maps.Keys(jwtGroupsMap)) + groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupNames) // If no groups are added or removed, we should not sync account if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { - return false + return false, nil, nil, nil } + newGroupsToCreate := make([]*nbgroup.Group, 0) + var modified bool for _, name := range groupsToAdd { group, exists := existedGroupsByName[name] if !exists { group = &nbgroup.Group{ - ID: xid.New().String(), - Name: name, - Issued: nbgroup.GroupIssuedJWT, + ID: xid.New().String(), + AccountID: user.AccountID, + Name: name, + Issued: nbgroup.GroupIssuedJWT, } - a.Groups[group.ID] = group + newGroupsToCreate = append(newGroupsToCreate, group) } if group.Issued == nbgroup.GroupIssuedJWT { - newAutoGroups = append(newAutoGroups, group.ID) + newUserAutoGroups = append(newUserAutoGroups, group.ID) modified = true } } for name, id := range jwtGroupsMap { if !slices.Contains(groupsToRemove, name) { - newAutoGroups = append(newAutoGroups, id) + newUserAutoGroups = append(newUserAutoGroups, id) continue } modified = true } - user.AutoGroups = newAutoGroups - return modified + return modified, newUserAutoGroups, newGroupsToCreate, nil } // UserGroupsAddToPeers adds groups to all peers of user @@ -1262,37 +1261,31 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u return nil } -// GetAccountIDByUserOrAccountID retrieves the account ID based on either the userID or accountID provided. -// If an accountID is provided, it checks if the account exists and returns it. -// If no accountID is provided, but a userID is given, it tries to retrieve the account by userID. +// GetAccountIDByUserID retrieves the account ID based on the userID provided. +// If user does have an account, it returns the user's account ID. // If the user doesn't have an account, it creates one using the provided domain. // Returns the account ID or an error if none is found or created. -func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) { - if accountID != "" { - exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID) - if err != nil { - return "", err - } - if !exists { - return "", status.Errorf(status.NotFound, "account %s does not exist", accountID) - } - return accountID, nil +func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) { + if userID == "" { + return "", status.Errorf(status.NotFound, "no valid userID provided") } - if userID != "" { - account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) - if err != nil { - return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID) - } + accountID, err := am.Store.GetAccountIDByUserID(userID) + if err != nil { + if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { + account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) + if err != nil { + return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID) + } - if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil { - return "", err + if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil { + return "", err + } + return account.Id, nil } - - return account.Id, nil + return "", err } - - return "", status.Errorf(status.NotFound, "no valid userID or accountID provided") + return accountID, nil } func isNil(i idp.Manager) bool { @@ -1796,6 +1789,10 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId) } + if user.AccountID != accountID { + return "", "", status.Errorf(status.PermissionDenied, "user %s is not part of the account %s", claims.UserId, accountID) + } + if !user.IsServiceUser && claims.Invited { err = am.redeemInvite(ctx, accountID, user.Id) if err != nil { @@ -1803,7 +1800,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai } } - if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil { + if err = am.syncJWTGroups(ctx, accountID, claims); err != nil { return "", "", err } @@ -1812,7 +1809,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // and propagates changes to peers if group propagation is enabled. -func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, user *User, claims jwtclaims.AuthorizationClaims) error { +func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error { settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { return err @@ -1823,69 +1820,136 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } if settings.JWTGroupsClaimName == "" { - log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set") + log.WithContext(ctx).Debugf("JWT groups are enabled but no claim name is set") return nil } - // TODO: Remove GetAccount after refactoring account peer's update - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) - oldGroups := make([]string, len(user.AutoGroups)) - copy(oldGroups, user.AutoGroups) + unlockPeer := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer func() { + if unlockPeer != nil { + unlockPeer() + } + }() - // Update the account if group membership changes - if account.SetJWTGroups(claims.UserId, jwtGroupsNames) { - addNewGroups := difference(user.AutoGroups, oldGroups) - removeOldGroups := difference(oldGroups, user.AutoGroups) - - if settings.GroupsPropagationEnabled { - account.UserGroupsAddToPeers(claims.UserId, addNewGroups...) - account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...) - account.Network.IncSerial() + var addNewGroups []string + var removeOldGroups []string + var hasChanges bool + var user *User + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + user, err = am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) + if err != nil { + return fmt.Errorf("error getting user: %w", err) } - if err := am.Store.SaveAccount(ctx, account); err != nil { - log.WithContext(ctx).Errorf("failed to save account: %v", err) + groups, err := am.Store.GetAccountGroups(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account groups: %w", err) + } + + changed, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(user, groups, jwtGroupsNames) + if err != nil { + return fmt.Errorf("error getting JWT groups changes: %w", err) + } + + hasChanges = changed + // skip update if no changes + if !changed { return nil } + if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil { + return fmt.Errorf("error saving groups: %w", err) + } + + addNewGroups = difference(updatedAutoGroups, user.AutoGroups) + removeOldGroups = difference(user.AutoGroups, updatedAutoGroups) + + user.AutoGroups = updatedAutoGroups + if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil { + return fmt.Errorf("error saving user: %w", err) + } + // Propagate changes to peers if group propagation is enabled if settings.GroupsPropagationEnabled { - log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) - am.updateAccountPeers(ctx, account) - } + groups, err = transaction.GetAccountGroups(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account groups: %w", err) + } - for _, g := range addNewGroups { - if group := account.GetGroup(g); group != nil { - am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser, - map[string]any{ - "group": group.Name, - "group_id": group.ID, - "is_service_user": user.IsServiceUser, - "user_name": user.ServiceUserName}) + groupsMap := make(map[string]*nbgroup.Group, len(groups)) + for _, group := range groups { + groupsMap[group.ID] = group + } + + peers, err := transaction.GetUserPeers(ctx, LockingStrengthShare, accountID, claims.UserId) + if err != nil { + return fmt.Errorf("error getting user peers: %w", err) + } + + updatedGroups, err := am.updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups) + if err != nil { + return fmt.Errorf("error modifying user peers in groups: %w", err) + } + + if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, updatedGroups); err != nil { + return fmt.Errorf("error saving groups: %w", err) + } + + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + return fmt.Errorf("error incrementing network serial: %w", err) } } + unlockPeer() + unlockPeer = nil - for _, g := range removeOldGroups { - if group := account.GetGroup(g); group != nil { - am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser, - map[string]any{ - "group": group.Name, - "group_id": group.ID, - "is_service_user": user.IsServiceUser, - "user_name": user.ServiceUserName}) + return nil + }) + if err != nil { + return err + } + + if !hasChanges { + return nil + } + + for _, g := range addNewGroups { + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + if err != nil { + log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) + } else { + meta := map[string]any{ + "group": group.Name, "group_id": group.ID, + "is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName, } + am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupAddedToUser, meta) } } + for _, g := range removeOldGroups { + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + if err != nil { + log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) + } else { + meta := map[string]any{ + "group": group.Name, "group_id": group.ID, + "is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName, + } + am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupRemovedFromUser, meta) + } + } + + if settings.GroupsPropagationEnabled { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account: %w", err) + } + + log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) + am.updateAccountPeers(ctx, account) + } + return nil } @@ -1916,7 +1980,17 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context // if Account ID is part of the claims // it means that we've already classified the domain and user has an account if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { - return am.GetAccountIDByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain) + if claims.AccountId != "" { + exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, claims.AccountId) + if err != nil { + return "", err + } + if !exists { + return "", status.Errorf(status.NotFound, "account %s does not exist", claims.AccountId) + } + return claims.AccountId, nil + } + return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain) } else if claims.AccountId != "" { userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) if err != nil { @@ -2229,7 +2303,11 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac routes := make(map[route.ID]*route.Route) setupKeys := map[string]*SetupKey{} nameServersGroups := make(map[string]*nbdns.NameServerGroup) - users[userID] = NewOwnerUser(userID) + + owner := NewOwnerUser(userID) + owner.AccountID = accountID + users[userID] = owner + dnsSettings := DNSSettings{ DisabledManagementGroups: make([]string, 0), } @@ -2297,12 +2375,17 @@ func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool { // separateGroups separates user's auto groups into non-JWT and JWT groups. // Returns the list of standard auto groups and a map of JWT auto groups, // where the keys are the group names and the values are the group IDs. -func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([]string, map[string]string) { +func separateGroups(autoGroups []string, allGroups []*nbgroup.Group) ([]string, map[string]string) { newAutoGroups := make([]string, 0) jwtAutoGroups := make(map[string]string) // map of group name to group ID + allGroupsMap := make(map[string]*nbgroup.Group, len(allGroups)) + for _, group := range allGroups { + allGroupsMap[group.ID] = group + } + for _, id := range autoGroups { - if group, ok := allGroups[id]; ok { + if group, ok := allGroupsMap[id]; ok { if group.Issued == nbgroup.GroupIssuedJWT { jwtAutoGroups[group.Name] = id } else { @@ -2310,5 +2393,6 @@ func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([ } } } + return newAutoGroups, jwtAutoGroups } diff --git a/management/server/account_test.go b/management/server/account_test.go index 198775bc3..c417e4bc8 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -633,7 +633,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain) + accountID, err := manager.GetAccountIDByUserID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.Domain) require.NoError(t, err, "create init user failed") initAccount, err := manager.Store.GetAccount(context.Background(), accountID) @@ -671,17 +671,16 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { userId := "user-id" domain := "test.domain" - initAccount := newAccountWithId(context.Background(), "", userId, domain) + _ = newAccountWithId(context.Background(), "", userId, domain) manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID := initAccount.Id - accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userId, accountID, domain) + accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain) require.NoError(t, err, "create init user failed") // as initAccount was created without account id we have to take the id after account initialization - // that happens inside the GetAccountIDByUserOrAccountID where the id is getting generated + // that happens inside the GetAccountIDByUserID where the id is getting generated // it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it - initAccount, err = manager.Store.GetAccount(context.Background(), accountID) + initAccount, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "get init account failed") claims := jwtclaims.AuthorizationClaims{ @@ -885,7 +884,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) { } } -func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { +func TestAccountManager_GetAccountByUserID(t *testing.T) { manager, err := createManager(t) if err != nil { t.Fatal(err) @@ -894,7 +893,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { userId := "test_user" - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userId, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, "") if err != nil { t.Fatal(err) } @@ -903,14 +902,13 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { return } - _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "") - if err != nil { - t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountID) - } + exists, err := manager.Store.AccountExists(context.Background(), LockingStrengthShare, accountID) + assert.NoError(t, err) + assert.True(t, exists, "expected to get existing account after creation using userid") - _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", "", "") + _, err = manager.GetAccountIDByUserID(context.Background(), "", "") if err == nil { - t.Errorf("expected an error when user and account IDs are empty") + t.Errorf("expected an error when user ID is empty") } } @@ -1669,7 +1667,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) @@ -1684,7 +1682,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + _, err = manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1696,7 +1694,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { }) require.NoError(t, err, "unable to add peer") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to get the account") account, err := manager.Store.GetAccount(context.Background(), accountID) @@ -1742,7 +1740,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1770,7 +1768,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. }, } - accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err = manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to get the account") account, err := manager.Store.GetAccount(context.Background(), accountID) @@ -1790,7 +1788,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + _, err = manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1802,7 +1800,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test }) require.NoError(t, err, "unable to add peer") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to get the account") account, err := manager.Store.GetAccount(context.Background(), accountID) @@ -1850,7 +1848,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ @@ -1861,9 +1859,6 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { assert.False(t, updated.Settings.PeerLoginExpirationEnabled) assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour) - accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "") - require.NoError(t, err, "unable to get account by ID") - settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) require.NoError(t, err, "unable to get account settings") @@ -2199,8 +2194,12 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) { } func TestAccount_SetJWTGroups(t *testing.T) { + manager, err := createManager(t) + require.NoError(t, err, "unable to create account manager") + // create a new account account := &Account{ + Id: "accountID", Peers: map[string]*nbpeer.Peer{ "peer1": {ID: "peer1", Key: "key1", UserID: "user1"}, "peer2": {ID: "peer2", Key: "key2", UserID: "user1"}, @@ -2211,62 +2210,120 @@ func TestAccount_SetJWTGroups(t *testing.T) { Groups: map[string]*group.Group{ "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}}, }, - Settings: &Settings{GroupsPropagationEnabled: true}, + Settings: &Settings{GroupsPropagationEnabled: true, JWTGroupsEnabled: true, JWTGroupsClaimName: "groups"}, Users: map[string]*User{ - "user1": {Id: "user1"}, - "user2": {Id: "user2"}, + "user1": {Id: "user1", AccountID: "accountID"}, + "user2": {Id: "user2", AccountID: "accountID"}, }, } + assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account") + t.Run("empty jwt groups", func(t *testing.T) { - updated := account.SetJWTGroups("user1", []string{}) - assert.False(t, updated, "account should not be updated") - assert.Empty(t, account.Users["user1"].AutoGroups, "auto groups must be empty") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{}}, + } + err := manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Empty(t, user.AutoGroups, "auto groups must be empty") }) t.Run("jwt match existing api group", func(t *testing.T) { - updated := account.SetJWTGroups("user1", []string{"group1"}) - assert.False(t, updated, "account should not be updated") - assert.Equal(t, 0, len(account.Users["user1"].AutoGroups)) - assert.Equal(t, account.Groups["group1"].Issued, group.GroupIssuedAPI, "group should be api issued") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{"group1"}}, + } + err := manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 0) + + group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") + assert.NoError(t, err, "unable to get group") + assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") }) t.Run("jwt match existing api group in user auto groups", func(t *testing.T) { account.Users["user1"].AutoGroups = []string{"group1"} + assert.NoError(t, manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, account.Users["user1"])) - updated := account.SetJWTGroups("user1", []string{"group1"}) - assert.False(t, updated, "account should not be updated") - assert.Equal(t, 1, len(account.Users["user1"].AutoGroups)) - assert.Equal(t, account.Groups["group1"].Issued, group.GroupIssuedAPI, "group should be api issued") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{"group1"}}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 1) + + group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") + assert.NoError(t, err, "unable to get group") + assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") }) t.Run("add jwt group", func(t *testing.T) { - updated := account.SetJWTGroups("user1", []string{"group1", "group2"}) - assert.True(t, updated, "account should be updated") - assert.Len(t, account.Groups, 2, "new group should be added") - assert.Len(t, account.Users["user1"].AutoGroups, 2, "new group should be added") - assert.Contains(t, account.Groups, account.Users["user1"].AutoGroups[0], "groups must contain group2 from user groups") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group2"}}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 2, "groups count should not be change") }) t.Run("existed group not update", func(t *testing.T) { - updated := account.SetJWTGroups("user1", []string{"group2"}) - assert.False(t, updated, "account should not be updated") - assert.Len(t, account.Groups, 2, "groups count should not be changed") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{"group2"}}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 2, "groups count should not be change") }) t.Run("add new group", func(t *testing.T) { - updated := account.SetJWTGroups("user2", []string{"group1", "group3"}) - assert.True(t, updated, "account should be updated") - assert.Len(t, account.Groups, 3, "new group should be added") - assert.Len(t, account.Users["user2"].AutoGroups, 1, "new group should be added") - assert.Contains(t, account.Groups, account.Users["user2"].AutoGroups[0], "groups must contain group3 from user groups") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user2", + Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group3"}}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID") + assert.NoError(t, err) + assert.Len(t, groups, 3, "new group3 should be added") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user2") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 1, "new group should be added") }) t.Run("remove all JWT groups", func(t *testing.T) { - updated := account.SetJWTGroups("user1", []string{}) - assert.True(t, updated, "account should be updated") - assert.Len(t, account.Users["user1"].AutoGroups, 1, "only non-JWT groups should remain") - assert.Contains(t, account.Users["user1"].AutoGroups, "group1", " group1 should still be present") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{}}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain") + assert.Contains(t, user.AutoGroups, "group1", " group1 should still be present") }) } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index b399be822..b6283a7e6 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -27,7 +27,7 @@ type MockAccountManager struct { CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) - GetAccountIDByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (string, error) + GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error) GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) @@ -194,14 +194,14 @@ func (am *MockAccountManager) CreateSetupKey( return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented") } -// GetAccountIDByUserOrAccountID mock implementation of GetAccountIDByUserOrAccountID from server.AccountManager interface -func (am *MockAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userId, accountId, domain string) (string, error) { - if am.GetAccountIDByUserOrAccountIdFunc != nil { - return am.GetAccountIDByUserOrAccountIdFunc(ctx, userId, accountId, domain) +// GetAccountIDByUserID mock implementation of GetAccountIDByUserID from server.AccountManager interface +func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId, domain string) (string, error) { + if am.GetAccountIDByUserIdFunc != nil { + return am.GetAccountIDByUserIdFunc(ctx, userId, domain) } return "", status.Errorf( codes.Unimplemented, - "method GetAccountIDByUserOrAccountID is not implemented", + "method GetAccountIDByUserID is not implemented", ) } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index cce748a0f..9e1ab27dc 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -10,6 +10,7 @@ import ( "path/filepath" "runtime" "runtime/debug" + "slices" "strings" "sync" "time" @@ -378,15 +379,26 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error { Create(&usersToSave).Error } -// SaveGroups saves the given list of groups to the database. -// It updates existing groups if a conflict occurs. -func (s *SqlStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error { - groupsToSave := make([]nbgroup.Group, 0, len(groups)) - for _, group := range groups { - group.AccountID = accountID - groupsToSave = append(groupsToSave, *group) +// SaveUser saves the given user to the database. +func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user) + if result.Error != nil { + return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error) } - return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&groupsToSave).Error + return nil +} + +// SaveGroups saves the given list of groups to the database. +func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error { + if len(groups) == 0 { + return nil + } + + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups) + if result.Error != nil { + return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error) + } + return nil } // DeleteHashedPAT2TokenIDIndex is noop in SqlStore @@ -1021,6 +1033,89 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId return nil } +// AddUserPeersToGroups adds the user's peers to specified groups in database. +func (s *SqlStore) AddUserPeersToGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error { + if len(groupIDs) == 0 { + return nil + } + + var userPeerIDs []string + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(LockingStrengthShare)}).Select("id"). + Where("account_id = ? AND user_id = ?", accountID, userID).Model(&nbpeer.Peer{}).Find(&userPeerIDs) + if result.Error != nil { + return status.Errorf(status.Internal, "issue finding user peers") + } + + groupsToUpdate := make([]*nbgroup.Group, 0, len(groupIDs)) + for _, gid := range groupIDs { + group, err := s.GetGroupByID(ctx, LockingStrengthShare, gid, accountID) + if err != nil { + return err + } + + groupPeers := make(map[string]struct{}) + for _, pid := range group.Peers { + groupPeers[pid] = struct{}{} + } + + for _, pid := range userPeerIDs { + groupPeers[pid] = struct{}{} + } + + group.Peers = group.Peers[:0] + for pid := range groupPeers { + group.Peers = append(group.Peers, pid) + } + + groupsToUpdate = append(groupsToUpdate, group) + } + + return s.SaveGroups(ctx, LockingStrengthUpdate, groupsToUpdate) +} + +// RemoveUserPeersFromGroups removes the user's peers from specified groups in database. +func (s *SqlStore) RemoveUserPeersFromGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error { + if len(groupIDs) == 0 { + return nil + } + + var userPeerIDs []string + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(LockingStrengthShare)}).Select("id"). + Where("account_id = ? AND user_id = ?", accountID, userID).Model(&nbpeer.Peer{}).Find(&userPeerIDs) + if result.Error != nil { + return status.Errorf(status.Internal, "issue finding user peers") + } + + groupsToUpdate := make([]*nbgroup.Group, 0, len(groupIDs)) + for _, gid := range groupIDs { + group, err := s.GetGroupByID(ctx, LockingStrengthShare, gid, accountID) + if err != nil { + return err + } + + if group.Name == "All" { + continue + } + + update := make([]string, 0, len(group.Peers)) + for _, pid := range group.Peers { + if !slices.Contains(userPeerIDs, pid) { + update = append(update, pid) + } + } + + group.Peers = update + groupsToUpdate = append(groupsToUpdate, group) + } + + return s.SaveGroups(ctx, LockingStrengthUpdate, groupsToUpdate) +} + +// GetUserPeers retrieves peers for a user. +func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { + return getRecords[*nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID) +} + func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { if err := s.db.WithContext(ctx).Create(peer).Error; err != nil { return status.Errorf(status.Internal, "issue adding peer to account") @@ -1127,6 +1222,15 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren return &group, nil } +// SaveGroup saves a group to the store. +func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group) + if result.Error != nil { + return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error) + } + return nil +} + // GetAccountPolicies retrieves policies for an account. func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index dc07849d9..4eed09c69 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1185,3 +1185,33 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) { require.NoError(t, err) assert.Equal(t, 2, setupKey.UsedTimes) } + +func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) + } + group := &nbgroup.Group{ + ID: "group-id", + AccountID: "account-id", + Name: "group-name", + Issued: "api", + Peers: nil, + } + err = store.ExecuteInTransaction(context.Background(), func(transaction Store) error { + err := transaction.SaveGroup(context.Background(), LockingStrengthUpdate, group) + if err != nil { + t.Fatal("failed to save group") + return err + } + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID) + if err != nil { + t.Fatal("failed to get group") + return err + } + t.Logf("group: %v", group) + return nil + }) + assert.NoError(t, err) +} diff --git a/management/server/store.go b/management/server/store.go index 041c936ae..50bc6afdf 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -60,6 +60,7 @@ type Store interface { GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) SaveUsers(accountID string, users map[string]*User) error + SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) DeleteHashedPAT2TokenIDIndex(hashedToken string) error @@ -68,7 +69,8 @@ type Store interface { GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) - SaveGroups(accountID string, groups map[string]*nbgroup.Group) error + SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error + SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) @@ -82,6 +84,7 @@ type Store interface { AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) + GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error diff --git a/management/server/user.go b/management/server/user.go index 6d01561c6..38a8ac0c4 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -8,14 +8,14 @@ import ( "time" "github.com/google/uuid" - log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + log "github.com/sirupsen/logrus" ) const ( @@ -1254,6 +1254,74 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, nil } +// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them. +func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[string]*nbgroup.Group, peers []*nbpeer.Peer, groupsToAdd, + groupsToRemove []string) (groupsToUpdate []*nbgroup.Group, err error) { + + if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { + return + } + + userPeerIDMap := make(map[string]struct{}, len(peers)) + for _, peer := range peers { + userPeerIDMap[peer.ID] = struct{}{} + } + + for _, gid := range groupsToAdd { + group, ok := accountGroups[gid] + if !ok { + return nil, errors.New("group not found") + } + addUserPeersToGroup(userPeerIDMap, group) + groupsToUpdate = append(groupsToUpdate, group) + } + + for _, gid := range groupsToRemove { + group, ok := accountGroups[gid] + if !ok { + return nil, errors.New("group not found") + } + removeUserPeersFromGroup(userPeerIDMap, group) + groupsToUpdate = append(groupsToUpdate, group) + } + + return groupsToUpdate, nil +} + +// addUserPeersToGroup adds the user's peers to the group. +func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) { + groupPeers := make(map[string]struct{}, len(group.Peers)) + for _, pid := range group.Peers { + groupPeers[pid] = struct{}{} + } + + for pid := range userPeerIDs { + groupPeers[pid] = struct{}{} + } + + group.Peers = make([]string, 0, len(groupPeers)) + for pid := range groupPeers { + group.Peers = append(group.Peers, pid) + } +} + +// removeUserPeersFromGroup removes user's peers from the group. +func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) { + // skip removing peers from group All + if group.Name == "All" { + return + } + + updatedPeers := make([]string, 0, len(group.Peers)) + for _, pid := range group.Peers { + if _, found := userPeerIDs[pid]; !found { + updatedPeers = append(updatedPeers, pid) + } + } + + group.Peers = updatedPeers +} + func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) { for _, user := range userData { if user.ID == userID { diff --git a/management/server/user_test.go b/management/server/user_test.go index ec0a10695..1a5704551 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -813,10 +813,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) { assert.NoError(t, err) } - accID, err := am.GetAccountIDByUserOrAccountID(context.Background(), "", account.Id, "") - assert.NoError(t, err) - - acc, err := am.Store.GetAccount(context.Background(), accID) + acc, err := am.Store.GetAccount(context.Background(), account.Id) assert.NoError(t, err) for _, id := range tc.expectedDeleted { From 8bf729c7b4e8d3f7a7a88c9dc2ac60b63c803a8c Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Fri, 4 Oct 2024 18:09:40 +0300 Subject: [PATCH 19/37] [management] Add AccountExists to AccountManager (#2694) * Add AccountExists method to account manager interface Signed-off-by: bcmmbaga * remove unused code Signed-off-by: bcmmbaga --------- Signed-off-by: bcmmbaga --- management/server/account.go | 6 ++ management/server/mock_server/account_mock.go | 13 ++- management/server/sql_store.go | 79 ------------------- 3 files changed, 17 insertions(+), 81 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index da3203852..a9781b385 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -76,6 +76,7 @@ type AccountManager interface { SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) + AccountExists(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error @@ -1261,6 +1262,11 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u return nil } +// AccountExists checks if an account exists. +func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) { + return am.Store.AccountExists(ctx, LockingStrengthShare, accountID) +} + // GetAccountIDByUserID retrieves the account ID based on the userID provided. // If user does have an account, it returns the user's account ID. // If the user doesn't have an account, it creates one using the provided domain. diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index b6283a7e6..ec29222a4 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -27,6 +27,7 @@ type MockAccountManager struct { CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) + AccountExistsFunc func(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error) GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) @@ -58,7 +59,7 @@ type MockAccountManager struct { UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error UpdatePeerSSHKeyFunc func(ctx context.Context, peerID string, sshKey string) error UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) - CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups,accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) + CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error @@ -194,6 +195,14 @@ func (am *MockAccountManager) CreateSetupKey( return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented") } +// AccountExists mock implementation of AccountExists from server.AccountManager interface +func (am *MockAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) { + if am.GetAccountIDByUserIdFunc != nil { + return am.AccountExistsFunc(ctx, accountID) + } + return false, status.Errorf(codes.Unimplemented, "method AccountExists is not implemented") +} + // GetAccountIDByUserID mock implementation of GetAccountIDByUserID from server.AccountManager interface func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId, domain string) (string, error) { if am.GetAccountIDByUserIdFunc != nil { @@ -444,7 +453,7 @@ func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID // CreateRoute mock implementation of CreateRoute from server.AccountManager interface func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { if am.CreateRouteFunc != nil { - return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups,accessControlGroupID, enabled, userID, keepRoute) + return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, accessControlGroupID, enabled, userID, keepRoute) } return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented") } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 9e1ab27dc..d056015d8 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -10,7 +10,6 @@ import ( "path/filepath" "runtime" "runtime/debug" - "slices" "strings" "sync" "time" @@ -1033,84 +1032,6 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId return nil } -// AddUserPeersToGroups adds the user's peers to specified groups in database. -func (s *SqlStore) AddUserPeersToGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error { - if len(groupIDs) == 0 { - return nil - } - - var userPeerIDs []string - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(LockingStrengthShare)}).Select("id"). - Where("account_id = ? AND user_id = ?", accountID, userID).Model(&nbpeer.Peer{}).Find(&userPeerIDs) - if result.Error != nil { - return status.Errorf(status.Internal, "issue finding user peers") - } - - groupsToUpdate := make([]*nbgroup.Group, 0, len(groupIDs)) - for _, gid := range groupIDs { - group, err := s.GetGroupByID(ctx, LockingStrengthShare, gid, accountID) - if err != nil { - return err - } - - groupPeers := make(map[string]struct{}) - for _, pid := range group.Peers { - groupPeers[pid] = struct{}{} - } - - for _, pid := range userPeerIDs { - groupPeers[pid] = struct{}{} - } - - group.Peers = group.Peers[:0] - for pid := range groupPeers { - group.Peers = append(group.Peers, pid) - } - - groupsToUpdate = append(groupsToUpdate, group) - } - - return s.SaveGroups(ctx, LockingStrengthUpdate, groupsToUpdate) -} - -// RemoveUserPeersFromGroups removes the user's peers from specified groups in database. -func (s *SqlStore) RemoveUserPeersFromGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error { - if len(groupIDs) == 0 { - return nil - } - - var userPeerIDs []string - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(LockingStrengthShare)}).Select("id"). - Where("account_id = ? AND user_id = ?", accountID, userID).Model(&nbpeer.Peer{}).Find(&userPeerIDs) - if result.Error != nil { - return status.Errorf(status.Internal, "issue finding user peers") - } - - groupsToUpdate := make([]*nbgroup.Group, 0, len(groupIDs)) - for _, gid := range groupIDs { - group, err := s.GetGroupByID(ctx, LockingStrengthShare, gid, accountID) - if err != nil { - return err - } - - if group.Name == "All" { - continue - } - - update := make([]string, 0, len(group.Peers)) - for _, pid := range group.Peers { - if !slices.Contains(userPeerIDs, pid) { - update = append(update, pid) - } - } - - group.Peers = update - groupsToUpdate = append(groupsToUpdate, group) - } - - return s.SaveGroups(ctx, LockingStrengthUpdate, groupsToUpdate) -} - // GetUserPeers retrieves peers for a user. func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { return getRecords[*nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID) From 5897a48e299d5553b6b375d2bc2b4df3e2dc24f1 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Fri, 4 Oct 2024 18:55:25 +0300 Subject: [PATCH 20/37] fix wrong reference (#2695) Signed-off-by: bcmmbaga --- management/server/mock_server/account_mock.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index ec29222a4..74557e227 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -197,7 +197,7 @@ func (am *MockAccountManager) CreateSetupKey( // AccountExists mock implementation of AccountExists from server.AccountManager interface func (am *MockAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) { - if am.GetAccountIDByUserIdFunc != nil { + if am.AccountExistsFunc != nil { return am.AccountExistsFunc(ctx, accountID) } return false, status.Errorf(codes.Unimplemented, "method AccountExists is not implemented") From f603cd92027f76c91b7d14ebb64ae6f3aa328639 Mon Sep 17 00:00:00 2001 From: Carlos Hernandez Date: Fri, 4 Oct 2024 11:15:16 -0600 Subject: [PATCH 21/37] [client] Check wginterface instead of engine ctx (#2676) Moving code to ensure wgInterface is gone right after context is cancelled/stop in the off chance that on next retry the backoff operation is permanently cancelled and interface is abandoned without destroying. --- client/internal/connect.go | 15 +++++++++------ client/internal/engine.go | 26 ++++++++++++++------------ 2 files changed, 23 insertions(+), 18 deletions(-) diff --git a/client/internal/connect.go b/client/internal/connect.go index c77f95603..74dc1f1b5 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -269,12 +269,6 @@ func (c *ConnectClient) run( checks := loginResp.GetChecks() c.engineMutex.Lock() - if c.engine != nil && c.engine.ctx.Err() != nil { - log.Info("Stopping Netbird Engine") - if err := c.engine.Stop(); err != nil { - log.Errorf("Failed to stop engine: %v", err) - } - } c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks) c.engineMutex.Unlock() @@ -294,6 +288,15 @@ func (c *ConnectClient) run( } <-engineCtx.Done() + c.engineMutex.Lock() + if c.engine != nil && c.engine.wgInterface != nil { + log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name()) + if err := c.engine.Stop(); err != nil { + log.Errorf("Failed to stop engine: %v", err) + } + c.engine = nil + } + c.engineMutex.Unlock() c.statusRecorder.ClientTeardown() backOff.Reset() diff --git a/client/internal/engine.go b/client/internal/engine.go index c51901a22..eac8ec098 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -251,6 +251,13 @@ func (e *Engine) Stop() error { } log.Info("Network monitor: stopped") + // stop/restore DNS first so dbus and friends don't complain because of a missing interface + e.stopDNSServer() + + if e.routeManager != nil { + e.routeManager.Stop() + } + err := e.removeAllPeers() if err != nil { return fmt.Errorf("failed to remove all peers: %s", err) @@ -1116,18 +1123,12 @@ func (e *Engine) close() { } } - // stop/restore DNS first so dbus and friends don't complain because of a missing interface - e.stopDNSServer() - - if e.routeManager != nil { - e.routeManager.Stop() - } - log.Debugf("removing Netbird interface %s", e.config.WgIfaceName) if e.wgInterface != nil { if err := e.wgInterface.Close(); err != nil { log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err) } + e.wgInterface = nil } if !isNil(e.sshServer) { @@ -1395,7 +1396,7 @@ func (e *Engine) startNetworkMonitor() { } // Set a new timer to debounce rapid network changes - debounceTimer = time.AfterFunc(1*time.Second, func() { + debounceTimer = time.AfterFunc(2*time.Second, func() { // This function is called after the debounce period mu.Lock() defer mu.Unlock() @@ -1426,6 +1427,11 @@ func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) { } func (e *Engine) stopDNSServer() { + if e.dnsServer == nil { + return + } + e.dnsServer.Stop() + e.dnsServer = nil err := fmt.Errorf("DNS server stopped") nsGroupStates := e.statusRecorder.GetDNSStates() for i := range nsGroupStates { @@ -1433,10 +1439,6 @@ func (e *Engine) stopDNSServer() { nsGroupStates[i].Error = err } e.statusRecorder.UpdateDNSStates(nsGroupStates) - if e.dnsServer != nil { - e.dnsServer.Stop() - e.dnsServer = nil - } } // isChecksEqual checks if two slices of checks are equal. From dbec24b52080bd14739e9b0dd8c950e4b0edf0cb Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Sun, 6 Oct 2024 17:01:13 +0200 Subject: [PATCH 22/37] [management] Remove admin check on getAccountByID (#2699) --- management/server/account.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index a9781b385..6ee0015f8 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -20,6 +20,11 @@ import ( cacheStore "github.com/eko/gocache/v3/store" "github.com/hashicorp/go-multierror" "github.com/miekg/dns" + gocache "github.com/patrickmn/go-cache" + "github.com/rs/xid" + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + "github.com/netbirdio/netbird/base62" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" @@ -36,10 +41,6 @@ import ( "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/route" - gocache "github.com/patrickmn/go-cache" - "github.com/rs/xid" - log "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" ) const ( @@ -1764,7 +1765,7 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s return nil, err } - if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) { + if user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data") } From 2c1f5e46d5928a21458749251c4ee5eb96575239 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Mon, 7 Oct 2024 19:06:26 +0300 Subject: [PATCH 23/37] [management] Validate peer ownership during login (#2704) * check peer ownership in login Signed-off-by: bcmmbaga * update error message Signed-off-by: bcmmbaga --------- Signed-off-by: bcmmbaga --- management/server/peer.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/management/server/peer.go b/management/server/peer.go index da9586734..a7d4f3b06 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -693,6 +693,11 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) updateRemotePeers := false if login.UserID != "" { + if peer.UserID != login.UserID { + log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, login.UserID) + return nil, nil, nil, status.Errorf(status.Unauthenticated, "invalid user") + } + changed, err := am.handleUserPeer(ctx, peer, settings) if err != nil { return nil, nil, nil, err From 44e81073832c2d1167ef6ec284c8bb9f4c5bc3d8 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 8 Oct 2024 11:21:11 +0200 Subject: [PATCH 24/37] [client] Limit P2P attempts and restart on specific events (#2657) --- client/internal/peer/conn.go | 60 ++++--- client/internal/peer/conn_monitor.go | 212 +++++++++++++++++++++++++ client/internal/peer/stdnet.go | 4 +- client/internal/peer/stdnet_android.go | 4 +- client/internal/peer/worker_ice.go | 63 ++++---- 5 files changed, 285 insertions(+), 58 deletions(-) create mode 100644 client/internal/peer/conn_monitor.go diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index ad84bd700..0d4ad2396 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -32,6 +32,8 @@ const ( connPriorityRelay ConnPriority = 1 connPriorityICETurn ConnPriority = 1 connPriorityICEP2P ConnPriority = 2 + + reconnectMaxElapsedTime = 30 * time.Minute ) type WgConfig struct { @@ -83,6 +85,7 @@ type Conn struct { wgProxyICE wgproxy.Proxy wgProxyRelay wgproxy.Proxy signaler *Signaler + iFaceDiscover stdnet.ExternalIFaceDiscover relayManager *relayClient.Manager allowedIPsIP string handshaker *Handshaker @@ -108,6 +111,8 @@ type Conn struct { // for reconnection operations iCEDisconnected chan bool relayDisconnected chan bool + connMonitor *ConnMonitor + reconnectCh <-chan struct{} } // NewConn creates a new not opened Conn to the remote peer. @@ -123,21 +128,31 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu connLog := log.WithField("peer", config.Key) var conn = &Conn{ - log: connLog, - ctx: ctx, - ctxCancel: ctxCancel, - config: config, - statusRecorder: statusRecorder, - wgProxyFactory: wgProxyFactory, - signaler: signaler, - relayManager: relayManager, - allowedIPsIP: allowedIPsIP.String(), - statusRelay: NewAtomicConnStatus(), - statusICE: NewAtomicConnStatus(), + log: connLog, + ctx: ctx, + ctxCancel: ctxCancel, + config: config, + statusRecorder: statusRecorder, + wgProxyFactory: wgProxyFactory, + signaler: signaler, + iFaceDiscover: iFaceDiscover, + relayManager: relayManager, + allowedIPsIP: allowedIPsIP.String(), + statusRelay: NewAtomicConnStatus(), + statusICE: NewAtomicConnStatus(), + iCEDisconnected: make(chan bool, 1), relayDisconnected: make(chan bool, 1), } + conn.connMonitor, conn.reconnectCh = NewConnMonitor( + signaler, + iFaceDiscover, + config, + conn.relayDisconnected, + conn.iCEDisconnected, + ) + rFns := WorkerRelayCallbacks{ OnConnReady: conn.relayConnectionIsReady, OnDisconnected: conn.onWorkerRelayStateDisconnected, @@ -200,6 +215,8 @@ func (conn *Conn) startHandshakeAndReconnect() { conn.log.Errorf("failed to send initial offer: %v", err) } + go conn.connMonitor.Start(conn.ctx) + if conn.workerRelay.IsController() { conn.reconnectLoopWithRetry() } else { @@ -309,12 +326,14 @@ func (conn *Conn) reconnectLoopWithRetry() { // With it, we can decrease to send necessary offer select { case <-conn.ctx.Done(): + return case <-time.After(3 * time.Second): } ticker := conn.prepareExponentTicker() defer ticker.Stop() time.Sleep(1 * time.Second) + for { select { case t := <-ticker.C: @@ -342,20 +361,11 @@ func (conn *Conn) reconnectLoopWithRetry() { if err != nil { conn.log.Errorf("failed to do handshake: %v", err) } - case changed := <-conn.relayDisconnected: - if !changed { - continue - } - conn.log.Debugf("Relay state changed, reset reconnect timer") - ticker.Stop() - ticker = conn.prepareExponentTicker() - case changed := <-conn.iCEDisconnected: - if !changed { - continue - } - conn.log.Debugf("ICE state changed, reset reconnect timer") + + case <-conn.reconnectCh: ticker.Stop() ticker = conn.prepareExponentTicker() + case <-conn.ctx.Done(): conn.log.Debugf("context is done, stop reconnect loop") return @@ -366,10 +376,10 @@ func (conn *Conn) reconnectLoopWithRetry() { func (conn *Conn) prepareExponentTicker() *backoff.Ticker { bo := backoff.WithContext(&backoff.ExponentialBackOff{ InitialInterval: 800 * time.Millisecond, - RandomizationFactor: 0.01, + RandomizationFactor: 0.1, Multiplier: 2, MaxInterval: conn.config.Timeout, - MaxElapsedTime: 0, + MaxElapsedTime: reconnectMaxElapsedTime, Stop: backoff.Stop, Clock: backoff.SystemClock, }, conn.ctx) diff --git a/client/internal/peer/conn_monitor.go b/client/internal/peer/conn_monitor.go new file mode 100644 index 000000000..75722c990 --- /dev/null +++ b/client/internal/peer/conn_monitor.go @@ -0,0 +1,212 @@ +package peer + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/pion/ice/v3" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/stdnet" +) + +const ( + signalerMonitorPeriod = 5 * time.Second + candidatesMonitorPeriod = 5 * time.Minute + candidateGatheringTimeout = 5 * time.Second +) + +type ConnMonitor struct { + signaler *Signaler + iFaceDiscover stdnet.ExternalIFaceDiscover + config ConnConfig + relayDisconnected chan bool + iCEDisconnected chan bool + reconnectCh chan struct{} + currentCandidates []ice.Candidate + candidatesMu sync.Mutex +} + +func NewConnMonitor(signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, config ConnConfig, relayDisconnected, iCEDisconnected chan bool) (*ConnMonitor, <-chan struct{}) { + reconnectCh := make(chan struct{}, 1) + cm := &ConnMonitor{ + signaler: signaler, + iFaceDiscover: iFaceDiscover, + config: config, + relayDisconnected: relayDisconnected, + iCEDisconnected: iCEDisconnected, + reconnectCh: reconnectCh, + } + return cm, reconnectCh +} + +func (cm *ConnMonitor) Start(ctx context.Context) { + signalerReady := make(chan struct{}, 1) + go cm.monitorSignalerReady(ctx, signalerReady) + + localCandidatesChanged := make(chan struct{}, 1) + go cm.monitorLocalCandidatesChanged(ctx, localCandidatesChanged) + + for { + select { + case changed := <-cm.relayDisconnected: + if !changed { + continue + } + log.Debugf("Relay state changed, triggering reconnect") + cm.triggerReconnect() + + case changed := <-cm.iCEDisconnected: + if !changed { + continue + } + log.Debugf("ICE state changed, triggering reconnect") + cm.triggerReconnect() + + case <-signalerReady: + log.Debugf("Signaler became ready, triggering reconnect") + cm.triggerReconnect() + + case <-localCandidatesChanged: + log.Debugf("Local candidates changed, triggering reconnect") + cm.triggerReconnect() + + case <-ctx.Done(): + return + } + } +} + +func (cm *ConnMonitor) monitorSignalerReady(ctx context.Context, signalerReady chan<- struct{}) { + if cm.signaler == nil { + return + } + + ticker := time.NewTicker(signalerMonitorPeriod) + defer ticker.Stop() + + lastReady := true + for { + select { + case <-ticker.C: + currentReady := cm.signaler.Ready() + if !lastReady && currentReady { + select { + case signalerReady <- struct{}{}: + default: + } + } + lastReady = currentReady + case <-ctx.Done(): + return + } + } +} + +func (cm *ConnMonitor) monitorLocalCandidatesChanged(ctx context.Context, localCandidatesChanged chan<- struct{}) { + ufrag, pwd, err := generateICECredentials() + if err != nil { + log.Warnf("Failed to generate ICE credentials: %v", err) + return + } + + ticker := time.NewTicker(candidatesMonitorPeriod) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if err := cm.handleCandidateTick(ctx, localCandidatesChanged, ufrag, pwd); err != nil { + log.Warnf("Failed to handle candidate tick: %v", err) + } + case <-ctx.Done(): + return + } + } +} + +func (cm *ConnMonitor) handleCandidateTick(ctx context.Context, localCandidatesChanged chan<- struct{}, ufrag string, pwd string) error { + log.Debugf("Gathering ICE candidates") + + transportNet, err := newStdNet(cm.iFaceDiscover, cm.config.ICEConfig.InterfaceBlackList) + if err != nil { + log.Errorf("failed to create pion's stdnet: %s", err) + } + + agent, err := newAgent(cm.config, transportNet, candidateTypesP2P(), ufrag, pwd) + if err != nil { + return fmt.Errorf("create ICE agent: %w", err) + } + defer func() { + if err := agent.Close(); err != nil { + log.Warnf("Failed to close ICE agent: %v", err) + } + }() + + gatherDone := make(chan struct{}) + err = agent.OnCandidate(func(c ice.Candidate) { + log.Tracef("Got candidate: %v", c) + if c == nil { + close(gatherDone) + } + }) + if err != nil { + return fmt.Errorf("set ICE candidate handler: %w", err) + } + + if err := agent.GatherCandidates(); err != nil { + return fmt.Errorf("gather ICE candidates: %w", err) + } + + ctx, cancel := context.WithTimeout(ctx, candidateGatheringTimeout) + defer cancel() + + select { + case <-ctx.Done(): + return fmt.Errorf("wait for gathering: %w", ctx.Err()) + case <-gatherDone: + } + + candidates, err := agent.GetLocalCandidates() + if err != nil { + return fmt.Errorf("get local candidates: %w", err) + } + log.Tracef("Got candidates: %v", candidates) + + if changed := cm.updateCandidates(candidates); changed { + select { + case localCandidatesChanged <- struct{}{}: + default: + } + } + + return nil +} + +func (cm *ConnMonitor) updateCandidates(newCandidates []ice.Candidate) bool { + cm.candidatesMu.Lock() + defer cm.candidatesMu.Unlock() + + if len(cm.currentCandidates) != len(newCandidates) { + cm.currentCandidates = newCandidates + return true + } + + for i, candidate := range cm.currentCandidates { + if candidate.Address() != newCandidates[i].Address() { + cm.currentCandidates = newCandidates + return true + } + } + + return false +} + +func (cm *ConnMonitor) triggerReconnect() { + select { + case cm.reconnectCh <- struct{}{}: + default: + } +} diff --git a/client/internal/peer/stdnet.go b/client/internal/peer/stdnet.go index ae31ebbf0..96d211dbc 100644 --- a/client/internal/peer/stdnet.go +++ b/client/internal/peer/stdnet.go @@ -6,6 +6,6 @@ import ( "github.com/netbirdio/netbird/client/internal/stdnet" ) -func (w *WorkerICE) newStdNet() (*stdnet.Net, error) { - return stdnet.NewNet(w.config.ICEConfig.InterfaceBlackList) +func newStdNet(_ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) { + return stdnet.NewNet(ifaceBlacklist) } diff --git a/client/internal/peer/stdnet_android.go b/client/internal/peer/stdnet_android.go index b411405bb..a39a03b1c 100644 --- a/client/internal/peer/stdnet_android.go +++ b/client/internal/peer/stdnet_android.go @@ -2,6 +2,6 @@ package peer import "github.com/netbirdio/netbird/client/internal/stdnet" -func (w *WorkerICE) newStdNet() (*stdnet.Net, error) { - return stdnet.NewNetWithDiscover(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList) +func newStdNet(iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) { + return stdnet.NewNetWithDiscover(iFaceDiscover, ifaceBlacklist) } diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index c4e9d1950..c86c1858f 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -233,41 +233,16 @@ func (w *WorkerICE) Close() { } func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, relaySupport []ice.CandidateType) (*ice.Agent, error) { - transportNet, err := w.newStdNet() + transportNet, err := newStdNet(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList) if err != nil { w.log.Errorf("failed to create pion's stdnet: %s", err) } - iceKeepAlive := iceKeepAlive() - iceDisconnectedTimeout := iceDisconnectedTimeout() - iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait() - - agentConfig := &ice.AgentConfig{ - MulticastDNSMode: ice.MulticastDNSModeDisabled, - NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}, - Urls: w.config.ICEConfig.StunTurn.Load().([]*stun.URI), - CandidateTypes: relaySupport, - InterfaceFilter: stdnet.InterfaceFilter(w.config.ICEConfig.InterfaceBlackList), - UDPMux: w.config.ICEConfig.UDPMux, - UDPMuxSrflx: w.config.ICEConfig.UDPMuxSrflx, - NAT1To1IPs: w.config.ICEConfig.NATExternalIPs, - Net: transportNet, - FailedTimeout: &failedTimeout, - DisconnectedTimeout: &iceDisconnectedTimeout, - KeepaliveInterval: &iceKeepAlive, - RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait, - LocalUfrag: w.localUfrag, - LocalPwd: w.localPwd, - } - - if w.config.ICEConfig.DisableIPv6Discovery { - agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4} - } - w.sentExtraSrflx = false - agent, err := ice.NewAgent(agentConfig) + + agent, err := newAgent(w.config, transportNet, relaySupport, w.localUfrag, w.localPwd) if err != nil { - return nil, err + return nil, fmt.Errorf("create agent: %w", err) } err = agent.OnCandidate(w.onICECandidate) @@ -390,6 +365,36 @@ func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferA } } +func newAgent(config ConnConfig, transportNet *stdnet.Net, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ice.Agent, error) { + iceKeepAlive := iceKeepAlive() + iceDisconnectedTimeout := iceDisconnectedTimeout() + iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait() + + agentConfig := &ice.AgentConfig{ + MulticastDNSMode: ice.MulticastDNSModeDisabled, + NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}, + Urls: config.ICEConfig.StunTurn.Load().([]*stun.URI), + CandidateTypes: candidateTypes, + InterfaceFilter: stdnet.InterfaceFilter(config.ICEConfig.InterfaceBlackList), + UDPMux: config.ICEConfig.UDPMux, + UDPMuxSrflx: config.ICEConfig.UDPMuxSrflx, + NAT1To1IPs: config.ICEConfig.NATExternalIPs, + Net: transportNet, + FailedTimeout: &failedTimeout, + DisconnectedTimeout: &iceDisconnectedTimeout, + KeepaliveInterval: &iceKeepAlive, + RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait, + LocalUfrag: ufrag, + LocalPwd: pwd, + } + + if config.ICEConfig.DisableIPv6Discovery { + agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4} + } + + return ice.NewAgent(agentConfig) +} + func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) { relatedAdd := candidate.RelatedAddress() return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ From d4ef84fe6e02e932fdfa24be2bcf2416634f832b Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Wed, 9 Oct 2024 14:33:58 +0200 Subject: [PATCH 25/37] [management] Propagate error in store errors (#2709) --- management/server/sql_store.go | 46 +++++++++++++++---------------- management/server/status/error.go | 8 ++++-- 2 files changed, 29 insertions(+), 25 deletions(-) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index d056015d8..67df29ef0 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -431,7 +431,7 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") } log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error) - return "", status.Errorf(status.Internal, "issue getting account from store") + return "", status.NewGetAccountFromStoreError(result.Error) } return accountID, nil @@ -444,7 +444,7 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (* if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - return nil, status.NewSetupKeyNotFoundError() + return nil, status.NewSetupKeyNotFoundError(result.Error) } if key.AccountID == "" { @@ -462,7 +462,7 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error) - return "", status.Errorf(status.Internal, "issue getting account from store") + return "", status.NewGetAccountFromStoreError(result.Error) } return token.ID, nil @@ -476,7 +476,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error) - return nil, status.Errorf(status.Internal, "issue getting account from store") + return nil, status.NewGetAccountFromStoreError(result.Error) } if token.UserID == "" { @@ -560,7 +560,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewAccountNotFoundError(accountID) } - return nil, status.Errorf(status.Internal, "issue getting account from store") + return nil, status.NewGetAccountFromStoreError(result.Error) } // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us @@ -623,7 +623,7 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - return nil, status.Errorf(status.Internal, "issue getting account from store") + return nil, status.NewGetAccountFromStoreError(result.Error) } if user.AccountID == "" { @@ -640,7 +640,7 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - return nil, status.Errorf(status.Internal, "issue getting account from store") + return nil, status.NewGetAccountFromStoreError(result.Error) } if peer.AccountID == "" { @@ -658,7 +658,7 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) ( if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - return nil, status.Errorf(status.Internal, "issue getting account from store") + return nil, status.NewGetAccountFromStoreError(result.Error) } if peer.AccountID == "" { @@ -676,7 +676,7 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - return "", status.Errorf(status.Internal, "issue getting account from store") + return "", status.NewGetAccountFromStoreError(result.Error) } return accountID, nil @@ -689,7 +689,7 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - return "", status.Errorf(status.Internal, "issue getting account from store") + return "", status.NewGetAccountFromStoreError(result.Error) } return accountID, nil @@ -702,7 +702,7 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - return "", status.NewSetupKeyNotFoundError() + return "", status.NewSetupKeyNotFoundError(result.Error) } if accountID == "" { @@ -723,7 +723,7 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "no peers found for the account") } - return nil, status.Errorf(status.Internal, "issue getting IPs from store") + return nil, status.Errorf(status.Internal, "issue getting IPs from store: %s", result.Error) } // Convert the JSON strings to net.IP objects @@ -751,7 +751,7 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock return nil, status.Errorf(status.NotFound, "no peers found for the account") } log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error) - return nil, status.Errorf(status.Internal, "issue getting dns labels from store") + return nil, status.Errorf(status.Internal, "issue getting dns labels from store: %s", result.Error) } return labels, nil @@ -764,7 +764,7 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.NewAccountNotFoundError(accountID) } - return nil, status.Errorf(status.Internal, "issue getting network from store") + return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err) } return accountNetwork.Network, nil } @@ -776,7 +776,7 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "peer not found") } - return nil, status.Errorf(status.Internal, "issue getting peer from store") + return nil, status.Errorf(status.Internal, "issue getting peer from store: %s", result.Error) } return &peer, nil @@ -788,7 +788,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "settings not found") } - return nil, status.Errorf(status.Internal, "issue getting settings from store") + return nil, status.Errorf(status.Internal, "issue getting settings from store: %s", err) } return accountSettings.Settings, nil } @@ -956,7 +956,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "setup key not found") } - return nil, status.NewSetupKeyNotFoundError() + return nil, status.NewSetupKeyNotFoundError(result.Error) } return &setupKey, nil } @@ -988,7 +988,7 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.Errorf(status.NotFound, "group 'All' not found for account") } - return status.Errorf(status.Internal, "issue finding group 'All'") + return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error) } for _, existingPeerID := range group.Peers { @@ -1000,7 +1000,7 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer group.Peers = append(group.Peers, peerID) if err := s.db.Save(&group).Error; err != nil { - return status.Errorf(status.Internal, "issue updating group 'All'") + return status.Errorf(status.Internal, "issue updating group 'All': %s", err) } return nil @@ -1014,7 +1014,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.Errorf(status.NotFound, "group not found for account") } - return status.Errorf(status.Internal, "issue finding group") + return status.Errorf(status.Internal, "issue finding group: %s", result.Error) } for _, existingPeerID := range group.Peers { @@ -1026,7 +1026,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId group.Peers = append(group.Peers, peerId) if err := s.db.Save(&group).Error; err != nil { - return status.Errorf(status.Internal, "issue updating group") + return status.Errorf(status.Internal, "issue updating group: %s", err) } return nil @@ -1039,7 +1039,7 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { if err := s.db.WithContext(ctx).Create(peer).Error; err != nil { - return status.Errorf(status.Internal, "issue adding peer to account") + return status.Errorf(status.Internal, "issue adding peer to account: %s", err) } return nil @@ -1048,7 +1048,7 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) if result.Error != nil { - return status.Errorf(status.Internal, "issue incrementing network serial count") + return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error) } return nil } diff --git a/management/server/status/error.go b/management/server/status/error.go index d7fde35b9..29d185216 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -102,8 +102,12 @@ func NewPeerLoginExpiredError() error { } // NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key -func NewSetupKeyNotFoundError() error { - return Errorf(NotFound, "setup key not found") +func NewSetupKeyNotFoundError(err error) error { + return Errorf(NotFound, "setup key not found: %s", err) +} + +func NewGetAccountFromStoreError(err error) error { + return Errorf(Internal, "issue getting account from store: %s", err) } // NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store From b1eda43f4b748ee9940e9a2399f200b60e55f076 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Wed, 9 Oct 2024 13:56:25 +0100 Subject: [PATCH 26/37] Add Link to the Lawrence Systems video (#2711) --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index aa3ec41e5..270c9ad87 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,8 @@ ![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab) +### NetBird on Lawrence Systems (Video) +[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw) ### Key features @@ -62,6 +64,7 @@ | | |
  • - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)
| |
  • - \[x] OpenWRT
| | | |
  • - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)
  • | |
    • - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas)
    | | | | | |
    • - \[x] Docker
    | + ### Quickstart with NetBird Cloud - Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install) From b79c1d64cc10871e70c94ee5f47fdf2b4773f5a7 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Wed, 9 Oct 2024 20:17:25 +0200 Subject: [PATCH 27/37] [management] Make max open db conns configurable (#2713) --- .github/workflows/golang-test-linux.yml | 2 +- .github/workflows/release.yml | 2 +- management/server/sql_store.go | 11 +++++++++-- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 524f35f6f..d6adcb27a 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -16,7 +16,7 @@ jobs: matrix: arch: [ '386','amd64' ] store: [ 'sqlite', 'postgres'] - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Install Go uses: actions/setup-go@v5 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7af6d3e4d..b2e2437e6 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -20,7 +20,7 @@ concurrency: jobs: release: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 env: flags: "" steps: diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 67df29ef0..fe4dcafdb 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -10,6 +10,7 @@ import ( "path/filepath" "runtime" "runtime/debug" + "strconv" "strings" "sync" "time" @@ -63,8 +64,14 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine StoreEngine, metr if err != nil { return nil, err } - conns := runtime.NumCPU() - sql.SetMaxOpenConns(conns) // TODO: make it configurable + + conns, err := strconv.Atoi(os.Getenv("NB_SQL_MAX_OPEN_CONNS")) + if err != nil { + conns = runtime.NumCPU() + } + sql.SetMaxOpenConns(conns) + + log.Infof("Set max open db connections to %d", conns) if err := migrate(ctx, db); err != nil { return nil, fmt.Errorf("migrate: %w", err) From 6ce09bca1680c50012cd9467317290808b620808 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 9 Oct 2024 20:46:23 +0200 Subject: [PATCH 28/37] Add support to envsub go management configurations (#2708) This change allows users to reference environment variables using Go template format, like {{ .EnvName }} Moved the previous file test code to file_suite_test.go. --- management/cmd/management.go | 2 +- util/file.go | 53 +++++++ util/file_suite_test.go | 126 +++++++++++++++ util/file_test.go | 288 ++++++++++++++++++++++------------- 4 files changed, 360 insertions(+), 109 deletions(-) create mode 100644 util/file_suite_test.go diff --git a/management/cmd/management.go b/management/cmd/management.go index 78b1a8d63..719d1a78c 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -475,7 +475,7 @@ func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handle func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config, error) { loadedConfig := &server.Config{} - _, err := util.ReadJson(mgmtConfigPath, loadedConfig) + _, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig) if err != nil { return nil, err } diff --git a/util/file.go b/util/file.go index 8355488c9..ecaecd222 100644 --- a/util/file.go +++ b/util/file.go @@ -1,11 +1,15 @@ package util import ( + "bytes" "context" "encoding/json" + "fmt" "io" "os" "path/filepath" + "strings" + "text/template" log "github.com/sirupsen/logrus" ) @@ -160,6 +164,55 @@ func ReadJson(file string, res interface{}) (interface{}, error) { return res, nil } +// ReadJsonWithEnvSub reads JSON config file and maps to a provided interface with environment variable substitution +func ReadJsonWithEnvSub(file string, res interface{}) (interface{}, error) { + envVars := getEnvMap() + + f, err := os.Open(file) + if err != nil { + return nil, err + } + defer f.Close() + + bs, err := io.ReadAll(f) + if err != nil { + return nil, err + } + + t, err := template.New("").Parse(string(bs)) + if err != nil { + return nil, fmt.Errorf("error parsing template: %v", err) + } + + var output bytes.Buffer + // Execute the template, substituting environment variables + err = t.Execute(&output, envVars) + if err != nil { + return nil, fmt.Errorf("error executing template: %v", err) + } + + err = json.Unmarshal(output.Bytes(), &res) + if err != nil { + return nil, fmt.Errorf("failed parsing Json file after template was executed, err: %v", err) + } + + return res, nil +} + +// getEnvMap Convert the output of os.Environ() to a map +func getEnvMap() map[string]string { + envMap := make(map[string]string) + + for _, env := range os.Environ() { + parts := strings.SplitN(env, "=", 2) + if len(parts) == 2 { + envMap[parts[0]] = parts[1] + } + } + + return envMap +} + // CopyFileContents copies contents of the given src file to the dst file func CopyFileContents(src, dst string) (err error) { in, err := os.Open(src) diff --git a/util/file_suite_test.go b/util/file_suite_test.go new file mode 100644 index 000000000..3de7db49b --- /dev/null +++ b/util/file_suite_test.go @@ -0,0 +1,126 @@ +package util_test + +import ( + "crypto/md5" + "encoding/hex" + "io" + "os" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "github.com/netbirdio/netbird/util" +) + +var _ = Describe("Client", func() { + + var ( + tmpDir string + ) + + type TestConfig struct { + SomeMap map[string]string + SomeArray []string + SomeField int + } + + BeforeEach(func() { + var err error + tmpDir, err = os.MkdirTemp("", "wiretrustee_util_test_tmp_*") + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + err := os.RemoveAll(tmpDir) + Expect(err).NotTo(HaveOccurred()) + }) + + Describe("Config", func() { + Context("in JSON format", func() { + It("should be written and read successfully", func() { + + m := make(map[string]string) + m["key1"] = "value1" + m["key2"] = "value2" + + arr := []string{"value1", "value2"} + + written := &TestConfig{ + SomeMap: m, + SomeArray: arr, + SomeField: 99, + } + + err := util.WriteJson(tmpDir+"/testconfig.json", written) + Expect(err).NotTo(HaveOccurred()) + + read, err := util.ReadJson(tmpDir+"/testconfig.json", &TestConfig{}) + Expect(err).NotTo(HaveOccurred()) + Expect(read).NotTo(BeNil()) + Expect(read.(*TestConfig).SomeMap["key1"]).To(BeEquivalentTo(written.SomeMap["key1"])) + Expect(read.(*TestConfig).SomeMap["key2"]).To(BeEquivalentTo(written.SomeMap["key2"])) + Expect(read.(*TestConfig).SomeArray).To(ContainElements(arr)) + Expect(read.(*TestConfig).SomeField).To(BeEquivalentTo(written.SomeField)) + + }) + }) + }) + + Describe("Copying file contents", func() { + Context("from one file to another", func() { + It("should be successful", func() { + + src := tmpDir + "/copytest_src" + dst := tmpDir + "/copytest_dst" + + err := util.WriteJson(src, []string{"1", "2", "3"}) + Expect(err).NotTo(HaveOccurred()) + + err = util.CopyFileContents(src, dst) + Expect(err).NotTo(HaveOccurred()) + + hashSrc := md5.New() + hashDst := md5.New() + + srcFile, err := os.Open(src) + Expect(err).NotTo(HaveOccurred()) + + dstFile, err := os.Open(dst) + Expect(err).NotTo(HaveOccurred()) + + _, err = io.Copy(hashSrc, srcFile) + Expect(err).NotTo(HaveOccurred()) + + _, err = io.Copy(hashDst, dstFile) + Expect(err).NotTo(HaveOccurred()) + + err = srcFile.Close() + Expect(err).NotTo(HaveOccurred()) + + err = dstFile.Close() + Expect(err).NotTo(HaveOccurred()) + + Expect(hex.EncodeToString(hashSrc.Sum(nil)[:16])).To(BeEquivalentTo(hex.EncodeToString(hashDst.Sum(nil)[:16]))) + }) + }) + }) + + Describe("Handle config file without full path", func() { + Context("config file handling", func() { + It("should be successful", func() { + written := &TestConfig{ + SomeField: 123, + } + cfgFile := "test_cfg.json" + defer os.Remove(cfgFile) + + err := util.WriteJson(cfgFile, written) + Expect(err).NotTo(HaveOccurred()) + + read, err := util.ReadJson(cfgFile, &TestConfig{}) + Expect(err).NotTo(HaveOccurred()) + Expect(read).NotTo(BeNil()) + }) + }) + }) +}) diff --git a/util/file_test.go b/util/file_test.go index 3de7db49b..1330e738e 100644 --- a/util/file_test.go +++ b/util/file_test.go @@ -1,126 +1,198 @@ -package util_test +package util import ( - "crypto/md5" - "encoding/hex" - "io" "os" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" - - "github.com/netbirdio/netbird/util" + "reflect" + "strings" + "testing" ) -var _ = Describe("Client", func() { - - var ( - tmpDir string - ) - - type TestConfig struct { - SomeMap map[string]string - SomeArray []string - SomeField int +func TestReadJsonWithEnvSub(t *testing.T) { + type Config struct { + CertFile string `json:"CertFile"` + Credentials string `json:"Credentials"` + NestedOption struct { + URL string `json:"URL"` + } `json:"NestedOption"` } - BeforeEach(func() { - var err error - tmpDir, err = os.MkdirTemp("", "wiretrustee_util_test_tmp_*") - Expect(err).NotTo(HaveOccurred()) - }) + type testCase struct { + name string + envVars map[string]string + jsonTemplate string + expectedResult Config + expectError bool + errorContains string + } - AfterEach(func() { - err := os.RemoveAll(tmpDir) - Expect(err).NotTo(HaveOccurred()) - }) + tests := []testCase{ + { + name: "All environment variables set", + envVars: map[string]string{ + "CERT_FILE": "/etc/certs/env_cert.crt", + "CREDENTIALS": "env_credentials", + "URL": "https://env.testing.com", + }, + jsonTemplate: `{ + "CertFile": "{{ .CERT_FILE }}", + "Credentials": "{{ .CREDENTIALS }}", + "NestedOption": { + "URL": "{{ .URL }}" + } + }`, + expectedResult: Config{ + CertFile: "/etc/certs/env_cert.crt", + Credentials: "env_credentials", + NestedOption: struct { + URL string `json:"URL"` + }{ + URL: "https://env.testing.com", + }, + }, + expectError: false, + }, + { + name: "Missing environment variable", + envVars: map[string]string{ + "CERT_FILE": "/etc/certs/env_cert.crt", + "CREDENTIALS": "env_credentials", + // "URL" is intentionally missing + }, + jsonTemplate: `{ + "CertFile": "{{ .CERT_FILE }}", + "Credentials": "{{ .CREDENTIALS }}", + "NestedOption": { + "URL": "{{ .URL }}" + } + }`, + expectedResult: Config{ + CertFile: "/etc/certs/env_cert.crt", + Credentials: "env_credentials", + NestedOption: struct { + URL string `json:"URL"` + }{ + URL: "", + }, + }, + expectError: false, + }, + { + name: "Invalid JSON template", + envVars: map[string]string{ + "CERT_FILE": "/etc/certs/env_cert.crt", + "CREDENTIALS": "env_credentials", + "URL": "https://env.testing.com", + }, + jsonTemplate: `{ + "CertFile": "{{ .CERT_FILE }}", + "Credentials": "{{ .CREDENTIALS }", + "NestedOption": { + "URL": "{{ .URL }}" + } + }`, // Note the missing closing brace in "{{ .CREDENTIALS }" + expectedResult: Config{}, + expectError: true, + errorContains: "unexpected \"}\" in operand", + }, + { + name: "No substitutions", + envVars: map[string]string{ + "CERT_FILE": "/etc/certs/env_cert.crt", + "CREDENTIALS": "env_credentials", + "URL": "https://env.testing.com", + }, + jsonTemplate: `{ + "CertFile": "/etc/certs/cert.crt", + "Credentials": "admnlknflkdasdf", + "NestedOption" : { + "URL": "https://testing.com" + } + }`, + expectedResult: Config{ + CertFile: "/etc/certs/cert.crt", + Credentials: "admnlknflkdasdf", + NestedOption: struct { + URL string `json:"URL"` + }{ + URL: "https://testing.com", + }, + }, + expectError: false, + }, + { + name: "Should fail when Invalid characters in variables", + envVars: map[string]string{ + "CERT_FILE": `"/etc/certs/"cert".crt"`, + "CREDENTIALS": `env_credentia{ls}`, + "URL": `https://env.testing.com?param={{value}}`, + }, + jsonTemplate: `{ + "CertFile": "{{ .CERT_FILE }}", + "Credentials": "{{ .CREDENTIALS }}", + "NestedOption": { + "URL": "{{ .URL }}" + } + }`, + expectedResult: Config{ + CertFile: `"/etc/certs/"cert".crt"`, + Credentials: `env_credentia{ls}`, + NestedOption: struct { + URL string `json:"URL"` + }{ + URL: `https://env.testing.com?param={{value}}`, + }, + }, + expectError: true, + }, + } - Describe("Config", func() { - Context("in JSON format", func() { - It("should be written and read successfully", func() { + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + for key, value := range tc.envVars { + t.Setenv(key, value) + } - m := make(map[string]string) - m["key1"] = "value1" - m["key2"] = "value2" + tempFile, err := os.CreateTemp("", "config*.json") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } - arr := []string{"value1", "value2"} - - written := &TestConfig{ - SomeMap: m, - SomeArray: arr, - SomeField: 99, + defer func() { + err = os.Remove(tempFile.Name()) + if err != nil { + t.Logf("Failed to remove temp file: %v", err) } + }() - err := util.WriteJson(tmpDir+"/testconfig.json", written) - Expect(err).NotTo(HaveOccurred()) + _, err = tempFile.WriteString(tc.jsonTemplate) + if err != nil { + t.Fatalf("Failed to write to temp file: %v", err) + } + err = tempFile.Close() + if err != nil { + t.Fatalf("Failed to close temp file: %v", err) + } - read, err := util.ReadJson(tmpDir+"/testconfig.json", &TestConfig{}) - Expect(err).NotTo(HaveOccurred()) - Expect(read).NotTo(BeNil()) - Expect(read.(*TestConfig).SomeMap["key1"]).To(BeEquivalentTo(written.SomeMap["key1"])) - Expect(read.(*TestConfig).SomeMap["key2"]).To(BeEquivalentTo(written.SomeMap["key2"])) - Expect(read.(*TestConfig).SomeArray).To(ContainElements(arr)) - Expect(read.(*TestConfig).SomeField).To(BeEquivalentTo(written.SomeField)) + var result Config - }) - }) - }) + _, err = ReadJsonWithEnvSub(tempFile.Name(), &result) - Describe("Copying file contents", func() { - Context("from one file to another", func() { - It("should be successful", func() { - - src := tmpDir + "/copytest_src" - dst := tmpDir + "/copytest_dst" - - err := util.WriteJson(src, []string{"1", "2", "3"}) - Expect(err).NotTo(HaveOccurred()) - - err = util.CopyFileContents(src, dst) - Expect(err).NotTo(HaveOccurred()) - - hashSrc := md5.New() - hashDst := md5.New() - - srcFile, err := os.Open(src) - Expect(err).NotTo(HaveOccurred()) - - dstFile, err := os.Open(dst) - Expect(err).NotTo(HaveOccurred()) - - _, err = io.Copy(hashSrc, srcFile) - Expect(err).NotTo(HaveOccurred()) - - _, err = io.Copy(hashDst, dstFile) - Expect(err).NotTo(HaveOccurred()) - - err = srcFile.Close() - Expect(err).NotTo(HaveOccurred()) - - err = dstFile.Close() - Expect(err).NotTo(HaveOccurred()) - - Expect(hex.EncodeToString(hashSrc.Sum(nil)[:16])).To(BeEquivalentTo(hex.EncodeToString(hashDst.Sum(nil)[:16]))) - }) - }) - }) - - Describe("Handle config file without full path", func() { - Context("config file handling", func() { - It("should be successful", func() { - written := &TestConfig{ - SomeField: 123, + if tc.expectError { + if err == nil { + t.Fatalf("Expected error but got none") } - cfgFile := "test_cfg.json" - defer os.Remove(cfgFile) - - err := util.WriteJson(cfgFile, written) - Expect(err).NotTo(HaveOccurred()) - - read, err := util.ReadJson(cfgFile, &TestConfig{}) - Expect(err).NotTo(HaveOccurred()) - Expect(read).NotTo(BeNil()) - }) + if !strings.Contains(err.Error(), tc.errorContains) { + t.Errorf("Expected error containing '%s', but got '%v'", tc.errorContains, err) + } + } else { + if err != nil { + t.Fatalf("ReadJsonWithEnvSub failed: %v", err) + } + if !reflect.DeepEqual(result, tc.expectedResult) { + t.Errorf("Result does not match expected.\nGot: %+v\nExpected: %+v", result, tc.expectedResult) + } + } }) - }) -}) + } +} From 8284ae959cd38f0b8c8cb7b7b711699a21c68417 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 10 Oct 2024 12:35:03 +0200 Subject: [PATCH 29/37] [management] Move testdata to sql files (#2693) --- .github/workflows/golang-test-darwin.yml | 2 +- .github/workflows/golang-test-linux.yml | 2 +- client/cmd/testutil_test.go | 4 +- client/internal/engine_test.go | 4 +- client/server/server_test.go | 2 +- client/testdata/store.sql | 36 +++ client/testdata/store.sqlite | Bin 163840 -> 0 bytes management/client/client_test.go | 22 +- management/server/account_test.go | 2 +- management/server/dns_test.go | 2 +- management/server/management_proto_test.go | 11 +- management/server/management_test.go | 4 +- management/server/nameserver_test.go | 2 +- management/server/peer_test.go | 8 +- management/server/route_test.go | 2 +- management/server/sql_store.go | 22 -- management/server/sql_store_test.go | 302 ++++++++---------- management/server/store.go | 65 ++-- management/server/testdata/extended-store.sql | 37 +++ .../server/testdata/extended-store.sqlite | Bin 163840 -> 0 bytes management/server/testdata/store.sql | 33 ++ management/server/testdata/store.sqlite | Bin 163840 -> 0 bytes .../server/testdata/store_policy_migrate.sql | 35 ++ .../testdata/store_policy_migrate.sqlite | Bin 163840 -> 0 bytes .../testdata/store_with_expired_peers.sql | 35 ++ .../testdata/store_with_expired_peers.sqlite | Bin 163840 -> 0 bytes management/server/testdata/storev1.sql | 39 +++ management/server/testdata/storev1.sqlite | Bin 163840 -> 0 bytes 28 files changed, 420 insertions(+), 251 deletions(-) create mode 100644 client/testdata/store.sql delete mode 100644 client/testdata/store.sqlite create mode 100644 management/server/testdata/extended-store.sql delete mode 100644 management/server/testdata/extended-store.sqlite create mode 100644 management/server/testdata/store.sql delete mode 100644 management/server/testdata/store.sqlite create mode 100644 management/server/testdata/store_policy_migrate.sql delete mode 100644 management/server/testdata/store_policy_migrate.sqlite create mode 100644 management/server/testdata/store_with_expired_peers.sql delete mode 100644 management/server/testdata/store_with_expired_peers.sqlite create mode 100644 management/server/testdata/storev1.sql delete mode 100644 management/server/testdata/storev1.sqlite diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml index 2aaef7564..88db8c5e8 100644 --- a/.github/workflows/golang-test-darwin.yml +++ b/.github/workflows/golang-test-darwin.yml @@ -42,4 +42,4 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./... + run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./... diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index d6adcb27a..e1e1ff236 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -49,7 +49,7 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./... + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./... test_client_on_docker: runs-on: ubuntu-20.04 diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 033d1bb6a..d998f9ea9 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -38,7 +38,7 @@ func startTestingServices(t *testing.T) string { signalAddr := signalLis.Addr().String() config.Signal.URI = signalAddr - _, mgmLis := startManagement(t, config, "../testdata/store.sqlite") + _, mgmLis := startManagement(t, config, "../testdata/store.sql") mgmAddr := mgmLis.Addr().String() return mgmAddr } @@ -71,7 +71,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc. t.Fatal(err) } s := grpc.NewServer() - store, cleanUp, err := mgmt.NewTestStoreFromSqlite(context.Background(), testFile, t.TempDir()) + store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir()) if err != nil { t.Fatal(err) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 3d1983c6b..74b10ee44 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -832,7 +832,7 @@ func TestEngine_MultiplePeers(t *testing.T) { return } defer sigServer.Stop() - mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sqlite") + mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sql") if err != nil { t.Fatal(err) return @@ -1080,7 +1080,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := server.NewTestStoreFromSqlite(context.Background(), testFile, config.Datadir) + store, cleanUp, err := server.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir) if err != nil { return nil, "", err } diff --git a/client/server/server_test.go b/client/server/server_test.go index e534ad7e2..61bdaf660 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -110,7 +110,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve return nil, "", err } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := server.NewTestStoreFromSqlite(context.Background(), "", config.Datadir) + store, cleanUp, err := server.NewTestStoreFromSQL(context.Background(), "", config.Datadir) if err != nil { return nil, "", err } diff --git a/client/testdata/store.sql b/client/testdata/store.sql new file mode 100644 index 000000000..ed5395486 --- /dev/null +++ b/client/testdata/store.sql @@ -0,0 +1,36 @@ +PRAGMA foreign_keys=OFF; +BEGIN TRANSACTION; +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE); +CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`)); +CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text); +CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); +CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); +CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); +CREATE INDEX `idx_peers_key` ON `peers`(`key`); +CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`); +CREATE INDEX `idx_users_account_id` ON `users`(`account_id`); +CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`); +CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`); +CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`); +CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`); +CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`); +CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`); +CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); + +INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 21:28:24.830195+02:00','','',0,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0); +INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,''); +INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,''); +INSERT INTO installations VALUES(1,''); + +COMMIT; diff --git a/client/testdata/store.sqlite b/client/testdata/store.sqlite deleted file mode 100644 index 118c2bebc9f1fd29751627c36304d301ba156781..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 163840 zcmeI5Piz}ke#beIMbfe*#>plVXQN$8iDR>oY?6`{$qU1kWjbrbmJ`Xwu3-ej5&1|> zM9wfXL)k*G1!yP1Vu512XpTMf)I$&3Vu3d2JrzCdvA|x7_S#zy1+s_!=FPm}d*slP z>)j$DzO>BYy!U(mKEL1l{ob45pPi3xTRzJ-9Jg)y`Q_A0DNRfLc|M;?rQW3f7wEru zyg@Hch!Z+$((6$#-%MR}|8^=6&V3Weyqf!F`ZuQ^PG3n+P3NY*p4yq*o4PQ0LNh16 zPW?9Z75$^fyDz8pYiG5TW!Kqb@6#4_&@j8c!_SST>vTJwQ3{W*^yXg5=pU_Xy{kOi zsy{Y5%=H}GY#C;)#yrpPoqc9|M%QDmVbzm!&ung2HttpOx3+FnewwcyT}?HAcPn2_ zuB>{Y8Z}leJlUDe=ZVVk!Lx6**w55nFR7B}y1?cCemxV3dJUv2DjU3f;vGC@2iBa5}AD)khf|*P2n{hnd~!L38Rb)tvbGsbO;(loLLN zN)MRR?UQN!yslZX#fC-q8Nngx=}~o)fLdvgQ22Q!!;SNo?Z`(_6}+CTYMeZ+dW3QE zCPHuIF~8d}qy(!Y4699MY3w3FY&wSHDPsG~wOLC^syVJOf+{c7X_KhyFL75&CJP3G_n#T`C}o8vQdIHDfk&1IkE=EB-3jM6MRd63TN*XFg} z+1Llg(_Ido{lv_$Q|f0*!8UYI{j34k_w=0;8GU|U`%>r5sxQ}nHKn?(e#Ue7E_tFz zA!kIS3(cIyt*ni@onT2+A4qlC0WC;nvn5lfg&8fo$YaDYFlk{`O$gCl&ulWIWwkA8 zJ1(7Iy3b)oePSSH$f008+;rB%ndHBJkF#brE5<~$~fN$AN z>W1ZUpK4H)A+ri%$-s&*7UjPF7!LQ@1*DF|f-Y-zTc#Ur5)Lf($QWLS#?sM#km5(j zY>x8!vGgYt?V}RIu^~yS)-tWO;;a=Zv(<7QjY34rUe~1UC%)aB3#DT5;_Ax9;>F^^;=&!%+o!Zkm)15mHm+V> z_}Rmcx4!z--BfCtI;Q1*pQ1lJKmY_l00ck)1V8`;KmY_l00ck)1ioVg&QIPyd7KLl z;r{;*soWpFV*^BYK>!3m00ck)1V8`;KmY_l00ck)1dc^ua&qQm`2PP~DwjJ}5eR_* z2!H?xfB*=900@8p2!H?xfWSBq=-FCs=KlS+=5}^=ve|2Avf1O@MBgZHu2ojc*O!*p zZeCtmxqNwf>H2c1vUKCd>gwj|`s!xo`t=*EVRl=7p1uvqW%QjL^u-hOX8+Psac!xz zo-Y@#tgK!s)0_2+m)Dn9*UIJa{r=_UKxBQXv^HAg=vs>O>4GKxmoFCSQv~$)E`NUb zG4B75gV|yXAOHd&00JNY0w4eaAOHd&00JN|nm};>AD{mp%?CLk00JNY0w4eaAOHd& z00JNY0wC~hCve=){|E2?ryu>p0|Y<-1V8`;KmY_l00ck)1V8`;Kwyjs1o!{3{~u!o zV`v}%0w4eaAOHd&00JNY0w4eaAV37L|A!fX00@8p2!H?xfB*=900@8p2!O!&6Ttp| z{B?{Gf&d7B00@8p2!H?xfB*=900@9U@cI9V*?&#t>hy#M2!H?xfB*=900@8p2!H?x zfB*=5+X?hupP0RW|J7`E=1lh5oma9mjg`h~ef7Qa(&}0Qiavk<2!H?xfB*=900@8p2!H?xfPg|E*#95r zO$&Vge;h&l{-44ZCm;X#3c|znss$B&$jB1jSh1?$2MDrS*tP6Gkj;C z*`5*n$f9A@L&SgDHj;aAW^=o;aj%lUwRNNN(|q;lu2%DRxAN8G%BmNt(SWLjCp)tl zeSKZ~e8V?)TP(Ria%wK>)4h%Bw=2A@KWb$OdZ^|v(9o-SpFQ>$<1s|1w%KMPYtQud zSUuF9NMnyXmTUT!V;gnTC(+>9x7uv6T5}n_t{c0DLRh4jzAV==JwLDo>Ox%j!|hvl zHntz+Kd3widskT8yt}n?Z+qj`*1deSvCoaiGdiZ>dq%UGzj=4Na_jxAs6ecU{C4GL zWxKMqS=q_+p?Rbx8f>AkJCo5b>soIrw$h-G7nrEiwSBTnem*;FB7<_FQSel%ykDLg zIi1$8>Z9gOSWk%`W}aR>mCwxrJ37F5k! zp4SZu2bPz=QMtKs|MtE70{wXE!eZ63eb$UD%ViDHM>WsWn46F(3Hyf~G5DjcC}J1- zq69;d*cKF95(Q&(njSQ#4pYsEpPw2wr$IU4W2p3iIo&>)*3avj6OXtv5rby}w1aNM>cHBgV^83!~na!`uYBR?^@Yb7R!ok`5H zng7Uf_YJdNcPX|pbrjyA=d*S&11$=FG)Z=utMnrh=(tY3Tl1Ca#)lZ?w@tg-Fl&C- z9VpN21Iu;nHnYj6h6>>28dj^{6&ZJ(0&mOQ?V}C!Yk0otcY_Ra#U5kI+?54tj%~9V zd39;V%lUlb8$2kClyTXoU5i*5W`~A-V3Lw!?-95+KGh>z^7dp)^5~S=K}40?m)Z;i zQG!S?u46hS*4bm^A(d{CmkXczFAAT6gk{Oqvq6OL+xi;{-L7ZIH$=Ro}u`- z=kQ&8h7Je8WBaZ`_E}*nti8k!ClAsY{o1_NI~)6;c)H7>v!9q5c1rzBDcFV%s-HFB z`kuaXBBRgGYhUWzS@q@muclPD)z5ei-z85JDUOT?U!j@PxRvEhw-c;l>I11RJD^3c zY_??Tw3eck5_ya`uqCa9stIwo>zPevw5+y8ZO5e(O!qm=s88$#Y-xEat-n7%YRMxt zQR=@=?@ng)@WqE(MxUG0enGB| z9Egw*#|hPXM6zl-V*Ps5iZ*I9^2gzsK^}xP9XDPK)50vu>9|(gB)1VYD!D@uT&z0z zcdKFboxjMTqwxD8i9GyZQW$?EO9`R?DZsbvCUwK|xKA}G%8*$Fv1DMG7mIRVe+-BF z>;h6pVnLTRyDig=*7XM#dt?kRE@SCv&q?v4V>U;5{aE@FiuO^7;nmX`vV2yh8OE3PjT7pq-z2^0`5@*s|@HZVMq3zsI+`sUoIxh725 zD2Z2-wA0Q8yX~3m?^1f|$Ej1L>0h1vuPfCO8=zT$?q z)zYA!y&Ky6L0bPwX_RFSvo%V?Y4eRYGx~dHwcg3txadNYAFLbWEqa*n!(k&Dk__WQ zm^@%CPtU!P(a)aMz6^~elnI~bRHKNB=jH1q{muwIMt)eF4k&S-v5vtjk`zrvq*ZT% z^#txjRLEV6YC6L8hijI2J46)=wg4ai}ufsFd0=^vKPu2a6eB@e~&LrpYHdKg``9adV-Mz zl^o@Sp4UP`{gaJ*!9@n6w#RDwLa|3{F(MO9EZ#r$ifkVpX-mE+D@sP1EL?i^S&JG$ z#Rzpq4ZVKqFTzDlWOM#Yo2&+ypt`CDAMVTIxI z%c^3dww^<+^3~Htx+xJ{FpT_zSWX7lDgGZWo`!vY0oVTA_l%P7o&3q}=X{V$>&E&> zXF}u?Zg0^glce;QQYr1dKibWqgoahuFmOdpjvU{u`#GPS%jlObX}!sCNk4emRV;zx z>#!Hp`8AXuVB+hp{^e}#rwJ%DZ&}O{Qmvd<&5nMy)iUs0u@Gju}|KbV;jvxR6 zAOHd&00JNY0w4eaAOHd&Funxv{eR=DTZ{|@KmY_l00ck)1V8`;KmY_l00aa9-2WFO za0CGm009sH0T2KI5C8!X009sHf$=4P{r~vt79#@z5C8!X009sH0T2KI5C8!X00BV& z@BbGha0CGm009sH0T2KI5C8!X009sHf$=4P`~TysTZ{|@KmY_l00ck)1V8`;KmY_l z00aa9?EeJ`96s009sH0T2KI z5C8!X009sHfk6V;{|^$uIS7CN2!H?xfB*=900@8p2!H?xj2{8K|9|{+iV=YT2!H?x VfB*=900@8p2!H?xfWRPu{{>_kT=f6| diff --git a/management/client/client_test.go b/management/client/client_test.go index 313a67617..100b3fcaa 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -4,7 +4,6 @@ import ( "context" "net" "os" - "path/filepath" "sync" "testing" "time" @@ -58,7 +57,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { t.Fatal(err) } s := grpc.NewServer() - store, cleanUp, err := NewSqliteTestStore(t, context.Background(), "../server/testdata/store.sqlite") + store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), "../server/testdata/store.sql", t.TempDir()) if err != nil { t.Fatal(err) } @@ -514,22 +513,3 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) { assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientID, flowInfo.ProviderConfig.ClientID, "provider configured client ID should match") assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientSecret, flowInfo.ProviderConfig.ClientSecret, "provider configured client secret should match") } - -func NewSqliteTestStore(t *testing.T, ctx context.Context, testFile string) (mgmt.Store, func(), error) { - t.Helper() - dataDir := t.TempDir() - err := util.CopyFileContents(testFile, filepath.Join(dataDir, "store.db")) - if err != nil { - t.Fatal(err) - } - - store, err := mgmt.NewSqliteStore(ctx, dataDir, nil) - if err != nil { - return nil, nil, err - } - - return store, func() { - store.Close(ctx) - os.Remove(filepath.Join(dataDir, "store.db")) - }, nil -} diff --git a/management/server/account_test.go b/management/server/account_test.go index c417e4bc8..4dd58e88e 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2423,7 +2423,7 @@ func createManager(t TB) (*DefaultAccountManager, error) { func createStore(t TB) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 23941495e..c7f435b68 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -210,7 +210,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { func createDNSStore(t *testing.T) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index f8ab46d81..dc8765e19 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -88,7 +88,7 @@ func getServerKey(client mgmtProto.ManagementServiceClient) (*wgtypes.Key, error func Test_SyncProtocol(t *testing.T) { dir := t.TempDir() - mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{ + mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{ Stuns: []*Host{{ Proto: "udp", URI: "stun:stun.wiretrustee.com:3468", @@ -413,7 +413,7 @@ func startManagementForTest(t *testing.T, testFile string, config *Config) (*grp } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), testFile) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), testFile, t.TempDir()) if err != nil { t.Fatal(err) } @@ -471,6 +471,7 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie } func Test_SyncStatusRace(t *testing.T) { + t.Skip() if os.Getenv("CI") == "true" && os.Getenv("NETBIRD_STORE_ENGINE") == "postgres" { t.Skip("Skipping on CI and Postgres store") } @@ -482,9 +483,10 @@ func Test_SyncStatusRace(t *testing.T) { } func testSyncStatusRace(t *testing.T) { t.Helper() + t.Skip() dir := t.TempDir() - mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{ + mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{ Stuns: []*Host{{ Proto: "udp", URI: "stun:stun.wiretrustee.com:3468", @@ -627,6 +629,7 @@ func testSyncStatusRace(t *testing.T) { } func Test_LoginPerformance(t *testing.T) { + t.Skip() if os.Getenv("CI") == "true" || runtime.GOOS == "windows" { t.Skip("Skipping test on CI or Windows") } @@ -655,7 +658,7 @@ func Test_LoginPerformance(t *testing.T) { t.Helper() dir := t.TempDir() - mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{ + mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{ Stuns: []*Host{{ Proto: "udp", URI: "stun:stun.wiretrustee.com:3468", diff --git a/management/server/management_test.go b/management/server/management_test.go index ba27dc5e8..d53c177d6 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -58,7 +58,7 @@ var _ = Describe("Management service", func() { Expect(err).NotTo(HaveOccurred()) config.Datadir = dataDir - s, listener = startServer(config, dataDir, "testdata/store.sqlite") + s, listener = startServer(config, dataDir, "testdata/store.sql") addr = listener.Addr().String() client, conn = createRawClient(addr) @@ -532,7 +532,7 @@ func startServer(config *server.Config, dataDir string, testFile string) (*grpc. Expect(err).NotTo(HaveOccurred()) s := grpc.NewServer() - store, _, err := server.NewTestStoreFromSqlite(context.Background(), testFile, dataDir) + store, _, err := server.NewTestStoreFromSQL(context.Background(), testFile, dataDir) if err != nil { log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 7dbd4420c..8a3fe6eb0 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -773,7 +773,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { func createNSStore(t *testing.T) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 225571f62..f3bf0ddba 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1004,7 +1004,7 @@ func Test_RegisterPeerByUser(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) if err != nil { t.Fatal(err) } @@ -1069,7 +1069,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) if err != nil { t.Fatal(err) } @@ -1135,7 +1135,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) if err != nil { t.Fatal(err) } @@ -1188,6 +1188,6 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z") assert.NoError(t, err) - assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed) + assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed.UTC()) assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes) } diff --git a/management/server/route_test.go b/management/server/route_test.go index fbe022102..09cbe53ff 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1257,7 +1257,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { func createRouterStore(t *testing.T) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index fe4dcafdb..615203bee 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -911,28 +911,6 @@ func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, data return store, nil } -// NewPostgresqlStoreFromFileStore restores a store from FileStore and stores Postgres DB. -func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { - store, err := NewPostgresqlStore(ctx, dsn, metrics) - if err != nil { - return nil, err - } - - err = store.SaveInstallationID(ctx, fileStore.InstallationID) - if err != nil { - return nil, err - } - - for _, account := range fileStore.GetAllAccounts(ctx) { - err := store.SaveAccount(ctx, account) - if err != nil { - return nil, err - } - } - - return store, nil -} - // NewPostgresqlStoreFromSqlStore restores a store from SqlStore and stores Postgres DB. func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { store, err := NewPostgresqlStore(ctx, dsn, metrics) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 4eed09c69..06e118fd2 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -11,14 +11,13 @@ import ( "testing" "time" - nbdns "github.com/netbirdio/netbird/dns" - nbgroup "github.com/netbirdio/netbird/management/server/group" - "github.com/netbirdio/netbird/management/server/testutil" - "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + nbdns "github.com/netbirdio/netbird/dns" + nbgroup "github.com/netbirdio/netbird/management/server/group" + route2 "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/management/server/status" @@ -31,7 +30,10 @@ func TestSqlite_NewStore(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) if len(store.GetAllAccounts(context.Background())) != 0 { t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") @@ -39,15 +41,23 @@ func TestSqlite_NewStore(t *testing.T) { } func TestSqlite_SaveAccount_Large(t *testing.T) { - if runtime.GOOS != "linux" && os.Getenv("CI") == "true" || runtime.GOOS == "windows" { - t.Skip("skip large test on non-linux OS due to environment restrictions") + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } + t.Run("SQLite", func(t *testing.T) { - store := newSqliteStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) runLargeTest(t, store) }) + // create store outside to have a better time counter for the test - store := newPostgresqlStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) t.Run("PostgreSQL", func(t *testing.T) { runLargeTest(t, store) }) @@ -199,7 +209,10 @@ func TestSqlite_SaveAccount(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) account := newAccountWithId(context.Background(), "account_id", "testuser", "") setupKey := GenerateDefaultSetupKey() @@ -213,7 +226,7 @@ func TestSqlite_SaveAccount(t *testing.T) { Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") @@ -271,7 +284,10 @@ func TestSqlite_DeleteAccount(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) testUserID := "testuser" user := NewAdminUser(testUserID) @@ -293,7 +309,7 @@ func TestSqlite_DeleteAccount(t *testing.T) { } account.Users[testUserID] = user - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) if len(store.GetAllAccounts(context.Background())) != 1 { @@ -324,7 +340,7 @@ func TestSqlite_DeleteAccount(t *testing.T) { for _, policy := range account.Policies { var rules []*PolicyRule - err = store.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + err = store.(*SqlStore).db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules") require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount") @@ -332,7 +348,7 @@ func TestSqlite_DeleteAccount(t *testing.T) { for _, accountUser := range account.Users { var pats []*PersonalAccessToken - err = store.db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error + err = store.(*SqlStore).db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token") require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount") @@ -345,11 +361,10 @@ func TestSqlite_GetAccount(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") - if err != nil { - t.Fatal(err) - } - defer cleanup() + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) id := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -369,11 +384,10 @@ func TestSqlite_SavePeer(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") - if err != nil { - t.Fatal(err) - } - defer cleanup() + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) @@ -421,11 +435,10 @@ func TestSqlite_SavePeerStatus(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") - defer cleanup() - if err != nil { - t.Fatal(err) - } + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) @@ -478,11 +491,11 @@ func TestSqlite_SavePeerLocation(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") - defer cleanup() - if err != nil { - t.Fatal(err) - } + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) + account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) @@ -532,11 +545,11 @@ func TestSqlite_TestGetAccountByPrivateDomain(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") - defer cleanup() - if err != nil { - t.Fatal(err) - } + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) + existingDomain := "test.com" account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain) @@ -555,11 +568,11 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") - defer cleanup() - if err != nil { - t.Fatal(err) - } + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) + hashed := "SoMeHaShEdToKeN" id := "9dj38s35-63fb-11ec-90d6-0242ac120003" @@ -579,11 +592,11 @@ func TestSqlite_GetUserByTokenID(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") - defer cleanup() - if err != nil { - t.Fatal(err) - } + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) + id := "9dj38s35-63fb-11ec-90d6-0242ac120003" user, err := store.GetUserByTokenID(context.Background(), id) @@ -598,13 +611,18 @@ func TestSqlite_GetUserByTokenID(t *testing.T) { } func TestMigrate(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newSqliteStore(t) + // TODO: figure out why this fails on postgres + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - err := migrate(context.Background(), store.db) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) + + err = migrate(context.Background(), store.(*SqlStore).db) require.NoError(t, err, "Migration should not fail on empty db") _, ipnet, err := net.ParseCIDR("10.0.0.0/24") @@ -640,7 +658,7 @@ func TestMigrate(t *testing.T) { }, } - err = store.db.Save(act).Error + err = store.(*SqlStore).db.Save(act).Error require.NoError(t, err, "Failed to insert Gob data") type route struct { @@ -656,16 +674,16 @@ func TestMigrate(t *testing.T) { Route: route2.Route{ID: "route1"}, } - err = store.db.Save(rt).Error + err = store.(*SqlStore).db.Save(rt).Error require.NoError(t, err, "Failed to insert Gob data") - err = migrate(context.Background(), store.db) + err = migrate(context.Background(), store.(*SqlStore).db) require.NoError(t, err, "Migration should not fail on gob populated db") - err = migrate(context.Background(), store.db) + err = migrate(context.Background(), store.(*SqlStore).db) require.NoError(t, err, "Migration should not fail on migrated db") - err = store.db.Delete(rt).Where("id = ?", "route1").Error + err = store.(*SqlStore).db.Delete(rt).Where("id = ?", "route1").Error require.NoError(t, err, "Failed to delete Gob data") prefix = netip.MustParsePrefix("12.0.0.0/24") @@ -675,13 +693,13 @@ func TestMigrate(t *testing.T) { Peer: "peer-id", } - err = store.db.Save(nRT).Error + err = store.(*SqlStore).db.Save(nRT).Error require.NoError(t, err, "Failed to insert json nil slice data") - err = migrate(context.Background(), store.db) + err = migrate(context.Background(), store.(*SqlStore).db) require.NoError(t, err, "Migration should not fail on json nil slice populated db") - err = migrate(context.Background(), store.db) + err = migrate(context.Background(), store.(*SqlStore).db) require.NoError(t, err, "Migration should not fail on migrated db") } @@ -716,63 +734,15 @@ func newAccount(store Store, id int) error { return store.SaveAccount(context.Background(), account) } -func newPostgresqlStore(t *testing.T) *SqlStore { - t.Helper() - - cleanUp, err := testutil.CreatePGDB() - if err != nil { - t.Fatal(err) - } - t.Cleanup(cleanUp) - - postgresDsn, ok := os.LookupEnv(postgresDsnEnv) - if !ok { - t.Fatalf("could not initialize postgresql store: %s is not set", postgresDsnEnv) - } - - store, err := NewPostgresqlStore(context.Background(), postgresDsn, nil) - if err != nil { - t.Fatalf("could not initialize postgresql store: %s", err) - } - require.NoError(t, err) - require.NotNil(t, store) - - return store -} - -func newPostgresqlStoreFromSqlite(t *testing.T, filename string) *SqlStore { - t.Helper() - - store, cleanUpQ, err := NewSqliteTestStore(context.Background(), t.TempDir(), filename) - t.Cleanup(cleanUpQ) - if err != nil { - return nil - } - - cleanUpP, err := testutil.CreatePGDB() - if err != nil { - t.Fatal(err) - } - t.Cleanup(cleanUpP) - - postgresDsn, ok := os.LookupEnv(postgresDsnEnv) - if !ok { - t.Fatalf("could not initialize postgresql store: %s is not set", postgresDsnEnv) - } - - pstore, err := NewPostgresqlStoreFromSqlStore(context.Background(), store, postgresDsn, nil) - require.NoError(t, err) - require.NotNil(t, store) - - return pstore -} - func TestPostgresql_NewStore(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newPostgresqlStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) if len(store.GetAllAccounts(context.Background())) != 0 { t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") @@ -780,11 +750,14 @@ func TestPostgresql_NewStore(t *testing.T) { } func TestPostgresql_SaveAccount(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newPostgresqlStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) account := newAccountWithId(context.Background(), "account_id", "testuser", "") setupKey := GenerateDefaultSetupKey() @@ -798,7 +771,7 @@ func TestPostgresql_SaveAccount(t *testing.T) { Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") @@ -852,11 +825,14 @@ func TestPostgresql_SaveAccount(t *testing.T) { } func TestPostgresql_DeleteAccount(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newPostgresqlStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) testUserID := "testuser" user := NewAdminUser(testUserID) @@ -878,7 +854,7 @@ func TestPostgresql_DeleteAccount(t *testing.T) { } account.Users[testUserID] = user - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) if len(store.GetAllAccounts(context.Background())) != 1 { @@ -909,7 +885,7 @@ func TestPostgresql_DeleteAccount(t *testing.T) { for _, policy := range account.Policies { var rules []*PolicyRule - err = store.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + err = store.(*SqlStore).db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules") require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount") @@ -917,7 +893,7 @@ func TestPostgresql_DeleteAccount(t *testing.T) { for _, accountUser := range account.Users { var pats []*PersonalAccessToken - err = store.db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error + err = store.(*SqlStore).db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token") require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount") @@ -926,11 +902,14 @@ func TestPostgresql_DeleteAccount(t *testing.T) { } func TestPostgresql_SavePeerStatus(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite") + t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) @@ -965,11 +944,14 @@ func TestPostgresql_SavePeerStatus(t *testing.T) { } func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite") + t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) existingDomain := "test.com" @@ -982,11 +964,14 @@ func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) { } func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite") + t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) hashed := "SoMeHaShEdToKeN" id := "9dj38s35-63fb-11ec-90d6-0242ac120003" @@ -997,11 +982,14 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) { } func TestPostgresql_GetUserByTokenID(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite") + t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) id := "9dj38s35-63fb-11ec-90d6-0242ac120003" @@ -1011,11 +999,8 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) { } func TestSqlite_GetTakenIPs(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") - } - - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) defer cleanup() if err != nil { t.Fatal(err) @@ -1059,11 +1044,8 @@ func TestSqlite_GetTakenIPs(t *testing.T) { } func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") - } - - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) if err != nil { return } @@ -1104,11 +1086,8 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { } func TestSqlite_GetAccountNetwork(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") - } - - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1130,10 +1109,8 @@ func TestSqlite_GetAccountNetwork(t *testing.T) { } func TestSqlite_GetSetupKeyBySecret(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") - } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1152,11 +1129,8 @@ func TestSqlite_GetSetupKeyBySecret(t *testing.T) { } func TestSqlite_incrementSetupKeyUsage(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") - } - - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1187,11 +1161,13 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) { } func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) } + group := &nbgroup.Group{ ID: "group-id", AccountID: "account-id", diff --git a/management/server/store.go b/management/server/store.go index 50bc6afdf..d914bb8f7 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -9,10 +9,12 @@ import ( "os" "path" "path/filepath" + "runtime" "strings" "time" log "github.com/sirupsen/logrus" + "gorm.io/driver/sqlite" "gorm.io/gorm" "github.com/netbirdio/netbird/dns" @@ -240,30 +242,41 @@ func getMigrations(ctx context.Context) []migrationFunc { } } -// NewTestStoreFromSqlite is only used in tests -func NewTestStoreFromSqlite(ctx context.Context, filename string, dataDir string) (Store, func(), error) { - // if store engine is not set in the config we first try to evaluate NETBIRD_STORE_ENGINE +// NewTestStoreFromSQL is only used in tests. It will create a test database base of the store engine set in env. +// Optionally it can load a SQL file to the database. If the filename is empty it will return an empty database +func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) (Store, func(), error) { kind := getStoreEngineFromEnv() if kind == "" { kind = SqliteStoreEngine } - var store *SqlStore - var err error - var cleanUp func() - - if filename == "" { - store, err = NewSqliteStore(ctx, dataDir, nil) - cleanUp = func() { - store.Close(ctx) - } - } else { - store, cleanUp, err = NewSqliteTestStore(ctx, dataDir, filename) + storeStr := fmt.Sprintf("%s?cache=shared", storeSqliteFileName) + if runtime.GOOS == "windows" { + // Vo avoid `The process cannot access the file because it is being used by another process` on Windows + storeStr = storeSqliteFileName } + + file := filepath.Join(dataDir, storeStr) + db, err := gorm.Open(sqlite.Open(file), getGormConfig()) if err != nil { return nil, nil, err } + if filename != "" { + err = loadSQL(db, filename) + if err != nil { + return nil, nil, fmt.Errorf("failed to load SQL file: %v", err) + } + } + + store, err := NewSqlStore(ctx, db, SqliteStoreEngine, nil) + if err != nil { + return nil, nil, fmt.Errorf("failed to create test store: %v", err) + } + cleanUp := func() { + store.Close(ctx) + } + if kind == PostgresStoreEngine { cleanUp, err = testutil.CreatePGDB() if err != nil { @@ -284,21 +297,25 @@ func NewTestStoreFromSqlite(ctx context.Context, filename string, dataDir string return store, cleanUp, nil } -func NewSqliteTestStore(ctx context.Context, dataDir string, testFile string) (*SqlStore, func(), error) { - err := util.CopyFileContents(testFile, filepath.Join(dataDir, "store.db")) +func loadSQL(db *gorm.DB, filepath string) error { + sqlContent, err := os.ReadFile(filepath) if err != nil { - return nil, nil, err + return err } - store, err := NewSqliteStore(ctx, dataDir, nil) - if err != nil { - return nil, nil, err + queries := strings.Split(string(sqlContent), ";") + + for _, query := range queries { + query = strings.TrimSpace(query) + if query != "" { + err := db.Exec(query).Error + if err != nil { + return err + } + } } - return store, func() { - store.Close(ctx) - os.Remove(filepath.Join(dataDir, "store.db")) - }, nil + return nil } // MigrateFileStoreToSqlite migrates the file store to the SQLite store. diff --git a/management/server/testdata/extended-store.sql b/management/server/testdata/extended-store.sql new file mode 100644 index 000000000..b522741e7 --- /dev/null +++ b/management/server/testdata/extended-store.sql @@ -0,0 +1,37 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE); +CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`)); +CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text); +CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); +CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); +CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); +CREATE INDEX `idx_peers_key` ON `peers`(`key`); +CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`); +CREATE INDEX `idx_users_account_id` ON `users`(`account_id`); +CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`); +CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`); +CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`); +CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`); +CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`); +CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`); +CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); + +INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:01:38.210014+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO setup_keys VALUES('A2C8E62B-38F5-4553-B31E-DD66C696CEBB','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','["cfefqs706sqkneg59g2g"]',0,0); +INSERT INTO setup_keys VALUES('A2C8E62B-38F5-4553-B31E-DD66C696CEBC','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBC','Faulty key with non existing group','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','["abcd"]',0,0); +INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','["cfefqs706sqkneg59g3g"]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.210678+02:00','api',0,''); +INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.210678+02:00','api',0,''); +INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003','f4f6d672-63fb-11ec-90d6-0242ac120003','','SoMeHaShEdToKeN','2023-02-27 00:00:00+00:00','user','2023-01-01 00:00:00+00:00','2023-02-01 00:00:00+00:00'); +INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); +INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,''); +INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,''); +INSERT INTO installations VALUES(1,''); diff --git a/management/server/testdata/extended-store.sqlite b/management/server/testdata/extended-store.sqlite deleted file mode 100644 index 81aea8118ccf7d3af562ddece1f3007f1e8fc942..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 163840 zcmeI5Piz}ke#beYMN+aQ#_=W-C(-UmiDMI%Y*LhD$qU1kW!f=p%ZX%R*D!+NhyTukw&U@-<(bEEZE!smbz2#6Km;T zzyElfJ{<86=v@kZ9`y0u#5wD4N4aqJTQ2iP_S^B_9lJk%F*Q1#9sPQAGr2Q*CV52C zM!rt`KJg{}qsQ6TQ}X4LQbISX?2-MTK@C*3wquH8wQia1maP`t_Y<9|*V6LGD^ll@ z$kwZm)E2XB)6g2KR;e)CRvmMf8MfNCnWgI0_}t@b8>Q8|CFREY)zVLu^1;=Vl{@Q7 zIli*;nX-4Xa_;fwWLjQXkv?B_wCx6qFA$%a_1biI^~$Z%z;h{QsI_IqVUL{I;I!VM zp*5MGwWHZPtm?w>)7YbyZfTBg8fsN@NR+=idXvqTD;A^ARdsvc7xq&$CoI>{Y=@(z zCIp2)+PHChb>p6LvvkjGU2b;m&idxvjny0Lca?H&SLni4TbkQj5wa`H<4aB*C<3PwX|EloO^QVXj+zK>FGUT20}!<9E&R1EgR8ks2WeLme~l@ zs@p0xc2BP`0ky0c?IyGI%52%-T4fC;P8doLi>qM^;bY7jUcSz^8)jvf>o%xW=(}Oo zbt9zkz*?fVq1$$w7v{F9TrFK&-MV#GnbBJM%xqaV99H*i&tf%#TvkLHvtlws!p?rn z@BG1b=CN}Je~hVx*ybYE<`?u0E_W;?Uy=tIoY0zD5N{LP9jD;Up74-MOBD;GEw6cY2^cq zQ=G!}MmBTWEj5&&w|v8Dl3A)drtO3cP$bq()wCmGyUa3JBa&1xE#C;rqExG)IXs9( zqUDGN~+v+LLj6K~ljV3e5LAweNYOnu1c6QA= zluMDr+E*)PH89j}rta)V^c&>VnJF?N?(jL>oEV|~5p{SsmwA?(b8DwCN~36OFO`-r zPfMMXfe#9XdpLA<6Vv@pshb%JwyuNfX7#wfCm$S1%hS`+Q&~8xu3Y!iXw+?WGoHhD zg$GI~*6K`c z=uMr@9h8nS-RCf)F0mJ|rQ&ExzCJxDWQI;r=zdOZC)09qO6o{~Rr!%vt(r|uH~hmR z{nq4Xd6KRj`H5sjZUyDs7avJ!d1^}f1-UwMAijkEJ`%MaKRIeU{$%)|>9KcC&mX%- zDtQoAH?3f1O>;Oer)B9)jogNJ(vUm!gNt4z|E|~cuJiYE=-v1GeiC{3-Xu5vge-+e z0aAdY8+B@iZVR8PQj{UH;<2P>Mi__+Uq6JyeRhFRLjpmI)!Plt@)ikuI(w-0&qD)g zZ#_uyqovgcdHq28Q;POpiT>CSCsk=^dNbmzBT`zUVLlv$h^D=kM$1tOWPX5)30gsY zuDCxjCYL|~Q6~>#TG12q59HkWk(9hPHE68)Caf0xPf5PcP9%OXIYF;d_G*Hjcz^&1 zfB*=900@8p2!H?xfB*=900=yTKxajIV{78})H`EaKbgoJ&x~d=8NROGt`#cF`Q?SV zZEZ0>w@_gh=2jMpD|5w#rNv@CU#L}z+w=MPg}Fk0EtFu*P9iZ*?Ub^AO3)Jz5C8!X009sH0T2KI5C8!X z009sHfme*c>EzbfF#8Yg`v03m_M2C%gJ>=YfB*=900@8p2!H?xfB*=900@ADm`X{~f+ZJ3L{!_;$4Yd4&I zg}%f`d8j)(3Vo}T!XD{1|Gof4{3X4`=qo7bdlcwa_QG6#d9JXc%;zsIEM1(Zd(ZQW zE5)Vd`FZzl_W2?gS(z&=4;DGNmOR}aKPR3_K2P_^)AOu&y+5OESE@6gJU%0FV()px zE)7NO%6CNU%2ih5%s8@_2~DNOY7guS`6c_ouEFYyEA{#M%qRT%|48;IQQt!M)XfB*=900@8p z2!H?xfB*=900=|~@b&*N_bG_=|1g65{(t(>KRiGH1V8`;KmY_l00ck)1V8`;KmY^| zF#&%4AM5`^tYCBv1V8`;KmY_l00ck)1V8`;KmY`Y0M`F70}ucK5C8!X009sH0T2KI z5C8!XIQ#^#{y+RWMh`&%1V8`;KmY_l00ck)1V8`;KtTNd-{e0fvg`DM2MB-w2!H?x zfB*=900@8p2!H?xygURtZ;ecDZM~7nOdQW#zI`GyQCp}jRhKTz&n*>e+jE5itIVzB zt4nkF`Gt9{QkbXT3oH)vU3hJqS@%B~`0Il7`-Tg1g*;WO6qYXL3m1#a=jIFfr3=gM zj~i+&y>niorpeT_OK#IxRjV>rc;rAjf#`pifJcf;yfB*=9 z00@8p2!H?xfB*=900@A9M}U9-|1jGf;`{%H*?hqIAGQSoAOHd&00JNY0w4eaAOHd& z00J)^f$!z_|D^0X-T(j6^%V^Q0T2KI5C8!X009sH0T2KI5CDOflmOlTKen|sarW{X zmoKkWKP@iX#l>gc`ajI2{7v&V`%v55DOK;9H`zM<6+)5z1YvIef|Ac)6wkBbmH+jE zn^~C47kV=NQhgcx`oEO@Q-Yp&fB*=900@8p2!H?xfB*=900@8p2)tqhPA9jQsCyr+(`oC#^zCO%v2&_J2w%KIYEwkO4r(XoP z|H;P4;5v)XU1vem`JqI0o``Sf!f!3m00ck) z1V8`;KmY_l00cnbl_Bs}a#4!gp-9GO6zd~$a&fG`#6o={``?L4={JdE-(>&g*ssQa zckKT7#nk9{cJ%Af&E#Jvcale>-;8{n_b11|@ru;BB(nADBelgW+cdO>s#Pk?wpGX6WrnTtFE3E_s;hwS z_Z;ypIKH+~TD@CRZmeG|{ZuI*+|;siXI&}BS5`h#_BvF~J>Hy5%PT9==c|sk-C*$r z;#0F;o9?b&xm6P9`c5m0(L-4|LtQT`4twOx27T}j4Xw%itR2nXVO3Xqej0n!(k;!= zO+&3}4vF$tM{ly(a>ZiwxvFmO`@(*T=7i-Mn(c5~pe6)`Kiar)dv)WUaLZC!45 z?aun<-Hp{7>vxrMZC4nNt+q7PvDJE6xprrxbmRKES0GS?vQfHL+9<8Bl{OX8HJj8# zoz3O8C(`nwEOkZ$E9Hf3ZlY$}aL6vj@npY=^vbzLAyTRGZh1P%v6OsC9yD&;dJ5ug z;>o3>X<3$~r}u=#h(z&nENTxTEvm8589k+0W+PCVZmZOr`hrBo1eT{B6%*rm;Tu`gfX2Y!OMo683Ekrf1+jg55=9Z^iEnQpP zx^-8Xp)Jap*|KgptnOKs#cHIFvLe!$6_XhfcJ^C-=MT0bkDcp_Vhkz7HW#razhGca z5GMMwww02dkTD2&) zv1l*secNG8J_26z8nH6b=C(yFYYP;e@hMVZ3m*3Qkc1^1|ZL6m|Gxl`LG@8sH zpXw?=l&k5DZd1azs}y(}+IAN$*RN_jn$zYP)y1s=fa6aMv~KP%cFdYhSIH)xc1@nYyzd(QlAbXQs%A__*irU1Egx2f?9z zS1$9cFy+=>5N~69skD4~TI!q(d{8jl!=baAnC^E<-ONz1bsbbUtH<>{`QS)eo}QMT z%EDQ7<+`6nqi(C4@f^M@JWxV$WI*_G&7Arjna;FZd=^vfNwwG>O?ty-OQue9DViyf z#|Q#j++3&}6L;ITR%dEMZ|ZdJpmdDsK8G1~iM@a=6-QI@_31%Vp3o@@-Os7*WLhpx zNgXM$DnGocRkNw-hJSdZ-trWH(vX=3K(v@E@;k=yW28ghqzaM7#e-}Rc_b^d-1 zz59OOPa+TBo8-oykfrb_KnielqfX7xZQ)Z@iZWzYJeKrK^8!)f>xXc-&n^&ZNFZpj zdb^=n-n@QKXAjl>$z>qztvM-vw6yvluOCQ%O3~gc(H|S)q$&+fZ$_MTL`rKk%!h*z z(OlNjXvs@~%nxueK`W@w758Vmf}L8D|&+dft))(l9Jb^28}h}gw=xoDM>5s z3}0NA4mQ@IeFx(k-ty;+lZF{YjB?f!He1D1(bpZVWQCHCIR4~>;E9EM)E0*5k%i?gh z2PPPvE+AjRa|5B_@A`*8JkEACtFq(y5xd{a`i-V{8SUGKI|1QKzpmo?merw7KYKN_ z_QRC?lfod&>}RVM`qQRs@22GoC#B9t`t;UKRx(7;@H zFEF(@`_8j=A-(Dg?)wivJDHYOPf4Arz<63_L$BRA~+fg{XTgY2Y zhbknK3aTvUo=9?9K6Og^qU0{ELxSN7-;P>WI4kOGx>b8q3O-tVbBM8)evQ)7G31M= z2t|GoHUsBkn{A8q?AN-jS80tVN<~&e6mcE1Vg?KImT5UbA%37*no$qxIoJV-zUibA zp_tx{yl&5<4#zJ=-1%-E-cgR+_lxswbiSo+=aJN^379&T##3iRqKwuCb7me+D_mNO&a&IicsZ z5Lf?X<2K)9P%AsEvg<4MU@dxNyn!X!7u}KVq9bjEH)Z{jo+fkW-+0!f#;2mXI-?W4 zb@VUYNlooUZ@B22f4B*6_Uo)q($$#f2cw(1&vl{V@%zU`Z29~>pV{nD(Q!XaCvfrVORM$ zIg^%iIq8XaAu-~eqZOZu_BWatwTZA1#7rMT4e&X|1UxNka&Oq z2!H?xfB*=900@8p2!H?xfB*;_UIP63Ki2<;SHI{P2!H?xfB*=900@8p2!H?xfB*<^ z0$BgU4nP0|KmY_l00ck)1V8`;KmY_l;P4Z`{r`tw$LJvlfB*=900@8p2!H?xfB*=9 z00`jvKWqR5KmY_l00ck)1V8`;KmY_l00a&{0j&QIzmCyE5C8!X009sH0T2KI5C8!X z009ud{r|855C8!X009sH0T2KI5C8!X009s<`~-0Q|M2S=Jp=&|009sH0T2KI5C8!X z009sH0j&RF10VnbAOHd&00JNY0w4eaAOHd&aQF$}{{O?TWAqRNKmY_l00ck)1V8`; zKmY_l00eOTA2t92AOHd&00JNY0w4eaAOHd&00M`f0M`G9U&rVn2!H?xfB*=900@8p z2!H?xfB*>a`~Rivw+VXU0RkWZ0w4eaAOHd&00JNY0w4eaAaM8yB&B3x0)PL1_!NsC zfdB}A00@8p2!H?xfB*=900@9UF9EFodx_u}1V8`;KmY_l00ck)1V8`;KmY^|9|8RR z|KZaqdISO>00JNY0w4eaAOHd&00JNY0=)#V{_iD%V-NrV5C8!X009sH0T2KI5C8!X lID7plVC(-UmiDMI%Y*Li=$@Av+Fe!utoy*D$Iwmy1bdQ7X>PTladeBzaaq9p!8(-MiqJM{lD{THvd z=))OtLPtgV-0$N%iL1`vjrhXp?|qrq)8CK(?%c=YtI3h^^vE|OTf=)Jmxs?N#?Uv3 z-zUDNfAqTYYEr#XnrG(TdW#!L!Zs zQ@z0)*S3tBZj?&Qb#>3)XO^osUFPU!xp(gQ_07WCc0s$hal7zyt$2JjMeWgsR_t9_ z@p3V2teAPWHJMUZR+KN+JY%=UdKc)OnhyK4y>{zCp>JE-W$JBF^Vm~wHflCJwTwCw zS$l@N$I1Z=k;a}jOvms{+tSO1N230_XV%$lvE(rNT-J9Fg|J95JXx-0xSo%ex)2rq zaP!{7wav%c{lep*cbVDsM;lw)n``$rwzXnqpX%fwVems+KkaKXJ(70<*{mLdk(7*G6JblR&&uKOytWREittuj`?YqAGYjqjBTpX~koVlQ$7~ zBbRy2hAt(T4PjVyGE05Wc0Jhud1BSpZ8s*i&m4=@Vo4?25k^qtr5ZKE^MhDSs-qs; z)emT7a-4z*knSx@?*lnZr z^{cy{;WhmX^6M^Rv1u*~lx)jlB?=LebCC17kI;F9=_})~&zdH&(v1cU`@kS2M~=py zOx!Dkw&d=~mgLeYv;2rEw=cCBM716w{Gg@|$$B1HJuZ2~yM~H5^MwRI`vu;w` zQRyD0dkQmZ6MF$$%8w+~yVLzbrqo2P{W-ZioKo{sN=u2XN&50#WUHKqKDTpc+OAt8=KaqAJuaoZ7#;p3LaVVj{p4$gG)AgpRT z(aM_EaA8ivG3y4ojj&P49g5&$mdU@H6|?R9MGhT>-xo>b;X9Lp_#;`09|cGOo@rI7 z8>Y*Bs!mac%*u}?9V@~}l>7PwKSZ?bbqv~$QXmrnu1C;{>I=lh!kAnF z1w@lPi0#B1=$^=#YePwOeX8GF6DF+Z#HV3@pPfkjXmWzymGtcd{on-xAOHd&00JNY z0w4eaAOHd&00JOz1cBCy^7_ui!>PB&c78T7dVX|dbad3;SMOGGrRD7M!rZQLJv+Bh zVoP%?3;C6~{KDe(d^Ve_l=8dt+4+UJTy`!yujLk3v-#ER;?B7(fT{SXzl~iQUSm0I8Q)y0LY*{j)^*_nri zyH9D?t}U;vt=+gW^NWu^+4%Z5j}nP->Zg+aHbFmlfdB}A00@8p2!H?xfB*=900@8p z2)twjE)DODo#gsMu>b!qk^b%_8z8z10w4eaAOHd&00JNY0w4eaAOHd&a3q1V!&778 zxq{&N|8ydqKC%chK>!3m00ck)1V8`;KmY_l00cl_kO;IaB|WjT^X}Bv*4F6g&GVz9 zC%K8fHov}HSe(B#mtVefeQx3U_59qee6BEe`}X4E`r^vsdg0cs+pJkuFliL`q}F%`Nie=`QY*Y`MfW(GM8KKFLHb>S^BcU z9RJa>S^81|{aoSiAAb_u|Nmnm{l`I4BgO&(AOHd&00JNY0w4eaAOHd&00JN&2n-KT zjN$wLf(DKt00JNY0w4eaAOHd&00JNY0wD165jg4h|NZCx>5u;51p*)d0w4eaAOHd& z00JNY0w4eaATYoL{QLjd{|~T&F*FbW0T2KI5C8!X009sH0T2KI5Fi5B|HBMG00ck) z1V8`;KmY_l00ck)1VCW$31I&}_&UZ2K>!3m00ck)1V8`;KmY_l00cn5|Nj5b{kx~8_f#>z&8u?6;`T{Fsc(|WPK`S<^a(*Kx9e@kz8 zfdB}A00@8p2!H?xfB*=900@8p2>cKTTpr$0PI6&m$p8KSC%Mdk{r?X^t>^;?fB*=9 z00@8p2!H?xfB*=900_hg`1}8pJhZ^~|0fajpZ{0V8wvWs3j{y_1V8`;KmY_l00ck) z1V8`;K;T3KTGy1Zot=p*H($SbbEW)ge%Z}mKk6a>libSRvLCV!jIF&wdE36vHt72X zdHS-!-29T3&93sFEBxL6o2}90WiB1V8`;KmY_l00ck)1V8`;K;We!;NSnp{{N*?89ECBAOHd& z00JNY0w4eaAOHd&00O5bFp>VB#H8}u#JTU%|9bAP$A5S3n~|;IzZu>e zKBN40=$pjv6JOImdR=)nsouPxBuuN!p1PmaOpobC)3f=xUUlqd!_{-a@wL{}D=GD( z6{U59XPf1xdV@KxZ5cJ)D3zG&>YlyNELZn`d4X=0gYN%%+vwec^Xr?1we5m-Z{v31 z=UVajt`@aN8(OheIr|q>{4+^}kKWSw> z^ib3;)6k2W$DVq#(HO#0%cwJvwP(0{tQ=@hq_L+B(=j~Lw)C>$k*NRfnRPZ>EIEum zm-XF4AuLi1PnN40uIJkVbs;MJ;pV-EYnzX?`-R6r?=rLNk2bcpH`nfMY-`2JJ~tj$ zZy37g>eZrl=h0^2-rbF`K%@w5vv8-dS=d-FY-xOGE~$wIo5}1>q}1!G(i(}Z)Gy@v zCTce=kL;44Pj;I~r(9qZJe4YMm#0R~CDj{hzj+hXljDbp=Qqx#R8>{Jddw|`C-S#* zaeLrtag9kWhDvknTBJ17)oEM@W{GiRMa^o~nPZk_i-y1QfFifam|~q+w=?jmZ#k=+*#Xsu&vF|A5Wc`Et;0cs-a~$ ztU~%IYCMfOJu)R>@30{Tf4mii>;hlZ!;mDlnV2n!f{{6mcbZd!#m$MIpY1lMPC4OY zsC0)pJs3->msG`!EH)^*&uGLXJ>IWw5{OsoCuDw^N^y((bv^V^R0XeRG;VP`?Kq3$ zO#}|aWnQzPO9^H}SWBI3P2aO!Pqr$aShaQAjfw3u$6~cuQpt9NW{bR3qh@#&+o{K- zI_j}q{eVVB4k{+~#Os;dG`-iD@lm*kuE*;B3^XbH(InYr zPOKlHK*O=i&5{?JZhVMge%-K|6{F-eosRO%IxrpEsxymxYM=mKu42~OU6FB@De%^e z-8NcZzq;!gUenJYSL`wto4c|=$+j$3BCjsZcsZX>e4Ph{zA_H`tZ5P}-DuFT4-8Ur z)h!4vzg%;-^JDyTij4`T<(lzMYo zXa>=kl@fW3D6sWf3l)3B-KJ|)nO-yN zCbb=v?qRy8FrzlH7qF%LNK(B!-EYYwHIZw7PVNq;)clmvQX;Do;Z-l&b;Gp8>6vb8 z5?P^SU`Ha6tSG3UnECQUC8bVHDZe6DM-D_th~rS)dPH*EcEtMixD{>KX6TQDGo3sL ztJ+Sq7^a0;nA32~x2jZ{QNnSf3F|rWX_$7}qyBDtV)PFQHSs~>Y;OEFW1o)Q z82$IuUyc54YHFmI{I4_rIXrphUxxlZ@gGA*;)C?(lm9*O50n2iF@5fzsqi0yy}EdB zz44Qj`Y5NgR-%dPdr-dNpy7we!{#Zvb!X0WPv~wTF)?_2^f%z$A|2u4*)OJ2YA&aI zd67@{Zu#yME$#ri1>&xtJ*a4|Njv2-b4!ld@VCWsb_gb#o&g{s5xN0s`Un0Y67OZZ zhEv)L{fOIbX5B{9xs4XK5iCIbtzTPlUE6AD&`0lv)_T6&DycbwnTzUJbtB_&!x#0NY&o89Z zwTnt?Dl(piT{BCEy3?eMksh9yWjXboZab2*+l9j2v{WIPR8(a#^ITC=>cxx7mxW+! zEeXmMz7@AFIVqvVUI2t|G#HY4Za+D(V_ENa~~%d|)1 zrD8iFinyL#vZIZ8!*;x=kbj~(hETwZ3+{kXd>&souZX&jp3`2F8$+&)&A)NU1-i#pwCC<@jA}&vnuMg+3+| zmzL~>GCJJPvvWV<3)AQOZDS!R(Sq8;Nc>8Ub3&)JkU;-r^K`-sG(!Nma@mdVY zgcFPRFMdU~jgGV>Uz8OkLrrF`y?)f9Mo`fMol!$?p8d;UQB!%XGhB47KZ1lex^>ni z8E7o@gYk>Hr@B$`?7j0mwtVp}S8h;Y@cwFCv0+=MP^*0Pbe3*P_!kUA{~(r={&kA~ zhl{5{-(SGBpZcCr?|UbIw);8v(@9ld>FZ30e2?2(bjhSwI(_d}S^AUR9O}_<+%{{;ygK>!3m z00ck)1V8`;KmY_l00cl_a0&R&|10V56ZC@@2!H?xfB*=900@8p2!H?xfB*=9z~B)W zR)!N3`2PRkDHbCF0T2KI5C8!X009sH0T2KI5CDNr0@(j|62UnLfB*=900@8p2!H?x zfB*=900;~o0et^|@N|k1fdB}A00@8p2!H?xfB*=900@9UCjspLJBi>N1V8`;KmY_l w00ck)1V8`;KmY^=j{u(kA3U96L?8eHAOHd&00JNY0w4eaAOHd&&`IEb0dP~Nf&c&j diff --git a/management/server/testdata/store_policy_migrate.sql b/management/server/testdata/store_policy_migrate.sql new file mode 100644 index 000000000..a9360e9d6 --- /dev/null +++ b/management/server/testdata/store_policy_migrate.sql @@ -0,0 +1,35 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE); +CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`)); +CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text); +CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); +CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); +CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); +CREATE INDEX `idx_peers_key` ON `peers`(`key`); +CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`); +CREATE INDEX `idx_users_account_id` ON `users`(`account_id`); +CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`); +CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`); +CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`); +CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`); +CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`); +CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`); +CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); + +INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:04:23.538411+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0); +INSERT INTO peers VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','MI5mHfJhbggPfD3FqEIsXm8X5bSWeUI2LhO9MpEEtWA=','','"100.103.179.238"','Ubuntu-2204-jammy-amd64-base','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'crocodile','crocodile','2023-02-13 12:37:12.635454796+00:00',1,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','AAAAC3NzaC1lZDI1NTE5AAAAIJN1NM4bpB9K',0,0,'2024-10-02 14:04:23.523293+00:00','2024-10-02 16:04:23.538926+02:00',0,'""','','',0); +INSERT INTO peers VALUES('cfeg6sf06sqkneg59g50','bf1c8084-ba50-4ce7-9439-34653001fc3b','zMAOKUeIYIuun4n0xPR1b3IdYZPmsyjYmB2jWCuloC4=','','"100.103.26.180"','borg','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'dingo','dingo','2023-02-21 09:37:42.565899199+00:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','AAAAC3NzaC1lZDI1NTE5AAAAILHW',1,0,'2024-10-02 14:04:23.523293+00:00','2024-10-02 16:04:23.538926+02:00',0,'""','','',0); +INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:04:23.539152+02:00','api',0,''); +INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:04:23.539152+02:00','api',0,''); +INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','["cfefqs706sqkneg59g4g","cfeg6sf06sqkneg59g50"]',0,''); +INSERT INTO installations VALUES(1,''); diff --git a/management/server/testdata/store_policy_migrate.sqlite b/management/server/testdata/store_policy_migrate.sqlite deleted file mode 100644 index 0c1a491a68d58e019b5256b60338ab8b5f5e40d4..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 163840 zcmeI5O>7%SmdDwmB~r2}#_=Q_CDDXd;@CtanawYY9vEI)q7#j6*^w+b9wTTr$s*Mj z*-dvhWov>gKsyO$2UyGwHpc~KPji^V%wmAqoaeOI!(jKaK(343!(8^3!vdMZzW!u4 zKdcW=7MABv%Otz%Rdvzx=C()cwIJFqONK!$w*{EFNHWR^jC>QC=_~={-2}& z?&EcObId)USI+l-)Y~^hi`qYqo5JyLO_^8X-%kDZ#J#DD6XR3y@vp`=!dv6#!pAsi z?5oi4LtoNAdc61Y1b=yk3&~22J<=aEsez)@F;weVY-nnyt&4*Fda677QiT7oz;!QK z*>dfX*k+opDpFIFs#T`zqM>dxMHf3d(?q#8Jon`CdZ~E3lvrI`DSebE?_EtfadRzE z9$s1bT-iBUIr(^FI>HwU+-F5Ysx;Z~0>e||PMdBQuUs#UJXhizwYHov*dt@!JFRo5 zNG;}OZAtnTtJyHzH1?=1Ymym&7*Y75-mG0WDOU^IfT-&(4UR+(fohaA0EnVnhTM`XjY?KpMZ?2bC-(PbI zcY@@JJsinR0<>Zq~$0Iz?b6?!G%)k<{o=yanY?lq_)K`t8R!wbs zYL#`78oMJ`nFY0+P&zH9$<_I?VrrE&nRP;6x?f!MTZlQvoZ;o_ywX&w+oo>4S}lDy z)rPG26z*9|&^BaU@0f*6+f1yKt`;}1-%iX)ZFz3KEGq_UIJT#;IzcWctTd(#$@B>u zyKT4gd)t}A&K~?jOwGqO8L&3Dplfi+6BGO;ew4vkS`!89C03d0H>WlWnv->We9)Zw z_rTfh3`s4(Emgi*8V(p^ajNp*;)Tp{iAXurHko-v`Vj1VpT8@vRDy({CEV~MPGv4b`7gVW+`r|y5ToKE3u)9svZ#AW}3p9fuyRcxkga7O0}Dk zVFs~)R9`)+F78mz$lC>^9vVZFJMz$EyD>wv6lOeBwQW(V)ijDbEZ7TuS2tM89IP^h zCK^u_rUlyJ2(&e|)~OnSk!f|&$!|$Yr!G~Ej@DP6DLb;JDlMjvgSHi5m8;9mUQ_(I zYZMflQl*F1)UT);lF>0U$gk^+1%|m_psFegt5S&YIR}5fOa5loOV-5JjZdBfcwI(vzcL8sKq^aWenLG`lwT;G#-k45;?r@1e9%USj0dhf=A zZmXAZ0N>>wD87(0BGTDrPW=wdjXG^}N>uAh)z}VANd0C@rcM(xns$-L@M7SwiB)+> zi0;no#lQ?u8tgtE8)J51+B+T4%&`88QyDp z?3~l_$M%s(9)vYi&6`=%9L~vUYjR5>x8a=B=MLTABG<^j%XPWu{M{USb^X4ZL>|6B z$&NohOEIGWDZr4G1~o(0EuShJWvq;#H z*+X$~9_mRu>p_YiZK*NJ>wD7oDcUg6 za|7IvpykzPi@OtJatRa=W%3}Z7Cgb=Ku(??o8XscM~yYtghj!97dF?~q0slHWAx0$ zS3>lM2MB-w2!H?xfB*=900@8p2!H?xfWT7-bPL=oo3R_SZ%l69iA7IF$D`4xxvs9% zh3ZmjDZ5aSa;b%Em0egUWHW_@Og5j(q*6k?nyI8y>Fk1#T1cf6LjGbZdoi6^%w?9c zg797{eKD0Xm~Jdq)fR0G?a;1~WK~H;wxq08WtyrKmh378tIFypbE~)JF3t(5)M7rn zm|9HD&ClJC^leHze}1W0EMB@a_vyWl*S`Gu%}{8H+R4Sg4$&VTAOHd&00JNY0w4ea zAOHd&00JNY0!NI%+3@D%e)b>i_5U}a_%}zagJ>=YfB*=900@8p2!H?xfB*=900@A< zQwbao&rZ7c71;Ox$3yY>Q;Q%I1V8`;KmY_l00ck)1V8`;KmY_@Bm!N9i^n!M-=5vr z*oa0ipNvNLa}m9mUS2BY(^nQUOILFX*<3EOa3v#@7FJgB`Q>~ezg)U9AFQ(~c{Zy`y$uFhT_U-@aj44uB5SB)Z>|INW zK5VdH{UuT<`cMJ=y=Oh&``Etz|EEy=PcM=h(H9T^0T2KI5C8!X009sH0T2KI5C8#} zKsX$m#OMEA8h8Z(5C8!X009sH0T2KI5C8!X0D+@NV85ULH}C%+JMp{F^i_Jo0|Y<- z1V8`;KmY_l00ck)1fDknPo~Gt<~CVu;T8Hx<(Zk-+p$k88(Hl3S zCu6Z_wayxOy`IYJ54IK7$Q2s7)PC;h{_&0C&1;)%_3mn?qhyuTqg(4jC9_()`=eVe zefR#|)|K@AJIkG>x}3daebc&-$}Fbyi^5WBuA*vBUs`(NbQ_XE!VK zo1mQ<+hI+$ZT__0Cv(e9wNs-nW~Z-czvcXbt$FlmjedT`2XE^wAY z&X@WW^441$lcbQ%s81pF2l@q{Lb8qhDCEX!uJu9vhg+3K<5qnobM--KRsX29^ii&| zaffZLrmt_^EZk_9O2(aHPa(pE!eTnJG-rcfNT*ZTh5J&gwYwm-YI*Z}`E~ZJnyP9l z{l+W(#IE;Gf%e+$L`NKbfx00@8p2!H?xfB*=900@8p2!H?x90dZeg>&4nJ+Co7oSUSK zfxQ=d{r`(l{NLlhI0_v^8$kdBKmY_l00ck)1V8`;KmY_l-~}X*3N!8?_o9rQ53|XG z+$&+P|9=yTe;xnL3&;;$0s#;J0T2KI5C8!X009sH0T2Lzqe$Rfc$3@D$0f!N^!fkk ze+|WJ^n?cpfB*=900@8p2!H?xfB*=900=zq1iG({O>b_#5{<@AMlatu6^-rZTM_oN zl2s(_3CO^E*R009sH0T2KI5C8!X009sH0T2KI5csYMbhv14^VF&BXslXi^#}Te zR9=6ut*}O}(8x6Q^J@UbW|Mww;NIL|`D|lu-b`-f^}4UxTx#xP^ZGv*|6_>$@Bjf2 z009sH0T2KI5C8!X009sH0T4Li1o&`na&TiH(}>0YH#E)tI&|Wj_`jU^yQ$xvxHoli zVtgt-{?+(K`0vA8;bYvd$G!^vKJ+F1qsMzMPwMwB(|BRtBTYVrD~Pwx@f4|OwmR2M;1l7W<&mN&oR6O zCzsbt#oMLC>e@=_qeOY{rj`>o*AnI7m6gwxoeq_gk2j_xe4)U7Ry3qalMOF0JT>mL z>2~qT^^$e2KWk+}^iWQmqpp_|276@8dwp;Y6{*GCtSw33Vl`WPZW?>kmNm(cRYj~x z28o)_hTLNF<*LT$eNC+Fy25UXWccNpl5Uu`Kuz!p-&$Y2QCz>9xK_Grw=Oxqd~=v|d_UE^Q>NuIZ#E>TEJu ziADGv&vnN=D>V!0rirQ@#UQ)1j;9Arq+iZ93M-W=@0F*MoS5J*@uS9#U5{YB#GYI_ z9^rYO`{J%;F;=4WbRuXER$5SFzB4*X)6}M?G+7s^UpsP@S;)!>rPE@XT%9i~rp{QC zSts>rpjCNO}m8M$VHZ|wfYH72nHe|)8PR|yCnwNFGV-_|oFR@a(THL&T zJ26N5nRD}HSut3{u`G?%Ngw5emBzFonLc4-x9xU*Z!2=x*}iCqA^F%Q1GeNA^vr3h z-<;YkXinDg@j-Lymvem#mF_dA>ys1wS)P+Ui?xexGwLx(PmQXZ1cH^C3CW*CB9_Jd ztmXJ9s=}%#8nieoZ7++nPGLI`of)0B=u42>uC=tt*2FDUH~dy*B{o!1)dOPNOjB4h zkW^JQSF>fSRJ$n|byaHxr26Vnb#aG!Mh+?<_0Sla+>wVS+l?8Tr7+{6s%?u>t)@|I zW5HhNySl+z<_MH2{Lvt(FfGsyN1&~#wNBLt3^%KbPJT;LI(4aPbhN(mOxck&RcSGW ze5$Pgt6W`f_L}0yU8BI;lqx;6rhY};kc^I*L9SS5EHHNc0##K}Se3lGZ^Zlat${CE zL1Cnf#vXKJVkJs#>h_LAO7^^m>E5iN?%0yP<+miA4w+&`RR8&X=d=US5D~Lq`*cdI zy~W5w2AV})&aT?&zZvd&X6?(R$f50uRkh|BYA;hZb_4qLavDsvG6FvC0KUr_p@Tti z-@YpueOj2ZYZt7S$(@M^fB7`mo$-8-H{AW9vzHhdbV|KUU$C_uR4=Q~^*wp_ScE@) zn)`ycoK;V*_ijAswt5){@Lm3a;)^3A!k2C4)bGG_rqecOF}1!_jqT8+*Kf9D>NJ<4 znG$&nFR%@p3zdh&-Ht9bnAntCGM(EiJ;ZbmU`9P+&tOZL@d^I@)1#(5zEcRj_Y;+H zgwM=!UCy&AH@u29wI#`ldw6Wnn%pc$(zYWvk*vtBpq%{t78l`XXStt}t0M>EO1Q6M zLF;jogSO+&U-z2PI_Gr!v3(?x2Vo6W^CrVIF>`X-n%t7eZ8#_OxkER&$Tjlsa$W8@ ze>aC-UBB-pk%#Y3vg41>Qp_ko3NU1)LCuhL%cqJIWyq||SkgDm^F%FQzYmA|^a8$y zc!C;hbefXp%#3TWXB*`kwTCiuO*4!PqcNs@jy~R=`;Yq@-q3 zeK-mc&1E%-mb?_m+yFNuXnFP7;_hshTml6|nLLQ91y3+Ikdx=fCivyqQDeIVOcZ zh<`HuKe1m;|5NPriQiG-KiF$^_qqG(TM_=Iz;z4Wz%@Onwcw!cJII~WQ*;~5oEsd_ zgF^1Wu;Qb+03Q_T3m1<+osIB{H@4V{Or@$fsy3;+ypju^!W% zxD{HiC1^UX>$m4D-h7}X=?^+Il{IIh-V~db2ekb6x!^&w3RX@%G&k*N>TA%(!d_l! zq=DJ+o?&Y7-WyNbg>u68uivs-ILuF)FJDiv4>QN%UWs_HGw+p1=Gh0Fs{laz*6&)yEm z>YGd|u@uw4k=N^a(BZhHh`YI)XKu@`l&+U(FL$}Pv0Pj!B`>@-!CyNw$^|*~+QuOG zkKRyTjqpF9$>_?imY7xo6;q7>QZQ zUQXzs7Gmq4Y+N@t8N})qt8TlB-CK(e8E0Tw?F;V6_Rx{G{F}0FNk@~(^RGN@QsYt) zZJp7HUOWER_N1nMsy|!|%s=deR|j=AAZcsN@q@ul-2+{yc>K;uE4F<0wxwLVLi_pU zpkkeK9YC%A-P3uxC}D0GI{v|(PMZ4^{|}p|cH5u9wIBGJ(eP_0-`Vw?Yw-zQER1v} zM1IKSE!r{}mQKIB#a;NbT^t&6;-G8j+fnmJj-z!w=l+=pfBrnz4ck-t{%u!x3gqp> zo(qnzfqWB_HSg+ePfIi{GB+NR-HCT1{QF6+oAd<%zhex90)NKYLFW>5slJkKSn!u8 zH|TejkHd2jKAGg6I2RHF-Z@zDnP7i|nL(TI8^Mt2gRcROKRJ9f2=uM*DxL2&o(JlPu>`V_Fyd7*yAkWpNH_`X!(n7-AL+H2@#KGX5wD14_Iu!rf zUCHAW1V8`;KmY_l00ck)1V8`;KmY_l;HVHd8{VAU&zA<6-~Wf-|2ryLMLR(N1V8`; zKmY_l00ck)1V8`;K;Uo$@cVy<qOK>!3m00ck)1V8`;KmY_l00cnba0IaaKO8rd z0s#;J0T2KI5C8!X009sH0T2LzqeQ^G|DTJ08=^lvKmY_l00ck)1V8`;KmY_l00ck) z1YSG>VJ;ks;q(75o?_7>5C8!X009sH0T2KI5C8!X009u_CxG>TKM@>*00@8p2!H?x zfB*=900@8p2!OzgM*yGyfAMsR9)SP|fB*=900@8p2!H?xfB*=9KtBPj|NDvH7z987 z1V8`;KmY_l00ck)1V8`;UOWQ0|Nq6)DS89~AOHd&00JNY0w4eaAOHd&00R94{ttLC B#c2Ql diff --git a/management/server/testdata/store_with_expired_peers.sql b/management/server/testdata/store_with_expired_peers.sql new file mode 100644 index 000000000..100a6470f --- /dev/null +++ b/management/server/testdata/store_with_expired_peers.sql @@ -0,0 +1,35 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE); +CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`)); +CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text); +CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); +CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); +CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); +CREATE INDEX `idx_peers_key` ON `peers`(`key`); +CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`); +CREATE INDEX `idx_users_account_id` ON `users`(`account_id`); +CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`); +CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`); +CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`); +CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`); +CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`); +CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`); +CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); + +INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 17:00:32.527528+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,3600000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0); +INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-34653001fc3b','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,''); +INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,''); +INSERT INTO installations VALUES(1,''); diff --git a/management/server/testdata/store_with_expired_peers.sqlite b/management/server/testdata/store_with_expired_peers.sqlite deleted file mode 100644 index ed1133211d28b5b2faf4963e833b2ac521ff7fe0..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 163840 zcmeI5Pi!06eaAVXB}%esuCv+9D!WQ(InJ7}C2{yi6df2_N}{xi^~b9}>$Ml@a7aE; zW05oT%utrRNCB?1Mi8J#i{@CMIrY#(QUqwW*PMzTg7#7%*P=c2(%T*iY!3a+U(WE4 zCDoh4t9)sh!+G!b-n{qu{oe2Q-pp{g{r;M4u=t*;H6$aR4!suQxX_=(nl&vQR27*e^;`WNV*8gtsTQ&?Cl4m?);2DP>nH`tMJ%R8*I zt4IyzX6;M*KC9X=+%$I7lr_naRYj~y28o*QhTLGcN)?UK=c-sfc7@#($?(h7CEYO5 zQWLzw_qJBo3tJE3cZ(0~)+KH&ZfX^5#}?b!EdT z;IWHu6_<-!#f`<{cHHWkPGqUGiNxL52ruRbnMi+}qR>D6wQ@AVXL+tY>{*c60^Kwp zwWSzjGuHm)UK8n-vyH+^rIvQemlDtJT#WEM&pm%&nSmu@y3(t1Zz1LwbB33z^KxCS9GJTGYPIxT zS8KB3Q@Cd>LEDgZy=4|QZ8N@9TrTXc?Zl^~raX14Br681Iku;Fnv=DEvDcis z_rMt{&ZES>phv#I^V(p>_jNp*;_@KH;AXurHkoa*VVj1U88;*~nDy(`&gT`s4 zl}<2D>kzg#(wWg}ioOK7=^9pp%u?J}b;ECfR$@&RRXrefz%+%`14$KCbB&;6m1@={ z!wg~psjhldT|A_ok+%y-Ju&(wx8%Obc4PWxDa?4HY6qfJt!fl^Sg;rRv2L)2Iap;1 zO*Ed$ObfKb5ol^^wN)_!Bh%`lli!e()}B-`T3T0mrX0$esx+8F4%$|LRc=qNcbekI zU8SH{m&zTqrhY};kc^g@L4I9lEHKRd0u@zJScO7_&pG(>EgvCT5oVx_#vZq1VkJsV z>h_^TO7-3t-^>Q(MJZ#k=uT<6np z&~0@x&fvTJ1H~6|21GjB%&FgjxlyZWPKl~rsTw<^38~*~$<%3LM$<0x7+wtQH?b=9 z3DGTGsxh%HH)J}tSGtesp23Vd#9qRd(!(SC%GE(3({~7=^LeBkj_~P8uFZK?dpF4Dei(DoDF7L@5=kMmw*7f^t5_$OU zBs>22EX9lhqyR%!YSau_w|uHdQHIRQj3r$&LQmB4^`~&SFD~F~h$pDATB|N;&LZJZ zW>3W4d8jAttOqH6G^N@gukT4eqG<1w=#35iq$+huZUmflKuW6D)hB}x(X>~SXgNxO z%nfjTf|ggGE$&W?$t6%gl*xmrTJQk91357}G{P@V4jOB&35$aJDQvE@L!qx_*oo-x}M!A053iHat2yYObrx zdqQPCIiHy>OWEXfrowX5`Aj-LozBc<)5)Z;S4o#s$y8=qNKPkHaUpj*nY^7&&16%# zY-;{yGDZ6a(~X&m+Mp9dhji9RvZ|ya`%=cLGDTGi^LCYjRb}PlsnvT^x2J?;a%L_w zlblIT-I`jL^aDzpot-Zf3U}^I{qUn7gqM^5`18;>wUdkeE<}HLfdB}A00@8p2!H?x zfB*=900@8p2z&^z=enC{8ae&CM;&<>wZQ3kyqZPiobT zI6VzXWAw}pdhi6@`!7r<=ck2yJe9ninY*2$oAr~~e0pv^m9lUDPo+(f{IoDXSmfkd zlJsSRY3nae-!QQLZd&gj{lLEd|NBtv_vcBC=nDvd00@8p2!H?xfB*=900@8p2!Mb~ zARLa4;rstC4QxRG1V8`;KmY_l00ck)1V8`;K;WxK;I!ZWH}C%skL#h!-@a6&4PGDs z0w4eaAOHd&00JQJWfOR|8NQL-Wzn0H^n;m-a@Lsfw;qJlmaz%TTnP0rWbFi^no~^30N;V}6 zYuW01nWg=e*3Hbr*1FfMTs)b-O^+PU`ofXfj(@=;z#o{exQSQC7$_lc&+>R(<<8b@)_WNxgTxzOb{q z+W2r&{$PJ`L6#S!?Bk{E(bM^Gre#X;iIzN-mOgX|g$=hdSJJL(6IvO7^gN?JL!y|QuBkju089^PEfRn{{1jh&~v+D`MpJERaNqsz>b>CMaw zxm-R!m#1G3Fg*&1^!s%I-_SL$|A%A04#oZ}_Aj)-3j{y_1V8`;KmY_l00ck)1V8`; zK;V@j@Mbv6^*c8jwljK9u?_Lz>{zAG85(>2uY_X%8vC1=^2&4&4FLfV009sH0T2KI z5C8!X009sHfv+%uWSDVhxi=|n7C6hTFGI6oHg=XfD(v;E15 zc_B}a&oCd4AvNXpn;f<6jZWM4=!pc?xm@Z+Px?4b+gdu)Uao28_5Y#RFGF-zs`&N6uGf4cwwkAxw*2?8Jh z0w4eaAOHd&00JNY0w4eauR4K0$o>DJ*IxB*qkSL%0w4eaAOHd&00JNY0w4eaAOHd{ zL%_WLkM;k{5JXuJ009sH0T2KI5C8!X009sH0T6ig31I#I>gxzC1OX5L0T2KI5C8!X z009sH0T2Lzmn9H2E5QB#FN+c7K>!3m00ck)1V8`;KmY_l00cnbl_r4e|F5)q&^QnP z0T2KI5C8!X009sH0T2KI5cpCEL}ULGy3G9|bm{ZhKVABZ@n2o~X#Dob@OW(av*GRV zUxoL>7r0*xeHQvn=x6kgUN^rs!oNGgg=D45j`YWM*d$rY%)z(6{#*tl?u~!(NGVVqKoDa8H#e%?$RIj82wvtWpS%m*eS+W zHx% zETo$zspO`pvHVhbd;v4bx&!sE>gb^b@D zgVh|%(%2s9qZGH&nARuLCu|%y-OitEMGiaL7xghDAKOI0mfV7#IgNLlQOX-*bYQzMyo0M66B_9Ee*0YabML9zg1a@HC0sgfY<@k6jl!;RaDK@Y{@Fs ztV_n8sx<;qUG=ECct|}X2NjTdV)RXJ$$gXU#`MinnDIo_4n(P1)hM>HU@!Dz-CzxK z1j-ctXpoed7HEef(A3mwt6~I(o7F`pzac5DJ*i@}w65|@Ig~Y3X)uL+s;vO4+@4(T zG{uj*N`bd7l{;un{ffFF87(t|T(QnrVC?z@DypKe3VC(ki1+7P17Ea)!ax~~J#NXw zN|c(^?L&!_?0FB6BP=pOJ?QG>g2P zUA5DHGu(B|+Lud_Lpv5LYSlB;PNr-e2lVUZ)R<~z1bo~Xe3vysdxPMqeOF@iMPbUW zU9dJ|ha(aG-K$)C!t+7iaQBDKPGY3jDRnY^!Pa(Aovbd`_w2hDBK*~>+;iS?Rvo#{ zr{SR6>SUb3clig3FOCcdU$&W3zXQ{mR@0otRJ&3&c1V+6zuA(h(_D&XO5`!Tz}9ar zRO%CVTe?(ZVqI>?bZoD5AJaX98Fh%gge|3qNBEVigQh&bLkOMEBjs>}Pfv1f&a)~v zyoyz|A<2rnd!g5w+$=}Zwj(!@tjMmQl=$=>7vU!-xu1}$BM0J2xZ9zi^|;AF+i~Zw zC(US`V>H;1;a-*=P9 z!*?gy@yBN=W)vU=7_w5MX2`nbQ$>n0WL9P@>6+$wqL!~eg~NSu0bfHrL5d#6NiZ0ILdsY`Mr;H(2uQoXJ|8H9-DvYJFo zUJ7Jxfa?>qy!vc$ceYC|fdZmT9z@lG2k0HhiP@nMesOZpSaVHS6x>f?T4|4(tL^CM zZ$o_Oz0gHr{4d8Ijolgjx5%H5{&i$>xHR&g7ycuB`NFS<{xt+~JnlbV(rJvVN5>DlyIo2=V5})x75NlRMSQc~KGtKp6SqRkwFFJa zb^Z37#hVW_B>i!Vrn2U2)SF_{@_?5AJ{R0;R>8`thvrE;n)(`aV&NpOG|<3ocrP)v zxcSzL&O$oX3-~TQESn{NOX4Osz3GZUOVzt~vCvmR^n0($VLOs0c-V zD{OksMORxI>DjHdELUlbW|a!8gec+~YDM)H=1o;Iyh7%Vs7XrAtLJ0~Wc5uZl~{`D zK9SeydC=jwrHH$EHqSgQyHs2&(z)El!uDceshG&UIl|w)KF9?*_1eZD`1jsYz7gTS zN0ZSjLCZ0l*ct2Y@h1kDOicem=SMb!%e!?4V;TB;_7~?_(rpB_}zdvs#F) zf3k7iJjozd_F3h?RqV-HbjUaZ%W7ZnjBE!TY0H07)-CC1GBNwcizYQL717oi9q7%A ze`Zf=_FnG}7d`V2JK-C>I_r_NHRkxi;7Q#xU8wlvJ6EjO@}p~(a_tK3_pb#N>m2J0 zYW1Hzy+s!#%oB!=e{iRh=6Q<$hm)sv+h4-9pZS_m|7$0IuuOzs3!WRVmjxi7l{2A}`I+vhJ^_6tPg1-26AA8_b0IO{or4u$5B4{h8MFz%5%ifp_!{8&lmB}(2=uM%DxL2*FOCV{yEeja zCI$|_02DJa_z(Ks_Qk6f!uI=?;n?eKcljOM`%VyHFw;bM(yHF|y91Z;iS|%D!oTwl z*Iu+*BP(<|s!o?N$W!Zqoom6i1oB*MdJ}zrE-fU?a|kVWg4i3p6S)81J%fQQ2!H?x zfB*=900@8p2!H?xfB*=bUjlgk-}%)odIka@00JNY0w4eaAOHd&00JNY0xkhu|944X z3j!bj0w4eaAOHd&00JNY0w4ea=a&H1|L0e?=otur00@8p2!H?xfB*=900@8p2)G1r z|G!HDTMz&N5C8!X009sH0T2KI5C8!XIKKpN{r~*x7Ci$25C8!X009sH0T2KI5C8!X z00EZ(*8eUEY(W47KmY_l00ck)1V8`;KmY_l;QSK6{r~4zx9AxNfB*=900@8p2!H?x zfB*=900_7QaQ)vUfh`Dt00@8p2!H?xfB*=900@8p2%KL6SpT12-J)k800JNY0w4ea zAOHd&00JNY0wCZL!2SO&32Z?C1V8`;KmY_l00ck)1V8`;K;Zlm!1e$0t6TI81V8`; zKmY_l00ck)1V8`;KmY_>0$BgMB(Mbm5C8!X009sH0T2KI5C8!X0D<#Mz`XyTi+vHI zKfFKy1V8`;KmY_l00ck)1V8`;KmY{JAAv9z4n^_(|MRC<^aun%00ck)1V8`;KmY_l z00ck)1iA@e{ohRl`yc=UAOHd&00JNY0w4eaAOHd&aQ+D3`~T-pr|1y~fB*=900@8p z2!H?xfB*=900?vw!1}+N2=+k$1V8`;KmY_l00ck)1V8`;K;Zll!2SQ{Pp9Y+2!H?x WfB*=900@8p2!H?xfB*<|6Zl`7K~6CM diff --git a/management/server/testdata/storev1.sql b/management/server/testdata/storev1.sql new file mode 100644 index 000000000..69194d623 --- /dev/null +++ b/management/server/testdata/storev1.sql @@ -0,0 +1,39 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE); +CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`)); +CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text); +CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); +CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); +CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); +CREATE INDEX `idx_peers_key` ON `peers`(`key`); +CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`); +CREATE INDEX `idx_users_account_id` ON `users`(`account_id`); +CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`); +CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`); +CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`); +CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`); +CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`); +CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`); +CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); + +INSERT INTO accounts VALUES('auth0|61bf82ddeab084006aa1bccd','','2024-10-02 17:00:54.181873+02:00','','',0,'a443c07a-5765-4a78-97fc-390d9c1d0e49','{"IP":"100.64.0.0","Mask":"/8AAAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO accounts VALUES('google-oauth2|103201118415301331038','','2024-10-02 17:00:54.225803+02:00','','',0,'b6d0b152-364e-40c1-a8a1-fa7bcac2267f','{"IP":"100.64.0.0","Mask":"/8AAAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO setup_keys VALUES('831727121','auth0|61bf82ddeab084006aa1bccd','1B2B50B0-B3E8-4B0C-A426-525EDB8481BD','One-off key','one-off','2021-12-24 16:09:45.926075752+01:00','2022-01-23 16:09:45.926075752+01:00','2021-12-24 16:09:45.926075752+01:00',0,1,'2021-12-24 16:12:45.763424077+01:00','[]',0,0); +INSERT INTO setup_keys VALUES('1769568301','auth0|61bf82ddeab084006aa1bccd','EB51E9EB-A11F-4F6E-8E49-C982891B405A','Default key','reusable','2021-12-24 16:09:45.926073628+01:00','2022-01-23 16:09:45.926073628+01:00','2021-12-24 16:09:45.926073628+01:00',0,1,'2021-12-24 16:13:06.236748538+01:00','[]',0,0); +INSERT INTO setup_keys VALUES('2485964613','google-oauth2|103201118415301331038','5AFB60DB-61F2-4251-8E11-494847EE88E9','Default key','reusable','2021-12-24 16:10:02.238476+01:00','2022-01-23 16:10:02.238476+01:00','2021-12-24 16:10:02.238476+01:00',0,1,'2021-12-24 16:12:05.994307717+01:00','[]',0,0); +INSERT INTO setup_keys VALUES('3504804807','google-oauth2|103201118415301331038','A72E4DC2-00DE-4542-8A24-62945438104E','One-off key','one-off','2021-12-24 16:10:02.238478209+01:00','2022-01-23 16:10:02.238478209+01:00','2021-12-24 16:10:02.238478209+01:00',0,1,'2021-12-24 16:11:27.015741738+01:00','[]',0,0); +INSERT INTO peers VALUES('oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','auth0|61bf82ddeab084006aa1bccd','oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','EB51E9EB-A11F-4F6E-8E49-C982891B405A','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:13:11.244342541+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.182618+02:00',0,'""','','',0); +INSERT INTO peers VALUES('xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','auth0|61bf82ddeab084006aa1bccd','xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','1B2B50B0-B3E8-4B0C-A426-525EDB8481BD','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:12:49.089339333+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.182618+02:00',0,'""','','',0); +INSERT INTO peers VALUES('6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','google-oauth2|103201118415301331038','6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','5AFB60DB-61F2-4251-8E11-494847EE88E9','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:12:05.994305438+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.228182+02:00',0,'""','','',0); +INSERT INTO peers VALUES('Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','google-oauth2|103201118415301331038','Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','A72E4DC2-00DE-4542-8A24-62945438104E','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:11:27.015739803+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.228182+02:00',0,'""','','',0); +INSERT INTO installations VALUES(1,''); + diff --git a/management/server/testdata/storev1.sqlite b/management/server/testdata/storev1.sqlite deleted file mode 100644 index 9a376698e4d226fc08fa68c12fb9bb4cf50375cd..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 163840 zcmeI5U2GfKb;miPBucg=uCv)poJ29DU9U~Zv^o5y17WKV+N-Q3c_lkpV-XC8nW4`*J9@$Z(o!v!l_ zZtRO4rs=98wM402XSyyL>Mm1sv8yvplpBL{FD`FZ7VlRQtLrP3_Y&3PtEnb7))Upi zl~u1*oyMxkr(4r8zFg+Mzi3Fc78_h(aBAG?)BVMzyOp7BC9Y9#s|kbc8`r&Nol`|> zGdF8T(sx+HhT*2M{f?|jhO8=LLo!Izd^hAayI!qpj6OHS+JP(VrbtFmt|jS)iI%$H z6~4E*y0*CaF!4_1q20UW_2rHAt^1pctLyg@)$LtN7rNMyL_-&w)x_RxFITn_*3fh!OM^`&--(a&VtJT}48|!6gVSHBO~v>;&mE3=7G(B7H_b=w zDhAn%bw1s1BE539QCO+eQ@?yQ`E22QjOTgof|4 zJ;pe#CTwq{Go#xP0||1+HLNz7rMRQ&M$iDQ#HK2$dPr=SX$otFlIp7F8bQ@6)oDqF z8N@i zXg<}L7V3v1(9zUJw{C=HrZq$-zbz@C-rc2pZ6Vvg&C{)=^X7&+5Wx5y#dRXgK1 zW2bM{fn16l+JRVC8=j&1nX+*Z(yy1(WU7@Ba)+mIbJh&)kEkbia}!6oIlFcWqcn>q z_r_!VTeIBZCC>+W(>)kE{lr+mQ}Qzd!Pa(AepZj`d-mpNjGvw5KIScF<;(e>#=>sP z&p3te3Qm+j$QcspY%`~Eht@{jj=3ai^rUKRj~1jsvn5lfg&8fo$YXdhaL~f4Iv_-M zb*ah3mfV)9ZLjnI(>;Y5`NW>Xma=2x{GHihAv4f~;C~*kMPqz+hCAdut8yc;*ihS& zthlG6{nq4WIg+*=xrt;&b_LbsNB6iGKQqJqkX#)(5Ld!Ij)bkpO%B_RyBI!hdF-_5 z_+$G_BoD%xs^+b%X$|M(bTqjwk=t+@4Y)%$xX2Ci@A9_nJAXHaj;`N#lgPvOCfV^P zU@2x4AO#q*(xh(4y5&h@!nyqCYkalB%~PxgB!WAt|ZVQlAV%M9W@HqU|UJGB>~t z2wGl!wz#`6CYL|~Q6>+fYT*X@Cvx)U$T+_|Gir>MEljzqpVJw@*$)Bg~O zKcqLjKmY_l00ck)1V8`;KmY_l00cnb^FZJ*Ir7@p)Y{b4#Di~7O!+MS3kP%zgO&ch*1o@kS(a_y+gd>VUqQeN2C^Ak?-?nMQ+2wR9<$P8TFe zsMYI@LERKeLaFFk56<%@-UDkl|7sd90E5C8!X009sH0T2KI5CDNQPT*-|G(Nla;PtO=ZEa0VZe5(1$mY|zlKHQAg8Ph% z#Y`o)vYbh!(<_x!E}zS!N{jTobs09Z{jZC(bE7HT-<@vNgPi2z`RVt)1*`q{GP)krmxRoj1NDKL5PAFzec0=!f z=Wsu2A(&ZPTeA^la;1E^kShq;6ClVh-d-xCSC&!*;dUmK%jAVrsUireTscQ0s#Hp) zN_mBCOWl@XBG4H9YyF~3LW!Q9Kl<)*WhpOI%9W+mq9EK(*RNE|7^8ilo_q=D*9uFVpaS<1b}}Vx}l$jv2kM zlv&EBm(r=FY^9XSEv1)Ji@8i8mCxiWD@&zZNmyDr&g=u-&KL8UL0TUsGI;QzS^^?M z#-v;w=GH}PM`e;NO;Gw$dxJP-f@5C8!X009sH0T2KI5C8!X z_#-AT9i8R;D|VyiseknjFHd?`IPCrZuKE4{e~$kozWYZE0bK_H5C8!X z009sH0T2KI5C8!X0D;c~ffIgdfAoZM(YfdtPK})Z3;KW; z2!H?xfB*=900@A<871&+Wb`WAVh^rdroT&j>CzOvBJ}@aD+#)AQo~DYV=1LyC2Lgx72Qfezlu^zx$r^ zcSFsy553af2YNm|{FJ))@OZXt{dTzdlsdfx^C|VTkf0wHxBksr?^JfckB-yNdDC;g z=8ulg%{_}mhbZB4j}i*Ik8162Jr;J_-TJq_ySMb7n&0TwALch6?#(~i(wBrs5ANT6 zoGA+{PfntQ(^PkoXWM-xWd7hv=uv_H5WD%^0{cPc=TO1Z4@a)D)sxJF{nmc@W_hJF zFV~b!d3ANEq`ZA+ZU4^3clW;~6t&XU&b_yra!z`xFP&|r+~V)F0Gaxe_Kv3AJki0-6wZ8=jXS& z+Wt2W*0P;4X&j8SOOmf0=)3@*c zFGS)>d?Eg?@xM8v!4AU$0T2KI5C8!X009sH0T2KI5CDPSpFldwxYJzC95oA^=E_^t zEO45OMI$$(Z1OZ$H?aTz{nd|>AOHd&00JNY0w4eaAOHd&00Ms^1kTLc|3_YV@qHw{S00@8p2!H?xfB*=900@8p2!OzgCV>0@ zFPbpw2LTWO0T2KI5C8!X009sH0T2Lzmrnru|Cg@=^Z*1v00ck)1V8`;KmY_l00ck) z1YR@&Jpccq38Q`x009sH0T2KI5C8!X009sH0T6ik1aSZVxoLM`@bc@{}g)B8_WLt}_c+HKBCdOq1)^tBR>J)?!wJf%Krb7}UQx<(=8?YP8l; z>$|4ryjm@7w$!Gq1k~x-LRj;%u6NDCrsXA8Dz_IO+`XTelREO;^{T8Gtm#;m#0GZlby(P(tn>5z=F}_a`WPzR zV@`J`$N8%~Cwmra7u{tvVv;^LtZotrS866Ce-MjV7WbpJqcn0SwnR4+mh1Vmg+`V>nYEaJy}zgHdDx_+6u7BZObjcD?!{13cM|; z=A$+BE9!=1bj=KM#X4i5xf>LytBS(vjJ_asuX z=RHjKW=(a+mh_#VCFyj^6f>d*+Yhv92ciKY=D1Gilvrnnk%tU*i@cm&wKIM*-1%l5 z$fd}k9f)qq5a~gMOIn(W!tC&Vls>b$c(Hk^dGId%@(MpLth8Nfdt%a%s;%--$noMlT zZJFBkN)IsIQ<#xY>^W>HJ2uYWnH{#|2{a-2pT}#_7@wWt4mr=N-0&(k)V3rm?&)a1 zHMv=iq-{rTB3Y4LK{ff&Jub%2%y2&>S4R%Sm2i(EVe4^|!?xqDUyobSI&C`s*gg}< zgRrKmd5d9Mm^nEeO>Rr%Hk?KS?$8Y`a)bQ4ye<3A-_4<;>-XIx^6Mcobhn#gtN@}&#C&LiYT2_;2%S(aG4R8a3mRFxG z?yh#pB~U<=$%Cj`xPktOoV+!;h(7Ruk5|L z`+oSP%Q1dK;10{)#5Fyrwc()QJII~pDZ2G%&h<~|ej#^aSn<)^fcJ~^go~#i&ct{@ z;6A!yP4<5I{**-60rU%mU4cI+Z>|M)$_=L1HMwJMi-Xw?7;k#ofLsa34FsmY?H@ex zLAEPt^&Q8L=>2BaZ#2EzXxBFE1&F!z^A*>(tqu(O=-trr*T?y93&Sk4pRFkLr_FAE zEyfowafg$hanXe)!(2CdTXZ|&a=#JvN!oG2PVO<5XIH)w<1by}KDLd;ma*Sw!baf~ zPXw=*_?>a}==fpxv`2}%jCDk-qJW~QNMO}FzcUwh^ItvcDx_1rU?0E! z!KE0#c!fKh@rN(y4S!0t)C6;1(FXZ_n4?7&U6md7N z=9!meS1NZabS-yzacg;TrIIXud7OXe$}kt?)N31q5Ip;;@}(I64O)y|3|o%bMSrfl z?Oz&VGGXa}y--Gv`+0idi`K&Q*{*LaB;~fi4=@t5lH;7vX)VOoKiRl$UStsKJFLFz zD)x9SI%J%QW%Vz7Mb<}0+6rEjbxS&$Ox}F$s6~xSMYMHB4So6ipWBO??N@ulMc?|v zPWV#4&iW*6jX8cWd{OsQH!7aKanXt`-+$dwu3e%1{)Mn&owiP)*5K9C>vU7XykO|~ z2X{GXUZ?mBE}q(be-77v>U&0m@14A``#JB#$9b_l)R_?Z0k^m4lF6X-S0WLv_y@Z= zG@#+IYv{S67K|K!T=#PxU5W8GZ*qrGdr9AW+0|VFdDmgDh38jazKO|Nclnp6C0Z7l z7ao&`<8Q?HJ4x;^83+PF#~2C)!Hlc@&L!+p10~(C5G+q_Fz70uM(1LDGRZx2ZX|}h zbGYIw;qitu!!{8#f&t5iKnEOu^7}`F(AawJ(goT%IwpMM^>Kb9IduAkpqP;%*yx*w z=VvX1hwoa3WAC@!?RWU>jWEJ+rit*lUA^mfhi>DOha-s?fAuPNxNP-CR_Jn6i*93( zr`AI!uZQ~*%5$~pE%bxAw2?5cA#~jZVt?>X;`x8~3I>iK00JNY0w4eaAOHd&00JNY z0w8d93E=nt&aQ4TG7ta(5C8!X009sH0T2KI5C8!Xa0%f4ze@r~5C8!X009sH0T2KI z5C8!X009s{{QUi79#@z5C8!X009sH0T2KI5C8!X00EZ( zp8t1A;0OXB00JNY0w4eaAOHd&00JNY0%w;1?*E@%-C|@Q00JNY0w4eaAOHd&00JNY z0wCZL!2aJQfg=cj00@8p2!H?xfB*=900@8p2%KF4=JWqt{C5%hhZhKd00@8p2!H?x zfB*=900@8p2!O!ZBM{}Hktuxt|LiFiBLV>s009sH0T2KI5C8!X009sHfnEaG|MwEX zIS7CN2!H?xfB*=900@8p2!H?xoIL{g{{Pw2DMkbWAOHd&00JNY0w4eaAOHd&00O-P zu>bERf^!f60T2KI5C8!X009sH0T2KI5IB1T@cjST( Date: Thu, 10 Oct 2024 14:14:56 +0200 Subject: [PATCH 30/37] Add billing user role (#2714) --- management/server/user.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/management/server/user.go b/management/server/user.go index 38a8ac0c4..71608ef20 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -19,10 +19,11 @@ import ( ) const ( - UserRoleOwner UserRole = "owner" - UserRoleAdmin UserRole = "admin" - UserRoleUser UserRole = "user" - UserRoleUnknown UserRole = "unknown" + UserRoleOwner UserRole = "owner" + UserRoleAdmin UserRole = "admin" + UserRoleUser UserRole = "user" + UserRoleUnknown UserRole = "unknown" + UserRoleBillingAdmin UserRole = "billing_admin" UserStatusActive UserStatus = "active" UserStatusDisabled UserStatus = "disabled" @@ -41,6 +42,8 @@ func StrRoleToUserRole(strRole string) UserRole { return UserRoleAdmin case "user": return UserRoleUser + case "billing_admin": + return UserRoleBillingAdmin default: return UserRoleUnknown } From 09bdd271f10fa80f42424ffdb14deb1db60e55a9 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 10 Oct 2024 15:54:34 +0200 Subject: [PATCH 31/37] [client] Improve route acl (#2705) - Update nftables library to v0.2.0 - Mark traffic that was originally destined for local and applies the input rules in the forward chain if said traffic was redirected (e.g. by Docker) - Add nft rules to internal map only if flush was successful - Improve error message if handle is 0 (= not found or hasn't been refreshed) - Add debug logging when route rules are added - Replace nftables userdata (rule ID) with a rule hash --- client/firewall/iptables/acl_linux.go | 57 +++++++- client/firewall/iptables/manager_linux.go | 2 +- client/firewall/iptables/router_linux.go | 15 ++- client/firewall/manager/firewall.go | 6 +- client/firewall/nftables/acl_linux.go | 124 +++++++++++++++--- client/firewall/nftables/router_linux.go | 35 +++-- client/firewall/nftables/router_linux_test.go | 8 +- client/internal/acl/id/id.go | 41 +++++- go.mod | 18 +-- go.sum | 32 ++--- util/net/net.go | 3 +- 11 files changed, 267 insertions(+), 74 deletions(-) diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index c6a96a876..c271e592d 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -11,6 +11,7 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + nbnet "github.com/netbirdio/netbird/util/net" ) const ( @@ -21,13 +22,19 @@ const ( chainNameOutputRules = "NETBIRD-ACL-OUTPUT" ) +type entry struct { + spec []string + position int +} + type aclManager struct { iptablesClient *iptables.IPTables wgIface iFaceMapper routingFwChainName string - entries map[string][][]string - ipsetStore *ipsetStore + entries map[string][][]string + optionalEntries map[string][]entry + ipsetStore *ipsetStore } func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) { @@ -36,8 +43,9 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi wgIface: wgIface, routingFwChainName: routingFwChainName, - entries: make(map[string][][]string), - ipsetStore: newIpsetStore(), + entries: make(map[string][][]string), + optionalEntries: make(map[string][]entry), + ipsetStore: newIpsetStore(), } err := ipset.Init() @@ -46,6 +54,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi } m.seedInitialEntries() + m.seedInitialOptionalEntries() err = m.cleanChains() if err != nil { @@ -232,6 +241,19 @@ func (m *aclManager) cleanChains() error { } } + ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING") + if err != nil { + return fmt.Errorf("list chains: %w", err) + } + if ok { + for _, rule := range m.entries["PREROUTING"] { + err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...) + if err != nil { + log.Errorf("failed to delete rule: %v, %s", rule, err) + } + } + } + for _, ipsetName := range m.ipsetStore.ipsetNames() { if err := ipset.Flush(ipsetName); err != nil { log.Errorf("flush ipset %q during reset: %v", ipsetName, err) @@ -267,6 +289,17 @@ func (m *aclManager) createDefaultChains() error { } } + for chainName, entries := range m.optionalEntries { + for _, entry := range entries { + if err := m.iptablesClient.InsertUnique(tableName, chainName, entry.position, entry.spec...); err != nil { + log.Errorf("failed to insert optional entry %v: %v", entry.spec, err) + continue + } + m.entries[chainName] = append(m.entries[chainName], entry.spec) + } + } + clear(m.optionalEntries) + return nil } @@ -295,6 +328,22 @@ func (m *aclManager) seedInitialEntries() { m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...)) } +func (m *aclManager) seedInitialOptionalEntries() { + m.optionalEntries["FORWARD"] = []entry{ + { + spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark), "-j", chainNameInputRules}, + position: 2, + }, + } + + m.optionalEntries["PREROUTING"] = []entry{ + { + spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark)}, + position: 1, + }, + } +} + func (m *aclManager) appendToEntries(chainName string, spec []string) { m.entries[chainName] = append(m.entries[chainName], spec) } diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 6fefd58e6..94bd2fccf 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -78,7 +78,7 @@ func (m *Manager) AddPeerFiltering( } func (m *Manager) AddRouteFiltering( - sources [] netip.Prefix, + sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 737b20785..e60c352d5 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -305,10 +305,7 @@ func (r *router) cleanUpDefaultForwardRules() error { log.Debug("flushing routing related tables") for _, chain := range []string{chainRTFWD, chainRTNAT} { - table := tableFilter - if chain == chainRTNAT { - table = tableNat - } + table := r.getTableForChain(chain) ok, err := r.iptablesClient.ChainExists(table, chain) if err != nil { @@ -329,15 +326,19 @@ func (r *router) cleanUpDefaultForwardRules() error { func (r *router) createContainers() error { for _, chain := range []string{chainRTFWD, chainRTNAT} { if err := r.createAndSetupChain(chain); err != nil { - return fmt.Errorf("create chain %s: %v", chain, err) + return fmt.Errorf("create chain %s: %w", chain, err) } } if err := r.insertEstablishedRule(chainRTFWD); err != nil { - return fmt.Errorf("insert established rule: %v", err) + return fmt.Errorf("insert established rule: %w", err) } - return r.addJumpRules() + if err := r.addJumpRules(); err != nil { + return fmt.Errorf("add jump rules: %w", err) + } + + return nil } func (r *router) createAndSetupChain(chain string) error { diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index a6185d370..556bda0d6 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -132,7 +132,7 @@ func SetLegacyManagement(router LegacyManager, isLegacy bool) error { // GenerateSetName generates a unique name for an ipset based on the given sources. func GenerateSetName(sources []netip.Prefix) string { // sort for consistent naming - sortPrefixes(sources) + SortPrefixes(sources) var sourcesStr strings.Builder for _, src := range sources { @@ -170,9 +170,9 @@ func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix { return merged } -// sortPrefixes sorts the given slice of netip.Prefix in place. +// SortPrefixes sorts the given slice of netip.Prefix in place. // It sorts first by IP address, then by prefix length (most specific to least specific). -func sortPrefixes(prefixes []netip.Prefix) { +func SortPrefixes(prefixes []netip.Prefix) { sort.Slice(prefixes, func(i, j int) bool { addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr()) if addrCmp != 0 { diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index eaf7fb6a0..61434f035 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -11,12 +11,14 @@ import ( "time" "github.com/google/nftables" + "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" + nbnet "github.com/netbirdio/netbird/util/net" ) const ( @@ -29,6 +31,7 @@ const ( chainNameInputFilter = "netbird-acl-input-filter" chainNameOutputFilter = "netbird-acl-output-filter" chainNameForwardFilter = "netbird-acl-forward-filter" + chainNamePrerouting = "netbird-rt-prerouting" allowNetbirdInputRuleID = "allow Netbird incoming traffic" ) @@ -40,15 +43,14 @@ var ( ) type AclManager struct { - rConn *nftables.Conn - sConn *nftables.Conn - wgIface iFaceMapper - routeingFwChainName string + rConn *nftables.Conn + sConn *nftables.Conn + wgIface iFaceMapper + routingFwChainName string workTable *nftables.Table chainInputRules *nftables.Chain chainOutputRules *nftables.Chain - chainFwFilter *nftables.Chain ipsetStore *ipsetStore rules map[string]*Rule @@ -61,7 +63,7 @@ type iFaceMapper interface { IsUserspaceBind() bool } -func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainName string) (*AclManager, error) { +func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) { // sConn is used for creating sets and adding/removing elements from them // it's differ then rConn (which does create new conn for each flush operation) // and is permanent. Using same connection for both type of operations @@ -72,11 +74,11 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainNa } m := &AclManager{ - rConn: &nftables.Conn{}, - sConn: sConn, - wgIface: wgIface, - workTable: table, - routeingFwChainName: routeingFwChainName, + rConn: &nftables.Conn{}, + sConn: sConn, + wgIface: wgIface, + workTable: table, + routingFwChainName: routingFwChainName, ipsetStore: newIpsetStore(), rules: make(map[string]*Rule), @@ -462,9 +464,9 @@ func (m *AclManager) createDefaultChains() (err error) { } // netbird-acl-forward-filter - m.chainFwFilter = m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward) - m.addJumpRulesToRtForward() // to netbird-rt-fwd - m.addDropExpressions(m.chainFwFilter, expr.MetaKeyIIFNAME) + chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward) + m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd + m.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME) err = m.rConn.Flush() if err != nil { @@ -472,10 +474,96 @@ func (m *AclManager) createDefaultChains() (err error) { return fmt.Errorf(flushError, err) } + if err := m.allowRedirectedTraffic(chainFwFilter); err != nil { + log.Errorf("failed to allow redirected traffic: %s", err) + } + return nil } -func (m *AclManager) addJumpRulesToRtForward() { +// Makes redirected traffic originally destined for the host itself (now subject to the forward filter) +// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the +// netbird peer IP. +func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error { + preroutingChain := m.rConn.AddChain(&nftables.Chain{ + Name: chainNamePrerouting, + Table: m.workTable, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookPrerouting, + Priority: nftables.ChainPriorityMangle, + }) + + m.addPreroutingRule(preroutingChain) + + m.addFwmarkToForward(chainFwFilter) + + if err := m.rConn.Flush(); err != nil { + return fmt.Errorf(flushError, err) + } + + return nil +} + +func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) { + m.rConn.AddRule(&nftables.Rule{ + Table: m.workTable, + Chain: preroutingChain, + Exprs: []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyIIFNAME, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(m.wgIface.Name()), + }, + &expr.Fib{ + Register: 1, + ResultADDRTYPE: true, + FlagDADDR: true, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL), + }, + &expr.Immediate{ + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark), + }, + &expr.Meta{ + Key: expr.MetaKeyMARK, + Register: 1, + SourceRegister: true, + }, + }, + }) +} + +func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) { + m.rConn.InsertRule(&nftables.Rule{ + Table: m.workTable, + Chain: chainFwFilter, + Exprs: []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyMARK, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark), + }, + &expr.Verdict{ + Kind: expr.VerdictJump, + Chain: m.chainInputRules.Name, + }, + }, + }) +} + +func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) { expressions := []expr.Any{ &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, &expr.Cmp{ @@ -485,13 +573,13 @@ func (m *AclManager) addJumpRulesToRtForward() { }, &expr.Verdict{ Kind: expr.VerdictJump, - Chain: m.routeingFwChainName, + Chain: m.routingFwChainName, }, } _ = m.rConn.AddRule(&nftables.Rule{ Table: m.workTable, - Chain: m.chainFwFilter, + Chain: chainFwFilter, Exprs: expressions, }) } @@ -509,7 +597,7 @@ func (m *AclManager) createChain(name string) *nftables.Chain { return chain } -func (m *AclManager) createFilterChainWithHook(name string, hookNum nftables.ChainHook) *nftables.Chain { +func (m *AclManager) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain { polAccept := nftables.ChainPolicyAccept chain := &nftables.Chain{ Name: name, diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index aa61e1858..9b8fdbda5 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -10,6 +10,7 @@ import ( "net/netip" "strings" + "github.com/davecgh/go-spew/spew" "github.com/google/nftables" "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" @@ -24,7 +25,7 @@ import ( const ( chainNameRoutingFw = "netbird-rt-fwd" - chainNameRoutingNat = "netbird-rt-nat" + chainNameRoutingNat = "netbird-rt-postrouting" chainNameForward = "FORWARD" userDataAcceptForwardRuleIif = "frwacceptiif" @@ -149,7 +150,6 @@ func (r *router) loadFilterTable() (*nftables.Table, error) { } func (r *router) createContainers() error { - r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{ Name: chainNameRoutingFw, Table: r.workTable, @@ -157,25 +157,26 @@ func (r *router) createContainers() error { insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw]) + prio := *nftables.ChainPriorityNATSource - 1 + r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{ Name: chainNameRoutingNat, Table: r.workTable, Hooknum: nftables.ChainHookPostrouting, - Priority: nftables.ChainPriorityNATSource - 1, + Priority: &prio, Type: nftables.ChainTypeNAT, }) r.acceptForwardRules() - err := r.refreshRulesMap() - if err != nil { + if err := r.refreshRulesMap(); err != nil { log.Errorf("failed to clean up rules from FORWARD chain: %s", err) } - err = r.conn.Flush() - if err != nil { + if err := r.conn.Flush(); err != nil { return fmt.Errorf("nftables: unable to initialize table: %v", err) } + return nil } @@ -188,6 +189,7 @@ func (r *router) AddRouteFiltering( dPort *firewall.Port, action firewall.Action, ) (firewall.Rule, error) { + ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action) if _, ok := r.rules[string(ruleKey)]; ok { return ruleKey, nil @@ -248,9 +250,18 @@ func (r *router) AddRouteFiltering( UserData: []byte(ruleKey), } - r.rules[string(ruleKey)] = r.conn.AddRule(rule) + rule = r.conn.AddRule(rule) - return ruleKey, r.conn.Flush() + log.Tracef("Adding route rule %s", spew.Sdump(rule)) + if err := r.conn.Flush(); err != nil { + return nil, fmt.Errorf(flushError, err) + } + + r.rules[string(ruleKey)] = rule + + log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action) + + return ruleKey, nil } func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) { @@ -288,6 +299,10 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error { return nil } + if nftRule.Handle == 0 { + return fmt.Errorf("route rule %s has no handle", ruleKey) + } + setName := r.findSetNameInRule(nftRule) if err := r.deleteNftRule(nftRule, ruleKey); err != nil { @@ -658,7 +673,7 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error { return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err) } - log.Debugf("nftables: removed rules for %s", pair.Destination) + log.Debugf("nftables: removed nat rules for %s", pair.Destination) return nil } diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index bbf92f3be..25b7587ac 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -314,6 +314,10 @@ func TestRouter_AddRouteFiltering(t *testing.T) { ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) require.NoError(t, err, "AddRouteFiltering failed") + t.Cleanup(func() { + require.NoError(t, r.DeleteRouteRule(ruleKey), "Failed to delete rule") + }) + // Check if the rule is in the internal map rule, ok := r.rules[ruleKey.GetRuleID()] assert.True(t, ok, "Rule not found in internal map") @@ -346,10 +350,6 @@ func TestRouter_AddRouteFiltering(t *testing.T) { // Verify actual nftables rule content verifyRule(t, nftRule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet) - - // Clean up - err = r.DeleteRouteRule(ruleKey) - require.NoError(t, err, "Failed to delete rule") }) } } diff --git a/client/internal/acl/id/id.go b/client/internal/acl/id/id.go index e27fce439..8ce73655d 100644 --- a/client/internal/acl/id/id.go +++ b/client/internal/acl/id/id.go @@ -1,8 +1,11 @@ package id import ( + "crypto/sha256" + "encoding/hex" "fmt" "net/netip" + "strconv" "github.com/netbirdio/netbird/client/firewall/manager" ) @@ -21,5 +24,41 @@ func GenerateRouteRuleKey( dPort *manager.Port, action manager.Action, ) RuleID { - return RuleID(fmt.Sprintf("%s-%s-%s-%s-%s-%d", sources, destination, proto, sPort, dPort, action)) + manager.SortPrefixes(sources) + + h := sha256.New() + + // Write all fields to the hasher, with delimiters + h.Write([]byte("sources:")) + for _, src := range sources { + h.Write([]byte(src.String())) + h.Write([]byte(",")) + } + + h.Write([]byte("destination:")) + h.Write([]byte(destination.String())) + + h.Write([]byte("proto:")) + h.Write([]byte(proto)) + + h.Write([]byte("sPort:")) + if sPort != nil { + h.Write([]byte(sPort.String())) + } else { + h.Write([]byte("")) + } + + h.Write([]byte("dPort:")) + if dPort != nil { + h.Write([]byte(dPort.String())) + } else { + h.Write([]byte("")) + } + + h.Write([]byte("action:")) + h.Write([]byte(strconv.Itoa(int(action)))) + hash := hex.EncodeToString(h.Sum(nil)) + + // prepend destination prefix to be able to identify the rule + return RuleID(fmt.Sprintf("%s-%s", destination.String(), hash[:16])) } diff --git a/go.mod b/go.mod index e7137ce5b..cb37ca4bb 100644 --- a/go.mod +++ b/go.mod @@ -19,8 +19,8 @@ require ( github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 github.com/vishvananda/netlink v1.2.1-beta.2 - golang.org/x/crypto v0.24.0 - golang.org/x/sys v0.21.0 + golang.org/x/crypto v0.28.0 + golang.org/x/sys v0.26.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 @@ -38,6 +38,7 @@ require ( github.com/cilium/ebpf v0.15.0 github.com/coreos/go-iptables v0.7.0 github.com/creack/pty v1.1.18 + github.com/davecgh/go-spew v1.1.1 github.com/eko/gocache/v3 v3.1.1 github.com/fsnotify/fsnotify v1.7.0 github.com/gliderlabs/ssh v0.3.4 @@ -45,7 +46,7 @@ require ( github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.6.0 github.com/google/gopacket v1.1.19 - github.com/google/nftables v0.0.0-20220808154552-2eca00135732 + github.com/google/nftables v0.2.0 github.com/gopacket/gopacket v1.1.1 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 github.com/hashicorp/go-multierror v1.1.1 @@ -55,7 +56,7 @@ require ( github.com/libp2p/go-netroute v0.2.1 github.com/magiconair/properties v1.8.7 github.com/mattn/go-sqlite3 v1.14.19 - github.com/mdlayher/socket v0.4.1 + github.com/mdlayher/socket v0.5.1 github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 @@ -89,10 +90,10 @@ require ( goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a - golang.org/x/net v0.26.0 + golang.org/x/net v0.30.0 golang.org/x/oauth2 v0.19.0 - golang.org/x/sync v0.7.0 - golang.org/x/term v0.21.0 + golang.org/x/sync v0.8.0 + golang.org/x/term v0.25.0 google.golang.org/api v0.177.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/postgres v1.5.7 @@ -133,7 +134,6 @@ require ( github.com/containerd/containerd v1.7.16 // indirect github.com/containerd/log v0.1.0 // indirect github.com/cpuguy83/dockercfg v0.3.1 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgraph-io/ristretto v0.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/distribution/reference v0.6.0 // indirect @@ -219,7 +219,7 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/mod v0.17.0 // indirect - golang.org/x/text v0.16.0 // indirect + golang.org/x/text v0.19.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect diff --git a/go.sum b/go.sum index 4563dc933..05df5c66e 100644 --- a/go.sum +++ b/go.sum @@ -322,8 +322,8 @@ github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= -github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A= -github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc= +github.com/google/nftables v0.2.0 h1:PbJwaBmbVLzpeldoeUKGkE2RjstrjPKMl6oLrfEJ6/8= +github.com/google/nftables v0.2.0/go.mod h1:Beg6V6zZ3oEn0JuiUQ4wqwuyqqzasOltcoXPtgLbFp4= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= @@ -475,8 +475,8 @@ github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy5 github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= -github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U= -github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA= +github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= +github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k= github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= @@ -774,8 +774,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= -golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= +golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= +golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -871,8 +871,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= -golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= -golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= +golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= +golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -901,8 +901,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -974,8 +974,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= -golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= +golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -983,8 +983,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= -golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA= -golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= +golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24= +golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -999,8 +999,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= +golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/util/net/net.go b/util/net/net.go index 61b47dbe7..035d7552b 100644 --- a/util/net/net.go +++ b/util/net/net.go @@ -11,7 +11,8 @@ import ( const ( // NetbirdFwmark is the fwmark value used by Netbird via wireguard - NetbirdFwmark = 0x1BD00 + NetbirdFwmark = 0x1BD00 + PreroutingFwmark = 0x1BD01 envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" ) From b2379175fe856e24c71263b7f0dcfac77ab8a722 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 10 Oct 2024 16:23:46 +0200 Subject: [PATCH 32/37] [signal] new signal dispatcher version (#2722) --- go.mod | 2 +- go.sum | 4 ++-- signal/server/signal.go | 3 +-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index cb37ca4bb..e7e3c17a6 100644 --- a/go.mod +++ b/go.mod @@ -61,7 +61,7 @@ require ( github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd - github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f + github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 github.com/patrickmn/go-cache v2.1.0+incompatible diff --git a/go.sum b/go.sum index 05df5c66e..e9bc318d6 100644 --- a/go.sum +++ b/go.sum @@ -525,8 +525,8 @@ github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811- github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= -github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f h1:Rl23OSc2xKFyxiuBXtWDMzhZBV4gOM7lhFxvYoCmBZg= -github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= +github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= +github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= diff --git a/signal/server/signal.go b/signal/server/signal.go index 63cc43bd7..305fd052b 100644 --- a/signal/server/signal.go +++ b/signal/server/signal.go @@ -6,6 +6,7 @@ import ( "io" "time" + "github.com/netbirdio/signal-dispatcher/dispatcher" log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" @@ -13,8 +14,6 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" - "github.com/netbirdio/signal-dispatcher/dispatcher" - "github.com/netbirdio/netbird/signal/metrics" "github.com/netbirdio/netbird/signal/peer" "github.com/netbirdio/netbird/signal/proto" From 0e95f16cdd8462242a6912ef061347caba67041b Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 11 Oct 2024 16:24:30 +0200 Subject: [PATCH 33/37] [relay,client] Relay/fix/wg roaming (#2691) If a peer connection switches from Relayed to ICE P2P, the Relayed proxy still consumes the data the other peer sends. Because the proxy is operating, the WireGuard switches back to the Relayed proxy automatically, thanks to the roaming feature. Extend the Proxy implementation with pause/resume functions. Before switching to the p2p connection, pause the WireGuard proxy operation to prevent unnecessary package sources. Consider waiting some milliseconds after the pause to be sure the WireGuard engine already processed all UDP msg in from the pipe. --- client/internal/peer/conn.go | 229 +++++++++++++----------- client/internal/wgproxy/ebpf/proxy.go | 35 +--- client/internal/wgproxy/ebpf/wrapper.go | 102 +++++++++-- client/internal/wgproxy/proxy.go | 5 +- client/internal/wgproxy/proxy_test.go | 2 +- client/internal/wgproxy/usp/proxy.go | 93 +++++++--- 6 files changed, 298 insertions(+), 168 deletions(-) diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 0d4ad2396..1b740388d 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -82,8 +82,6 @@ type Conn struct { config ConnConfig statusRecorder *Status wgProxyFactory *wgproxy.Factory - wgProxyICE wgproxy.Proxy - wgProxyRelay wgproxy.Proxy signaler *Signaler iFaceDiscover stdnet.ExternalIFaceDiscover relayManager *relayClient.Manager @@ -106,7 +104,8 @@ type Conn struct { beforeAddPeerHooks []nbnet.AddHookFunc afterRemovePeerHooks []nbnet.RemoveHookFunc - endpointRelay *net.UDPAddr + wgProxyICE wgproxy.Proxy + wgProxyRelay wgproxy.Proxy // for reconnection operations iCEDisconnected chan bool @@ -257,8 +256,7 @@ func (conn *Conn) Close() { conn.wgProxyICE = nil } - err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) - if err != nil { + if err := conn.removeWgPeer(); err != nil { conn.log.Errorf("failed to remove wg endpoint: %v", err) } @@ -430,54 +428,59 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon conn.log.Debugf("ICE connection is ready") - conn.statusICE.Set(StatusConnected) - - defer conn.updateIceState(iceConnInfo) - if conn.currentConnPriority > priority { + conn.statusICE.Set(StatusConnected) + conn.updateIceState(iceConnInfo) return } conn.log.Infof("set ICE to active connection") - endpoint, wgProxy, err := conn.getEndpointForICEConnInfo(iceConnInfo) - if err != nil { - return + var ( + ep *net.UDPAddr + wgProxy wgproxy.Proxy + err error + ) + if iceConnInfo.RelayedOnLocal { + wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn) + if err != nil { + conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) + return + } + ep = wgProxy.EndpointAddr() + conn.wgProxyICE = wgProxy + } else { + directEp, err := net.ResolveUDPAddr("udp", iceConnInfo.RemoteConn.RemoteAddr().String()) + if err != nil { + log.Errorf("failed to resolveUDPaddr") + conn.handleConfigurationFailure(err, nil) + return + } + ep = directEp } - endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) - conn.log.Debugf("Conn resolved IP is %s for endopint %s", endpoint, endpointUdpAddr.IP) - - conn.connIDICE = nbnet.GenerateConnID() - for _, hook := range conn.beforeAddPeerHooks { - if err := hook(conn.connIDICE, endpointUdpAddr.IP); err != nil { - conn.log.Errorf("Before add peer hook failed: %v", err) - } + if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil { + conn.log.Errorf("Before add peer hook failed: %v", err) } conn.workerRelay.DisableWgWatcher() - err = conn.configureWGEndpoint(endpointUdpAddr) - if err != nil { - if wgProxy != nil { - if err := wgProxy.CloseConn(); err != nil { - conn.log.Warnf("Failed to close turn connection: %v", err) - } - } - conn.log.Warnf("Failed to update wg peer configuration: %v", err) + if conn.wgProxyRelay != nil { + conn.wgProxyRelay.Pause() + } + + if wgProxy != nil { + wgProxy.Work() + } + + if err = conn.configureWGEndpoint(ep); err != nil { + conn.handleConfigurationFailure(err, wgProxy) return } wgConfigWorkaround() - - if conn.wgProxyICE != nil { - if err := conn.wgProxyICE.CloseConn(); err != nil { - conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) - } - } - conn.wgProxyICE = wgProxy - conn.currentConnPriority = priority - + conn.statusICE.Set(StatusConnected) + conn.updateIceState(iceConnInfo) conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr) } @@ -492,11 +495,18 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { conn.log.Tracef("ICE connection state changed to %s", newState) + if conn.wgProxyICE != nil { + if err := conn.wgProxyICE.CloseConn(); err != nil { + conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) + } + } + // switch back to relay connection - if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay { + if conn.isReadyToUpgrade() { conn.log.Debugf("ICE disconnected, set Relay to active connection") - err := conn.configureWGEndpoint(conn.endpointRelay) - if err != nil { + conn.wgProxyRelay.Work() + + if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil { conn.log.Errorf("failed to switch to relay conn: %v", err) } conn.workerRelay.EnableWgWatcher(conn.ctx) @@ -506,10 +516,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { changed := conn.statusICE.Get() != newState && newState != StatusConnecting conn.statusICE.Set(newState) - select { - case conn.iCEDisconnected <- changed: - default: - } + conn.notifyReconnectLoopICEDisconnected(changed) peerState := State{ PubKey: conn.config.Key, @@ -530,61 +537,48 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { if conn.ctx.Err() != nil { if err := rci.relayedConn.Close(); err != nil { - log.Warnf("failed to close unnecessary relayed connection: %v", err) + conn.log.Warnf("failed to close unnecessary relayed connection: %v", err) } return } - conn.log.Debugf("Relay connection is ready to use") - conn.statusRelay.Set(StatusConnected) + conn.log.Debugf("Relay connection has been established, setup the WireGuard") - wgProxy := conn.wgProxyFactory.GetProxy() - endpoint, err := wgProxy.AddTurnConn(conn.ctx, rci.relayedConn) + wgProxy, err := conn.newProxy(rci.relayedConn) if err != nil { conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err) return } - conn.log.Infof("created new wgProxy for relay connection: %s", endpoint) - endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) - conn.endpointRelay = endpointUdpAddr - conn.log.Debugf("conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP) + conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String()) - defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) - - if conn.currentConnPriority > connPriorityRelay { - if conn.statusICE.Get() == StatusConnected { - log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority) - return - } + if conn.iceP2PIsActive() { + conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority) + conn.wgProxyRelay = wgProxy + conn.statusRelay.Set(StatusConnected) + conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) + return } - conn.connIDRelay = nbnet.GenerateConnID() - for _, hook := range conn.beforeAddPeerHooks { - if err := hook(conn.connIDRelay, endpointUdpAddr.IP); err != nil { - conn.log.Errorf("Before add peer hook failed: %v", err) - } + if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil { + conn.log.Errorf("Before add peer hook failed: %v", err) } - err = conn.configureWGEndpoint(endpointUdpAddr) - if err != nil { + wgProxy.Work() + if err := conn.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil { if err := wgProxy.CloseConn(); err != nil { conn.log.Warnf("Failed to close relay connection: %v", err) } - conn.log.Errorf("Failed to update wg peer configuration: %v", err) + conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err) return } conn.workerRelay.EnableWgWatcher(conn.ctx) + wgConfigWorkaround() - - if conn.wgProxyRelay != nil { - if err := conn.wgProxyRelay.CloseConn(); err != nil { - conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) - } - } - conn.wgProxyRelay = wgProxy conn.currentConnPriority = connPriorityRelay - + conn.statusRelay.Set(StatusConnected) + conn.wgProxyRelay = wgProxy + conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) conn.log.Infof("start to communicate with peer via relay") conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr) } @@ -597,29 +591,23 @@ func (conn *Conn) onWorkerRelayStateDisconnected() { return } - log.Debugf("relay connection is disconnected") + conn.log.Debugf("relay connection is disconnected") if conn.currentConnPriority == connPriorityRelay { - log.Debugf("clean up WireGuard config") - err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) - if err != nil { + conn.log.Debugf("clean up WireGuard config") + if err := conn.removeWgPeer(); err != nil { conn.log.Errorf("failed to remove wg endpoint: %v", err) } } if conn.wgProxyRelay != nil { - conn.endpointRelay = nil _ = conn.wgProxyRelay.CloseConn() conn.wgProxyRelay = nil } changed := conn.statusRelay.Get() != StatusDisconnected conn.statusRelay.Set(StatusDisconnected) - - select { - case conn.relayDisconnected <- changed: - default: - } + conn.notifyReconnectLoopRelayDisconnected(changed) peerState := State{ PubKey: conn.config.Key, @@ -627,9 +615,7 @@ func (conn *Conn) onWorkerRelayStateDisconnected() { Relayed: conn.isRelayed(), ConnStatusUpdate: time.Now(), } - - err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState) - if err != nil { + if err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState); err != nil { conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err) } } @@ -765,6 +751,16 @@ func (conn *Conn) isConnected() bool { return true } +func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error { + conn.connIDICE = nbnet.GenerateConnID() + for _, hook := range conn.beforeAddPeerHooks { + if err := hook(conn.connIDICE, ip); err != nil { + return err + } + } + return nil +} + func (conn *Conn) freeUpConnID() { if conn.connIDRelay != "" { for _, hook := range conn.afterRemovePeerHooks { @@ -785,21 +781,52 @@ func (conn *Conn) freeUpConnID() { } } -func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr, wgproxy.Proxy, error) { - if !iceConnInfo.RelayedOnLocal { - return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil - } - conn.log.Debugf("setup ice turn connection") +func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) { + conn.log.Debugf("setup proxied WireGuard connection") wgProxy := conn.wgProxyFactory.GetProxy() - ep, err := wgProxy.AddTurnConn(conn.ctx, iceConnInfo.RemoteConn) - if err != nil { + if err := wgProxy.AddTurnConn(conn.ctx, remoteConn); err != nil { conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) - if errClose := wgProxy.CloseConn(); errClose != nil { - conn.log.Warnf("failed to close turn proxy connection: %v", errClose) - } - return nil, nil, err + return nil, err + } + return wgProxy, nil +} + +func (conn *Conn) isReadyToUpgrade() bool { + return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay +} + +func (conn *Conn) iceP2PIsActive() bool { + return conn.currentConnPriority == connPriorityICEP2P && conn.statusICE.Get() == StatusConnected +} + +func (conn *Conn) removeWgPeer() error { + return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) +} + +func (conn *Conn) notifyReconnectLoopRelayDisconnected(changed bool) { + select { + case conn.relayDisconnected <- changed: + default: + } +} + +func (conn *Conn) notifyReconnectLoopICEDisconnected(changed bool) { + select { + case conn.iCEDisconnected <- changed: + default: + } +} + +func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) { + conn.log.Warnf("Failed to update wg peer configuration: %v", err) + if wgProxy != nil { + if ierr := wgProxy.CloseConn(); ierr != nil { + conn.log.Warnf("Failed to close wg proxy: %v", ierr) + } + } + if conn.wgProxyRelay != nil { + conn.wgProxyRelay.Work() } - return ep, wgProxy, nil } func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool { diff --git a/client/internal/wgproxy/ebpf/proxy.go b/client/internal/wgproxy/ebpf/proxy.go index 27ede3ef1..e850f4533 100644 --- a/client/internal/wgproxy/ebpf/proxy.go +++ b/client/internal/wgproxy/ebpf/proxy.go @@ -5,7 +5,6 @@ package ebpf import ( "context" "fmt" - "io" "net" "os" "sync" @@ -94,13 +93,12 @@ func (p *WGEBPFProxy) Listen() error { } // AddTurnConn add new turn connection for the proxy -func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) { +func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (*net.UDPAddr, error) { wgEndpointPort, err := p.storeTurnConn(turnConn) if err != nil { return nil, err } - go p.proxyToLocal(ctx, wgEndpointPort, turnConn) log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort) wgEndpoint := &net.UDPAddr{ @@ -137,35 +135,6 @@ func (p *WGEBPFProxy) Free() error { return nberrors.FormatErrorOrNil(result) } -func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) { - defer p.removeTurnConn(endpointPort) - - var ( - err error - n int - ) - buf := make([]byte, 1500) - for ctx.Err() == nil { - n, err = remoteConn.Read(buf) - if err != nil { - if ctx.Err() != nil { - return - } - if err != io.EOF { - log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err) - } - return - } - - if err := p.sendPkg(buf[:n], endpointPort); err != nil { - if ctx.Err() != nil || p.ctx.Err() != nil { - return - } - log.Errorf("failed to write out turn pkg to local conn: %v", err) - } - } -} - // proxyToRemote read messages from local WireGuard interface and forward it to remote conn // From this go routine has only one instance. func (p *WGEBPFProxy) proxyToRemote() { @@ -280,7 +249,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) { return packetConn, nil } -func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error { +func (p *WGEBPFProxy) sendPkg(data []byte, port int) error { localhost := net.ParseIP("127.0.0.1") payload := gopacket.Payload(data) diff --git a/client/internal/wgproxy/ebpf/wrapper.go b/client/internal/wgproxy/ebpf/wrapper.go index c5639f840..b6a8ac452 100644 --- a/client/internal/wgproxy/ebpf/wrapper.go +++ b/client/internal/wgproxy/ebpf/wrapper.go @@ -4,8 +4,13 @@ package ebpf import ( "context" + "errors" "fmt" + "io" "net" + "sync" + + log "github.com/sirupsen/logrus" ) // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call @@ -13,20 +18,55 @@ type ProxyWrapper struct { WgeBPFProxy *WGEBPFProxy remoteConn net.Conn - cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread + ctx context.Context + cancel context.CancelFunc + + wgEndpointAddr *net.UDPAddr + + pausedMu sync.Mutex + paused bool + isStarted bool } -func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) { - ctxConn, cancel := context.WithCancel(ctx) - addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn) - +func (p *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) error { + addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn) if err != nil { - cancel() - return nil, fmt.Errorf("add turn conn: %w", err) + return fmt.Errorf("add turn conn: %w", err) } - e.remoteConn = remoteConn - e.cancel = cancel - return addr, err + p.remoteConn = remoteConn + p.ctx, p.cancel = context.WithCancel(ctx) + p.wgEndpointAddr = addr + return err +} + +func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr { + return p.wgEndpointAddr +} + +func (p *ProxyWrapper) Work() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = false + p.pausedMu.Unlock() + + if !p.isStarted { + p.isStarted = true + go p.proxyToLocal(p.ctx) + } +} + +func (p *ProxyWrapper) Pause() { + if p.remoteConn == nil { + return + } + + log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr()) + p.pausedMu.Lock() + p.paused = true + p.pausedMu.Unlock() } // CloseConn close the remoteConn and automatically remove the conn instance from the map @@ -42,3 +82,45 @@ func (e *ProxyWrapper) CloseConn() error { } return nil } + +func (p *ProxyWrapper) proxyToLocal(ctx context.Context) { + defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port)) + + buf := make([]byte, 1500) + for { + n, err := p.readFromRemote(ctx, buf) + if err != nil { + return + } + + p.pausedMu.Lock() + if p.paused { + p.pausedMu.Unlock() + continue + } + + err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port) + p.pausedMu.Unlock() + + if err != nil { + if ctx.Err() != nil { + return + } + log.Errorf("failed to write out turn pkg to local conn: %v", err) + } + } +} + +func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, error) { + n, err := p.remoteConn.Read(buf) + if err != nil { + if ctx.Err() != nil { + return 0, ctx.Err() + } + if !errors.Is(err, io.EOF) { + log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err) + } + return 0, err + } + return n, nil +} diff --git a/client/internal/wgproxy/proxy.go b/client/internal/wgproxy/proxy.go index 96fae8dd1..558121cdd 100644 --- a/client/internal/wgproxy/proxy.go +++ b/client/internal/wgproxy/proxy.go @@ -7,6 +7,9 @@ import ( // Proxy is a transfer layer between the relayed connection and the WireGuard type Proxy interface { - AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) + AddTurnConn(ctx context.Context, turnConn net.Conn) error + EndpointAddr() *net.UDPAddr + Work() + Pause() CloseConn() error } diff --git a/client/internal/wgproxy/proxy_test.go b/client/internal/wgproxy/proxy_test.go index b09e6be55..b88ff3f83 100644 --- a/client/internal/wgproxy/proxy_test.go +++ b/client/internal/wgproxy/proxy_test.go @@ -114,7 +114,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { relayedConn := newMockConn() - _, err := tt.proxy.AddTurnConn(ctx, relayedConn) + err := tt.proxy.AddTurnConn(ctx, relayedConn) if err != nil { t.Errorf("error: %v", err) } diff --git a/client/internal/wgproxy/usp/proxy.go b/client/internal/wgproxy/usp/proxy.go index 83a8725d8..f73500717 100644 --- a/client/internal/wgproxy/usp/proxy.go +++ b/client/internal/wgproxy/usp/proxy.go @@ -15,13 +15,17 @@ import ( // WGUserSpaceProxy proxies type WGUserSpaceProxy struct { localWGListenPort int - ctx context.Context - cancel context.CancelFunc remoteConn net.Conn localConn net.Conn + ctx context.Context + cancel context.CancelFunc closeMu sync.Mutex closed bool + + pausedMu sync.Mutex + paused bool + isStarted bool } // NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation @@ -33,24 +37,60 @@ func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy { return p } -// AddTurnConn start the proxy with the given remote conn -func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) { - p.ctx, p.cancel = context.WithCancel(ctx) - - p.remoteConn = remoteConn - - var err error +// AddTurnConn +// The provided Context must be non-nil. If the context expires before +// the connection is complete, an error is returned. Once successfully +// connected, any expiration of the context will not affect the +// connection. +func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) error { dialer := net.Dialer{} - p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) + localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) if err != nil { log.Errorf("failed dialing to local Wireguard port %s", err) - return nil, err + return err } - go p.proxyToRemote() - go p.proxyToLocal() + p.ctx, p.cancel = context.WithCancel(ctx) + p.localConn = localConn + p.remoteConn = remoteConn - return p.localConn.LocalAddr(), err + return err +} + +func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr { + if p.localConn == nil { + return nil + } + endpointUdpAddr, _ := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String()) + return endpointUdpAddr +} + +// Work starts the proxy or resumes it if it was paused +func (p *WGUserSpaceProxy) Work() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = false + p.pausedMu.Unlock() + + if !p.isStarted { + p.isStarted = true + go p.proxyToRemote(p.ctx) + go p.proxyToLocal(p.ctx) + } +} + +// Pause pauses the proxy from receiving data from the remote peer +func (p *WGUserSpaceProxy) Pause() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = true + p.pausedMu.Unlock() } // CloseConn close the localConn @@ -85,7 +125,7 @@ func (p *WGUserSpaceProxy) close() error { } // proxyToRemote proxies from Wireguard to the RemoteKey -func (p *WGUserSpaceProxy) proxyToRemote() { +func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) { defer func() { if err := p.close(); err != nil { log.Warnf("error in proxy to remote loop: %s", err) @@ -93,10 +133,10 @@ func (p *WGUserSpaceProxy) proxyToRemote() { }() buf := make([]byte, 1500) - for p.ctx.Err() == nil { + for ctx.Err() == nil { n, err := p.localConn.Read(buf) if err != nil { - if p.ctx.Err() != nil { + if ctx.Err() != nil { return } log.Debugf("failed to read from wg interface conn: %s", err) @@ -105,7 +145,7 @@ func (p *WGUserSpaceProxy) proxyToRemote() { _, err = p.remoteConn.Write(buf[:n]) if err != nil { - if p.ctx.Err() != nil { + if ctx.Err() != nil { return } @@ -116,7 +156,8 @@ func (p *WGUserSpaceProxy) proxyToRemote() { } // proxyToLocal proxies from the Remote peer to local WireGuard -func (p *WGUserSpaceProxy) proxyToLocal() { +// if the proxy is paused it will drain the remote conn and drop the packets +func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) { defer func() { if err := p.close(); err != nil { log.Warnf("error in proxy to local loop: %s", err) @@ -124,19 +165,27 @@ func (p *WGUserSpaceProxy) proxyToLocal() { }() buf := make([]byte, 1500) - for p.ctx.Err() == nil { + for { n, err := p.remoteConn.Read(buf) if err != nil { - if p.ctx.Err() != nil { + if ctx.Err() != nil { return } log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err) return } + p.pausedMu.Lock() + if p.paused { + p.pausedMu.Unlock() + continue + } + _, err = p.localConn.Write(buf[:n]) + p.pausedMu.Unlock() + if err != nil { - if p.ctx.Err() != nil { + if ctx.Err() != nil { return } log.Debugf("failed to write to wg interface conn: %s", err) From da3a053e2bed950bf9cf382f0690435548221745 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sat, 12 Oct 2024 08:35:51 +0200 Subject: [PATCH 34/37] [management] Refactor getAccountIDWithAuthorizationClaims (#2715) This change restructures the getAccountIDWithAuthorizationClaims method to improve readability, maintainability, and performance. - have dedicated methods to handle possible cases - introduced Store.UpdateAccountDomainAttributes and Store.GetAccountUsers methods - Remove GetAccount and SaveAccount dependency - added tests --- management/server/account.go | 353 +++++++++++++++++----------- management/server/account_test.go | 280 +++++++++++----------- management/server/sql_store.go | 37 +++ management/server/sql_store_test.go | 60 +++++ management/server/store.go | 2 + 5 files changed, 450 insertions(+), 282 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 6ee0015f8..c468b5ecc 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -4,6 +4,7 @@ import ( "context" "crypto/sha256" b64 "encoding/base64" + "errors" "fmt" "hash/crc32" "math/rand" @@ -50,6 +51,8 @@ const ( CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days DefaultPeerLoginExpiration = 24 * time.Hour + emptyUserID = "empty user ID in claims" + errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" ) type userLoggedInOnce bool @@ -1285,7 +1288,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID) } - if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil { + if err = am.addAccountIDToIDPAppMeta(ctx, userID, account.Id); err != nil { return "", err } return account.Id, nil @@ -1300,28 +1303,39 @@ func isNil(i idp.Manager) bool { } // addAccountIDToIDPAppMeta update user's app metadata in idp manager -func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, account *Account) error { +func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error { if !isNil(am.idpManager) { + accountUsers, err := am.Store.GetAccountUsers(ctx, accountID) + if err != nil { + return err + } + cachedAccount := &Account{ + Id: accountID, + Users: make(map[string]*User), + } + for _, user := range accountUsers { + cachedAccount.Users[user.Id] = user + } // user can be nil if it wasn't found (e.g., just created) - user, err := am.lookupUserInCache(ctx, userID, account) + user, err := am.lookupUserInCache(ctx, userID, cachedAccount) if err != nil { return err } - if user != nil && user.AppMetadata.WTAccountID == account.Id { + if user != nil && user.AppMetadata.WTAccountID == accountID { // it was already set, so we skip the unnecessary update log.WithContext(ctx).Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s", - account.Id, userID) + accountID, userID) return nil } - err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: account.Id}) + err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: accountID}) if err != nil { return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err) } // refresh cache to reflect the update - _, err = am.refreshCache(ctx, account.Id) + _, err = am.refreshCache(ctx, accountID) if err != nil { return err } @@ -1545,48 +1559,69 @@ func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accoun return am.cacheManager.Set(am.ctx, accountID, data, cacheStore.WithExpiration(cacheEntryExpiration())) } -// updateAccountDomainAttributes updates the account domain attributes and then, saves the account -func (am *DefaultAccountManager) updateAccountDomainAttributes(ctx context.Context, account *Account, claims jwtclaims.AuthorizationClaims, +// updateAccountDomainAttributesIfNotUpToDate updates the account domain attributes if they are not up to date and then, saves the account changes +func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims, primaryDomain bool, ) error { - - if claims.Domain != "" { - account.IsDomainPrimaryAccount = primaryDomain - - lowerDomain := strings.ToLower(claims.Domain) - userObj := account.Users[claims.UserId] - if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin { - account.Domain = lowerDomain - } - // prevent updating category for different domain until admin logs in - if account.Domain == lowerDomain { - account.DomainCategory = claims.DomainCategory - } - } else { + if claims.Domain == "" { log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims) + return nil } - err := am.Store.SaveAccount(ctx, account) + unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlockAccount() + + accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, accountID) if err != nil { + log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) return err } - return nil + + if domainIsUpToDate(accountDomain, domainCategory, claims) { + return nil + } + + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) + if err != nil { + log.WithContext(ctx).Errorf("error getting user: %v", err) + return err + } + + newDomain := accountDomain + newCategoty := domainCategory + + lowerDomain := strings.ToLower(claims.Domain) + if accountDomain != lowerDomain && user.HasAdminPower() { + newDomain = lowerDomain + } + + if accountDomain == lowerDomain { + newCategoty = claims.DomainCategory + } + + return am.Store.UpdateAccountDomainAttributes(ctx, accountID, newDomain, newCategoty, primaryDomain) } // handleExistingUserAccount handles existing User accounts and update its domain attributes. +// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise, +// we compare the account's ID with the domain account ID, and if they don't match, we set the account as +// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain +// was previously unclassified or classified as public so N users that logged int that time, has they own account +// and peers that shouldn't be lost. func (am *DefaultAccountManager) handleExistingUserAccount( ctx context.Context, - existingAcc *Account, - primaryDomain bool, + userAccountID string, + domainAccountID string, claims jwtclaims.AuthorizationClaims, ) error { - err := am.updateAccountDomainAttributes(ctx, existingAcc, claims, primaryDomain) + primaryDomain := domainAccountID == "" || userAccountID == domainAccountID + err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, claims, primaryDomain) if err != nil { return err } // we should register the account ID to this user's metadata in our IDP manager - err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, existingAcc) + err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, userAccountID) if err != nil { return err } @@ -1594,44 +1629,58 @@ func (am *DefaultAccountManager) handleExistingUserAccount( return nil } -// handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account, +// addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account, // otherwise it will create a new account and make it primary account for the domain. -func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) { +func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) { if claims.UserId == "" { - return nil, fmt.Errorf("user ID is empty") + return "", fmt.Errorf("user ID is empty") } - var ( - account *Account - err error - ) + lowerDomain := strings.ToLower(claims.Domain) - // if domain already has a primary account, add regular user - if domainAcc != nil { - account = domainAcc - account.Users[claims.UserId] = NewRegularUser(claims.UserId) - err = am.Store.SaveAccount(ctx, account) - if err != nil { - return nil, err - } - } else { - account, err = am.newAccount(ctx, claims.UserId, lowerDomain) - if err != nil { - return nil, err - } - err = am.updateAccountDomainAttributes(ctx, account, claims, true) - if err != nil { - return nil, err - } - } - err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, account) + newAccount, err := am.newAccount(ctx, claims.UserId, lowerDomain) if err != nil { - return nil, err + return "", err } - am.StoreEvent(ctx, claims.UserId, claims.UserId, account.Id, activity.UserJoined, nil) + newAccount.Domain = lowerDomain + newAccount.DomainCategory = claims.DomainCategory + newAccount.IsDomainPrimaryAccount = true - return account, nil + err = am.Store.SaveAccount(ctx, newAccount) + if err != nil { + return "", err + } + + err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, newAccount.Id) + if err != nil { + return "", err + } + + am.StoreEvent(ctx, claims.UserId, claims.UserId, newAccount.Id, activity.UserJoined, nil) + + return newAccount.Id, nil +} + +func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) { + unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) + defer unlockAccount() + + usersMap := make(map[string]*User) + usersMap[claims.UserId] = NewRegularUser(claims.UserId) + err := am.Store.SaveUsers(domainAccountID, usersMap) + if err != nil { + return "", err + } + + err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, domainAccountID) + if err != nil { + return "", err + } + + am.StoreEvent(ctx, claims.UserId, claims.UserId, domainAccountID, activity.UserJoined, nil) + + return domainAccountID, nil } // redeemInvite checks whether user has been invited and redeems the invite @@ -1775,7 +1824,7 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s // GetAccountIDFromToken returns an account ID associated with this token. func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { if claims.UserId == "" { - return "", "", fmt.Errorf("user ID is empty") + return "", "", errors.New(emptyUserID) } if am.singleAccountMode && am.singleAccountModeDomain != "" { // This section is mostly related to self-hosted installations. @@ -1961,16 +2010,17 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } // getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims. +// if domain is not private or domain is invalid, it will return the account ID by user ID. // if domain is of the PrivateCategory category, it will evaluate // if account is new, existing or if there is another account with the same domain // // Use cases: // -// New user + New account + New domain -> create account, user role = admin (if private domain, index domain) +// New user + New account + New domain -> create account, user role = owner (if private domain, index domain) // -// New user + New account + Existing Private Domain -> add user to the existing account, user role = regular (not admin) +// New user + New account + Existing Private Domain -> add user to the existing account, user role = user (not admin) // -// New user + New account + Existing Public Domain -> create account, user role = admin +// New user + New account + Existing Public Domain -> create account, user role = owner // // Existing user + Existing account + Existing Domain -> Nothing changes (if private, index domain) // @@ -1980,98 +2030,123 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"", claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory) + if claims.UserId == "" { - return "", fmt.Errorf("user ID is empty") + return "", errors.New(emptyUserID) } - // if Account ID is part of the claims - // it means that we've already classified the domain and user has an account if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { - if claims.AccountId != "" { - exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, claims.AccountId) - if err != nil { - return "", err - } - if !exists { - return "", status.Errorf(status.NotFound, "account %s does not exist", claims.AccountId) - } - return claims.AccountId, nil - } return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain) - } else if claims.AccountId != "" { - userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) - if err != nil { - return "", err - } - - if userAccountID != claims.AccountId { - return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) - } - - domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId) - if err != nil { - return "", err - } - - if domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain { - return userAccountID, nil - } } - start := time.Now() - unlock := am.Store.AcquireGlobalLock(ctx) - defer unlock() - log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId) + if claims.AccountId != "" { + return am.handlePrivateAccountWithIDFromClaim(ctx, claims) + } // We checked if the domain has a primary account already - domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain) + domainAccountID, cancel, err := am.getPrivateDomainWithGlobalLock(ctx, claims.Domain) + if cancel != nil { + defer cancel() + } if err != nil { - // if NotFound we are good to continue, otherwise return error - e, ok := status.FromError(err) - if !ok || e.Type() != status.NotFound { - return "", err - } + return "", err } userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) - if err == nil { - unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAccountID) - defer unlockAccount() - account, err := am.Store.GetAccountByUser(ctx, claims.UserId) - if err != nil { - return "", err - } - // If there is no primary domain account yet, we set the account as primary for the domain. Otherwise, - // we compare the account's ID with the domain account ID, and if they don't match, we set the account as - // non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain - // was previously unclassified or classified as public so N users that logged int that time, has they own account - // and peers that shouldn't be lost. - primaryDomain := domainAccountID == "" || account.Id == domainAccountID - if err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims); err != nil { - return "", err - } - - return account.Id, nil - } else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { - var domainAccount *Account - if domainAccountID != "" { - unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) - defer unlockAccount() - domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain) - if err != nil { - return "", err - } - } - - account, err := am.handleNewUserAccount(ctx, domainAccount, claims) - if err != nil { - return "", err - } - return account.Id, nil - } else { - // other error + if handleNotFound(err) != nil { + log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) return "", err } + + if userAccountID != "" { + if err = am.handleExistingUserAccount(ctx, userAccountID, domainAccountID, claims); err != nil { + return "", err + } + + return userAccountID, nil + } + + if domainAccountID != "" { + return am.addNewUserToDomainAccount(ctx, domainAccountID, claims) + } + + return am.addNewPrivateAccount(ctx, domainAccountID, claims) +} +func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) { + domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) + if handleNotFound(err) != nil { + + log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) + return "", nil, err + } + + if domainAccountID != "" { + return domainAccountID, nil, nil + } + + log.WithContext(ctx).Debugf("no primary account found for domain %s, acquiring global lock", domain) + cancel := am.Store.AcquireGlobalLock(ctx) + + // check again if the domain has a primary account because of simultaneous requests + domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) + if handleNotFound(err) != nil { + log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) + return "", nil, err + } + + return domainAccountID, cancel, nil +} + +func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { + userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) + if err != nil { + log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) + return "", err + } + + if userAccountID != claims.AccountId { + return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) + } + + accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId) + if handleNotFound(err) != nil { + log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) + return "", err + } + + if domainIsUpToDate(accountDomain, domainCategory, claims) { + return claims.AccountId, nil + } + + // We checked if the domain has a primary account already + domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain) + if handleNotFound(err) != nil { + log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) + return "", err + } + + err = am.handleExistingUserAccount(ctx, claims.AccountId, domainAccountID, claims) + if err != nil { + return "", err + } + + return claims.AccountId, nil +} + +func handleNotFound(err error) error { + if err == nil { + return nil + } + + e, ok := status.FromError(err) + if !ok || e.Type() != status.NotFound { + return err + } + return nil +} + +func domainIsUpToDate(domain string, domainCategory string, claims jwtclaims.AuthorizationClaims) bool { + return claims.Domain != "" && claims.Domain != domain && claims.DomainCategory == PrivateCategory && domainCategory != PrivateCategory } func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { diff --git a/management/server/account_test.go b/management/server/account_test.go index 4dd58e88e..b20071cba 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -465,7 +465,26 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { type initUserParams jwtclaims.AuthorizationClaims - type test struct { + var ( + publicDomain = "public.com" + privateDomain = "private.com" + unknownDomain = "unknown.com" + ) + + defaultInitAccount := initUserParams{ + Domain: publicDomain, + UserId: "defaultUser", + } + + initUnknown := defaultInitAccount + initUnknown.DomainCategory = UnknownCategory + initUnknown.Domain = unknownDomain + + privateInitAccount := defaultInitAccount + privateInitAccount.Domain = privateDomain + privateInitAccount.DomainCategory = PrivateCategory + + testCases := []struct { name string inputClaims jwtclaims.AuthorizationClaims inputInitUserParams initUserParams @@ -479,156 +498,131 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { expectedPrimaryDomainStatus bool expectedCreatedBy string expectedUsers []string - } - - var ( - publicDomain = "public.com" - privateDomain = "private.com" - unknownDomain = "unknown.com" - ) - - defaultInitAccount := initUserParams{ - Domain: publicDomain, - UserId: "defaultUser", - } - - testCase1 := test{ - name: "New User With Public Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: publicDomain, - UserId: "pub-domain-user", - DomainCategory: PublicCategory, + }{ + { + name: "New User With Public Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: publicDomain, + UserId: "pub-domain-user", + DomainCategory: PublicCategory, + }, + inputInitUserParams: defaultInitAccount, + testingFunc: require.NotEqual, + expectedMSG: "account IDs shouldn't match", + expectedUserRole: UserRoleOwner, + expectedDomainCategory: "", + expectedDomain: publicDomain, + expectedPrimaryDomainStatus: false, + expectedCreatedBy: "pub-domain-user", + expectedUsers: []string{"pub-domain-user"}, }, - inputInitUserParams: defaultInitAccount, - testingFunc: require.NotEqual, - expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, - expectedDomainCategory: "", - expectedDomain: publicDomain, - expectedPrimaryDomainStatus: false, - expectedCreatedBy: "pub-domain-user", - expectedUsers: []string{"pub-domain-user"}, - } - - initUnknown := defaultInitAccount - initUnknown.DomainCategory = UnknownCategory - initUnknown.Domain = unknownDomain - - testCase2 := test{ - name: "New User With Unknown Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: unknownDomain, - UserId: "unknown-domain-user", - DomainCategory: UnknownCategory, + { + name: "New User With Unknown Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: unknownDomain, + UserId: "unknown-domain-user", + DomainCategory: UnknownCategory, + }, + inputInitUserParams: initUnknown, + testingFunc: require.NotEqual, + expectedMSG: "account IDs shouldn't match", + expectedUserRole: UserRoleOwner, + expectedDomain: unknownDomain, + expectedDomainCategory: "", + expectedPrimaryDomainStatus: false, + expectedCreatedBy: "unknown-domain-user", + expectedUsers: []string{"unknown-domain-user"}, }, - inputInitUserParams: initUnknown, - testingFunc: require.NotEqual, - expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, - expectedDomain: unknownDomain, - expectedDomainCategory: "", - expectedPrimaryDomainStatus: false, - expectedCreatedBy: "unknown-domain-user", - expectedUsers: []string{"unknown-domain-user"}, - } - - testCase3 := test{ - name: "New User With Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: privateDomain, - UserId: "pvt-domain-user", - DomainCategory: PrivateCategory, + { + name: "New User With Private Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: privateDomain, + UserId: "pvt-domain-user", + DomainCategory: PrivateCategory, + }, + inputInitUserParams: defaultInitAccount, + testingFunc: require.NotEqual, + expectedMSG: "account IDs shouldn't match", + expectedUserRole: UserRoleOwner, + expectedDomain: privateDomain, + expectedDomainCategory: PrivateCategory, + expectedPrimaryDomainStatus: true, + expectedCreatedBy: "pvt-domain-user", + expectedUsers: []string{"pvt-domain-user"}, }, - inputInitUserParams: defaultInitAccount, - testingFunc: require.NotEqual, - expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, - expectedDomain: privateDomain, - expectedDomainCategory: PrivateCategory, - expectedPrimaryDomainStatus: true, - expectedCreatedBy: "pvt-domain-user", - expectedUsers: []string{"pvt-domain-user"}, - } - - privateInitAccount := defaultInitAccount - privateInitAccount.Domain = privateDomain - privateInitAccount.DomainCategory = PrivateCategory - - testCase4 := test{ - name: "New Regular User With Existing Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: privateDomain, - UserId: "new-pvt-domain-user", - DomainCategory: PrivateCategory, + { + name: "New Regular User With Existing Private Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: privateDomain, + UserId: "new-pvt-domain-user", + DomainCategory: PrivateCategory, + }, + inputUpdateAttrs: true, + inputInitUserParams: privateInitAccount, + testingFunc: require.Equal, + expectedMSG: "account IDs should match", + expectedUserRole: UserRoleUser, + expectedDomain: privateDomain, + expectedDomainCategory: PrivateCategory, + expectedPrimaryDomainStatus: true, + expectedCreatedBy: defaultInitAccount.UserId, + expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"}, }, - inputUpdateAttrs: true, - inputInitUserParams: privateInitAccount, - testingFunc: require.Equal, - expectedMSG: "account IDs should match", - expectedUserRole: UserRoleUser, - expectedDomain: privateDomain, - expectedDomainCategory: PrivateCategory, - expectedPrimaryDomainStatus: true, - expectedCreatedBy: defaultInitAccount.UserId, - expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"}, - } - - testCase5 := test{ - name: "Existing User With Existing Reclassified Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: defaultInitAccount.Domain, - UserId: defaultInitAccount.UserId, - DomainCategory: PrivateCategory, + { + name: "Existing User With Existing Reclassified Private Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: defaultInitAccount.Domain, + UserId: defaultInitAccount.UserId, + DomainCategory: PrivateCategory, + }, + inputInitUserParams: defaultInitAccount, + testingFunc: require.Equal, + expectedMSG: "account IDs should match", + expectedUserRole: UserRoleOwner, + expectedDomain: defaultInitAccount.Domain, + expectedDomainCategory: PrivateCategory, + expectedPrimaryDomainStatus: true, + expectedCreatedBy: defaultInitAccount.UserId, + expectedUsers: []string{defaultInitAccount.UserId}, }, - inputInitUserParams: defaultInitAccount, - testingFunc: require.Equal, - expectedMSG: "account IDs should match", - expectedUserRole: UserRoleOwner, - expectedDomain: defaultInitAccount.Domain, - expectedDomainCategory: PrivateCategory, - expectedPrimaryDomainStatus: true, - expectedCreatedBy: defaultInitAccount.UserId, - expectedUsers: []string{defaultInitAccount.UserId}, - } - - testCase6 := test{ - name: "Existing Account Id With Existing Reclassified Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: defaultInitAccount.Domain, - UserId: defaultInitAccount.UserId, - DomainCategory: PrivateCategory, + { + name: "Existing Account Id With Existing Reclassified Private Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: defaultInitAccount.Domain, + UserId: defaultInitAccount.UserId, + DomainCategory: PrivateCategory, + }, + inputUpdateClaimAccount: true, + inputInitUserParams: defaultInitAccount, + testingFunc: require.Equal, + expectedMSG: "account IDs should match", + expectedUserRole: UserRoleOwner, + expectedDomain: defaultInitAccount.Domain, + expectedDomainCategory: PrivateCategory, + expectedPrimaryDomainStatus: true, + expectedCreatedBy: defaultInitAccount.UserId, + expectedUsers: []string{defaultInitAccount.UserId}, }, - inputUpdateClaimAccount: true, - inputInitUserParams: defaultInitAccount, - testingFunc: require.Equal, - expectedMSG: "account IDs should match", - expectedUserRole: UserRoleOwner, - expectedDomain: defaultInitAccount.Domain, - expectedDomainCategory: PrivateCategory, - expectedPrimaryDomainStatus: true, - expectedCreatedBy: defaultInitAccount.UserId, - expectedUsers: []string{defaultInitAccount.UserId}, - } - - testCase7 := test{ - name: "User With Private Category And Empty Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: "", - UserId: "pvt-domain-user", - DomainCategory: PrivateCategory, + { + name: "User With Private Category And Empty Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: "", + UserId: "pvt-domain-user", + DomainCategory: PrivateCategory, + }, + inputInitUserParams: defaultInitAccount, + testingFunc: require.NotEqual, + expectedMSG: "account IDs shouldn't match", + expectedUserRole: UserRoleOwner, + expectedDomain: "", + expectedDomainCategory: "", + expectedPrimaryDomainStatus: false, + expectedCreatedBy: "pvt-domain-user", + expectedUsers: []string{"pvt-domain-user"}, }, - inputInitUserParams: defaultInitAccount, - testingFunc: require.NotEqual, - expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, - expectedDomain: "", - expectedDomainCategory: "", - expectedPrimaryDomainStatus: false, - expectedCreatedBy: "pvt-domain-user", - expectedUsers: []string{"pvt-domain-user"}, } - for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6, testCase7} { + for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") @@ -640,7 +634,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { require.NoError(t, err, "get init account failed") if testCase.inputUpdateAttrs { - err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) + err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) require.NoError(t, err, "update init user failed") } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 615203bee..de3dfa945 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -323,6 +323,29 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer. return nil } +func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error { + accountCopy := Account{ + Domain: domain, + DomainCategory: category, + IsDomainPrimaryAccount: isPrimaryDomain, + } + + fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"} + result := s.db.WithContext(ctx).Model(&Account{}). + Select(fieldsToUpdate). + Where(idQueryCondition, accountID). + Updates(&accountCopy) + if result.Error != nil { + return result.Error + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "account %s", accountID) + } + + return nil +} + func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { var peerCopy nbpeer.Peer peerCopy.Status = &peerStatus @@ -518,6 +541,20 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre return &user, nil } +func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) { + var users []*User + result := s.db.Find(&users, accountIDCondition, accountID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") + } + log.WithContext(ctx).Errorf("error when getting users from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "issue getting users from store") + } + + return users, nil +} + func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { var groups []*nbgroup.Group result := s.db.Find(&groups, accountIDCondition, accountID) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 06e118fd2..20e812ea7 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1191,3 +1191,63 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { }) assert.NoError(t, err) } + +func TestSqlite_GetAccoundUsers(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) + } + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + account, err := store.GetAccount(context.Background(), accountID) + require.NoError(t, err) + users, err := store.GetAccountUsers(context.Background(), accountID) + require.NoError(t, err) + require.Len(t, users, len(account.Users)) +} + +func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) + } + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + t.Run("Should update attributes with public domain", func(t *testing.T) { + require.NoError(t, err) + domain := "example.com" + category := "public" + IsDomainPrimaryAccount := false + err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount) + require.NoError(t, err) + account, err := store.GetAccount(context.Background(), accountID) + require.NoError(t, err) + require.Equal(t, domain, account.Domain) + require.Equal(t, category, account.DomainCategory) + require.Equal(t, IsDomainPrimaryAccount, account.IsDomainPrimaryAccount) + }) + + t.Run("Should update attributes with private domain", func(t *testing.T) { + require.NoError(t, err) + domain := "test.com" + category := "private" + IsDomainPrimaryAccount := true + err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount) + require.NoError(t, err) + account, err := store.GetAccount(context.Background(), accountID) + require.NoError(t, err) + require.Equal(t, domain, account.Domain) + require.Equal(t, category, account.DomainCategory) + require.Equal(t, IsDomainPrimaryAccount, account.IsDomainPrimaryAccount) + }) + + t.Run("Should fail when account does not exist", func(t *testing.T) { + require.NoError(t, err) + domain := "test.com" + category := "private" + IsDomainPrimaryAccount := true + err = store.UpdateAccountDomainAttributes(context.Background(), "non-existing-account-id", domain, category, IsDomainPrimaryAccount) + require.Error(t, err) + }) + +} diff --git a/management/server/store.go b/management/server/store.go index d914bb8f7..131fd8aaa 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -58,9 +58,11 @@ type Store interface { GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) SaveAccount(ctx context.Context, account *Account) error DeleteAccount(ctx context.Context, account *Account) error + UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) + GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) SaveUsers(accountID string, users map[string]*User) error SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error From 3a88ac78ff80b77eecfdcf9d7d66663b017419aa Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Sat, 12 Oct 2024 10:44:48 +0200 Subject: [PATCH 35/37] [client] Add table filter rules using iptables (#2727) This specifically concerns the established/related rule since this one is not compatible with iptables-nft even if it is generated the same way by iptables-translate. --- client/firewall/nftables/manager_linux.go | 47 +++-- .../firewall/nftables/manager_linux_test.go | 1 + client/firewall/nftables/router_linux.go | 186 ++++++++++++------ 3 files changed, 148 insertions(+), 86 deletions(-) diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index d2258ae08..01b08bd71 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -315,28 +315,33 @@ func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain * rule := &nftables.Rule{ Table: table, Chain: chain, - Exprs: []expr.Any{ - &expr.Ct{ - Key: expr.CtKeySTATE, - Register: 1, - }, - &expr.Bitwise{ - SourceRegister: 1, - DestRegister: 1, - Len: 4, - Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED), - Xor: binaryutil.NativeEndian.PutUint32(0), - }, - &expr.Cmp{ - Op: expr.CmpOpNeq, - Register: 1, - Data: []byte{0, 0, 0, 0}, - }, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - }, + Exprs: getEstablishedExprs(1), } conn.InsertRule(rule) } + +func getEstablishedExprs(register uint32) []expr.Any { + return []expr.Any{ + &expr.Ct{ + Key: expr.CtKeySTATE, + Register: register, + }, + &expr.Bitwise{ + SourceRegister: register, + DestRegister: register, + Len: 4, + Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED), + Xor: binaryutil.NativeEndian.PutUint32(0), + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: register, + Data: []byte{0, 0, 0, 0}, + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + } +} diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 904050a51..bbe18ab07 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -109,6 +109,7 @@ func TestNftablesManager(t *testing.T) { Register: 1, Data: []byte{0, 0, 0, 0}, }, + &expr.Counter{}, &expr.Verdict{ Kind: expr.VerdictAccept, }, diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 9b8fdbda5..404ba6957 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -10,6 +10,7 @@ import ( "net/netip" "strings" + "github.com/coreos/go-iptables/iptables" "github.com/davecgh/go-spew/spew" "github.com/google/nftables" "github.com/google/nftables/binaryutil" @@ -81,7 +82,7 @@ func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFa } } - err = r.cleanUpDefaultForwardRules() + err = r.removeAcceptForwardRules() if err != nil { log.Errorf("failed to clean up rules from FORWARD chain: %s", err) } @@ -98,40 +99,7 @@ func (r *router) Reset() error { // clear without deleting the ipsets, the nf table will be deleted by the caller r.ipsetCounter.Clear() - return r.cleanUpDefaultForwardRules() -} - -func (r *router) cleanUpDefaultForwardRules() error { - if r.filterTable == nil { - return nil - } - - chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) - if err != nil { - return fmt.Errorf("list chains: %v", err) - } - - for _, chain := range chains { - if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward { - continue - } - - rules, err := r.conn.GetRules(r.filterTable, chain) - if err != nil { - return fmt.Errorf("get rules: %v", err) - } - - for _, rule := range rules { - if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) || - bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) { - if err := r.conn.DelRule(rule); err != nil { - return fmt.Errorf("delete rule: %v", err) - } - } - } - } - - return r.conn.Flush() + return r.removeAcceptForwardRules() } func (r *router) loadFilterTable() (*nftables.Table, error) { @@ -167,7 +135,9 @@ func (r *router) createContainers() error { Type: nftables.ChainTypeNAT, }) - r.acceptForwardRules() + if err := r.acceptForwardRules(); err != nil { + log.Errorf("failed to add accept rules for the forward chain: %s", err) + } if err := r.refreshRulesMap(); err != nil { log.Errorf("failed to clean up rules from FORWARD chain: %s", err) @@ -577,19 +547,60 @@ func (r *router) RemoveAllLegacyRouteRules() error { // that our traffic is not dropped by existing rules there. // The existing FORWARD rules/policies decide outbound traffic towards our interface. // In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. -func (r *router) acceptForwardRules() { +func (r *router) acceptForwardRules() error { if r.filterTable == nil { log.Debugf("table 'filter' not found for forward rules, skipping accept rules") - return + return nil } + fw := "iptables" + + defer func() { + log.Debugf("Used %s to add accept forward rules", fw) + }() + + // Try iptables first and fallback to nftables if iptables is not available + ipt, err := iptables.New() + if err != nil { + // filter table exists but iptables is not + log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) + + fw = "nftables" + return r.acceptForwardRulesNftables() + } + + return r.acceptForwardRulesIptables(ipt) +} + +func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error { + var merr *multierror.Error + for _, rule := range r.getAcceptForwardRules() { + if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil { + merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err)) + } else { + log.Debugf("added iptables rule: %v", rule) + } + } + + return nberrors.FormatErrorOrNil(merr) +} + +func (r *router) getAcceptForwardRules() [][]string { + intf := r.wgIface.Name() + return [][]string{ + {"-i", intf, "-j", "ACCEPT"}, + {"-o", intf, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}, + } +} + +func (r *router) acceptForwardRulesNftables() error { intf := ifname(r.wgIface.Name()) // Rule for incoming interface (iif) with counter iifRule := &nftables.Rule{ Table: r.filterTable, Chain: &nftables.Chain{ - Name: "FORWARD", + Name: chainNameForward, Table: r.filterTable, Type: nftables.ChainTypeFilter, Hooknum: nftables.ChainHookForward, @@ -609,6 +620,15 @@ func (r *router) acceptForwardRules() { } r.conn.InsertRule(iifRule) + oifExprs := []expr.Any{ + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: intf, + }, + } + // Rule for outgoing interface (oif) with counter oifRule := &nftables.Rule{ Table: r.filterTable, @@ -619,36 +639,72 @@ func (r *router) acceptForwardRules() { Hooknum: nftables.ChainHookForward, Priority: nftables.ChainPriorityFilter, }, - Exprs: []expr.Any{ - &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: intf, - }, - &expr.Ct{ - Key: expr.CtKeySTATE, - Register: 2, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED), - Xor: binaryutil.NativeEndian.PutUint32(0), - }, - &expr.Cmp{ - Op: expr.CmpOpNeq, - Register: 2, - Data: []byte{0, 0, 0, 0}, - }, - &expr.Counter{}, - &expr.Verdict{Kind: expr.VerdictAccept}, - }, + Exprs: append(oifExprs, getEstablishedExprs(2)...), UserData: []byte(userDataAcceptForwardRuleOif), } r.conn.InsertRule(oifRule) + + return nil +} + +func (r *router) removeAcceptForwardRules() error { + if r.filterTable == nil { + return nil + } + + // Try iptables first and fallback to nftables if iptables is not available + ipt, err := iptables.New() + if err != nil { + log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) + return r.removeAcceptForwardRulesNftables() + } + + return r.removeAcceptForwardRulesIptables(ipt) +} + +func (r *router) removeAcceptForwardRulesNftables() error { + chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) + if err != nil { + return fmt.Errorf("list chains: %v", err) + } + + for _, chain := range chains { + if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward { + continue + } + + rules, err := r.conn.GetRules(r.filterTable, chain) + if err != nil { + return fmt.Errorf("get rules: %v", err) + } + + for _, rule := range rules { + if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) || + bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) { + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("delete rule: %v", err) + } + } + } + } + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf(flushError, err) + } + + return nil +} + +func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error { + var merr *multierror.Error + for _, rule := range r.getAcceptForwardRules() { + if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil { + merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err)) + } + } + + return nberrors.FormatErrorOrNil(merr) } // RemoveNatRule removes a nftables rule pair from nat chains From d93dd4fc7f47c9e1ac597af2570e6faaa36f1219 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Sat, 12 Oct 2024 18:21:34 +0200 Subject: [PATCH 36/37] [relay-server] Move the handshake logic to separated struct (#2648) * Move the handshake logic to separated struct - The server will response to the client after it ready to process the peer - Preload the response messages * Fix deprecated lint issue * Fix error handling * [relay-server] Relay measure auth time (#2675) Measure the Relay client's authentication time --- relay/metrics/realy.go | 52 +++++++++++-- relay/server/handshake.go | 153 ++++++++++++++++++++++++++++++++++++++ relay/server/relay.go | 127 ++++++------------------------- 3 files changed, 223 insertions(+), 109 deletions(-) create mode 100644 relay/server/handshake.go diff --git a/relay/metrics/realy.go b/relay/metrics/realy.go index 13799713a..4dc98a0e0 100644 --- a/relay/metrics/realy.go +++ b/relay/metrics/realy.go @@ -16,8 +16,10 @@ const ( type Metrics struct { metric.Meter - TransferBytesSent metric.Int64Counter - TransferBytesRecv metric.Int64Counter + TransferBytesSent metric.Int64Counter + TransferBytesRecv metric.Int64Counter + AuthenticationTime metric.Float64Histogram + PeerStoreTime metric.Float64Histogram peers metric.Int64UpDownCounter peerActivityChan chan string @@ -52,11 +54,23 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) { return nil, err } + authTime, err := meter.Float64Histogram("relay_peer_authentication_time_milliseconds", metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...)) + if err != nil { + return nil, err + } + + peerStoreTime, err := meter.Float64Histogram("relay_peer_store_time_milliseconds", metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...)) + if err != nil { + return nil, err + } + m := &Metrics{ - Meter: meter, - TransferBytesSent: bytesSent, - TransferBytesRecv: bytesRecv, - peers: peers, + Meter: meter, + TransferBytesSent: bytesSent, + TransferBytesRecv: bytesRecv, + AuthenticationTime: authTime, + PeerStoreTime: peerStoreTime, + peers: peers, ctx: ctx, peerActivityChan: make(chan string, 10), @@ -89,6 +103,16 @@ func (m *Metrics) PeerConnected(id string) { m.peerLastActive[id] = time.Time{} } +// RecordAuthenticationTime measures the time taken for peer authentication +func (m *Metrics) RecordAuthenticationTime(duration time.Duration) { + m.AuthenticationTime.Record(m.ctx, float64(duration.Nanoseconds())/1e6) +} + +// RecordPeerStoreTime measures the time to store the peer in map +func (m *Metrics) RecordPeerStoreTime(duration time.Duration) { + m.PeerStoreTime.Record(m.ctx, float64(duration.Nanoseconds())/1e6) +} + // PeerDisconnected decrements the number of connected peers and decrements number of idle or active connections func (m *Metrics) PeerDisconnected(id string) { m.peers.Add(m.ctx, -1) @@ -134,3 +158,19 @@ func (m *Metrics) readPeerActivity() { } } } + +func getStandardBucketBoundaries() []float64 { + return []float64{ + 0.1, + 0.5, + 1, + 5, + 10, + 50, + 100, + 500, + 1000, + 5000, + 10000, + } +} diff --git a/relay/server/handshake.go b/relay/server/handshake.go new file mode 100644 index 000000000..0257300f8 --- /dev/null +++ b/relay/server/handshake.go @@ -0,0 +1,153 @@ +package server + +import ( + "fmt" + "net" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/relay/auth" + "github.com/netbirdio/netbird/relay/messages" + //nolint:staticcheck + "github.com/netbirdio/netbird/relay/messages/address" + //nolint:staticcheck + authmsg "github.com/netbirdio/netbird/relay/messages/auth" +) + +// preparedMsg contains the marshalled success response messages +type preparedMsg struct { + responseHelloMsg []byte + responseAuthMsg []byte +} + +func newPreparedMsg(instanceURL string) (*preparedMsg, error) { + rhm, err := marshalResponseHelloMsg(instanceURL) + if err != nil { + return nil, err + } + + ram, err := messages.MarshalAuthResponse(instanceURL) + if err != nil { + return nil, fmt.Errorf("failed to marshal auth response msg: %w", err) + } + + return &preparedMsg{ + responseHelloMsg: rhm, + responseAuthMsg: ram, + }, nil +} + +func marshalResponseHelloMsg(instanceURL string) ([]byte, error) { + addr := &address.Address{URL: instanceURL} + addrData, err := addr.Marshal() + if err != nil { + return nil, fmt.Errorf("failed to marshal response address: %w", err) + } + + //nolint:staticcheck + responseMsg, err := messages.MarshalHelloResponse(addrData) + if err != nil { + return nil, fmt.Errorf("failed to marshal hello response: %w", err) + } + return responseMsg, nil +} + +type handshake struct { + conn net.Conn + validator auth.Validator + preparedMsg *preparedMsg + + handshakeMethodAuth bool + peerID string +} + +func (h *handshake) handshakeReceive() ([]byte, error) { + buf := make([]byte, messages.MaxHandshakeSize) + n, err := h.conn.Read(buf) + if err != nil { + return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err) + } + + _, err = messages.ValidateVersion(buf[:n]) + if err != nil { + return nil, fmt.Errorf("validate version from %s: %w", h.conn.RemoteAddr(), err) + } + + msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n]) + if err != nil { + return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err) + } + + var ( + bytePeerID []byte + peerID string + ) + switch msgType { + //nolint:staticcheck + case messages.MsgTypeHello: + bytePeerID, peerID, err = h.handleHelloMsg(buf[messages.SizeOfProtoHeader:n]) + case messages.MsgTypeAuth: + h.handshakeMethodAuth = true + bytePeerID, peerID, err = h.handleAuthMsg(buf[messages.SizeOfProtoHeader:n]) + default: + return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr()) + } + if err != nil { + return nil, err + } + h.peerID = peerID + return bytePeerID, nil +} + +func (h *handshake) handshakeResponse() error { + var responseMsg []byte + if h.handshakeMethodAuth { + responseMsg = h.preparedMsg.responseAuthMsg + } else { + responseMsg = h.preparedMsg.responseHelloMsg + } + + if _, err := h.conn.Write(responseMsg); err != nil { + return fmt.Errorf("handshake response write to %s (%s): %w", h.peerID, h.conn.RemoteAddr(), err) + } + + return nil +} + +func (h *handshake) handleHelloMsg(buf []byte) ([]byte, string, error) { + //nolint:staticcheck + rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf) + if err != nil { + return nil, "", fmt.Errorf("unmarshal hello message: %w", err) + } + + peerID := messages.HashIDToString(rawPeerID) + log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, h.conn.RemoteAddr()) + + authMsg, err := authmsg.UnmarshalMsg(authData) + if err != nil { + return nil, "", fmt.Errorf("unmarshal auth message: %w", err) + } + + //nolint:staticcheck + if err := h.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil { + return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err) + } + + return rawPeerID, peerID, nil +} + +func (h *handshake) handleAuthMsg(buf []byte) ([]byte, string, error) { + rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf) + if err != nil { + return nil, "", fmt.Errorf("unmarshal hello message: %w", err) + } + + peerID := messages.HashIDToString(rawPeerID) + + if err := h.validator.Validate(authPayload); err != nil { + return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err) + } + + return rawPeerID, peerID, nil +} diff --git a/relay/server/relay.go b/relay/server/relay.go index 76c01a697..6cd8506ae 100644 --- a/relay/server/relay.go +++ b/relay/server/relay.go @@ -7,16 +7,13 @@ import ( "net/url" "strings" "sync" + "time" log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/metric" "github.com/netbirdio/netbird/relay/auth" - "github.com/netbirdio/netbird/relay/messages" //nolint:staticcheck - "github.com/netbirdio/netbird/relay/messages/address" - //nolint:staticcheck - authmsg "github.com/netbirdio/netbird/relay/messages/auth" "github.com/netbirdio/netbird/relay/metrics" ) @@ -28,6 +25,7 @@ type Relay struct { store *Store instanceURL string + preparedMsg *preparedMsg closed bool closeMu sync.RWMutex @@ -69,6 +67,12 @@ func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, valida return nil, fmt.Errorf("get instance URL: %v", err) } + r.preparedMsg, err = newPreparedMsg(r.instanceURL) + if err != nil { + metricsCancel() + return nil, fmt.Errorf("prepare message: %v", err) + } + return r, nil } @@ -100,17 +104,22 @@ func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) { // Accept start to handle a new peer connection func (r *Relay) Accept(conn net.Conn) { + acceptTime := time.Now() r.closeMu.RLock() defer r.closeMu.RUnlock() if r.closed { return } - peerID, err := r.handshake(conn) + h := handshake{ + conn: conn, + validator: r.validator, + preparedMsg: r.preparedMsg, + } + peerID, err := h.handshakeReceive() if err != nil { log.Errorf("failed to handshake: %s", err) - cErr := conn.Close() - if cErr != nil { + if cErr := conn.Close(); cErr != nil { log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr) } return @@ -118,7 +127,9 @@ func (r *Relay) Accept(conn net.Conn) { peer := NewPeer(r.metrics, peerID, conn, r.store) peer.log.Infof("peer connected from: %s", conn.RemoteAddr()) + storeTime := time.Now() r.store.AddPeer(peer) + r.metrics.RecordPeerStoreTime(time.Since(storeTime)) r.metrics.PeerConnected(peer.String()) go func() { peer.Work() @@ -126,6 +137,12 @@ func (r *Relay) Accept(conn net.Conn) { peer.log.Debugf("relay connection closed") r.metrics.PeerDisconnected(peer.String()) }() + + if err := h.handshakeResponse(); err != nil { + log.Errorf("failed to send handshake response, close peer: %s", err) + peer.Close() + } + r.metrics.RecordAuthenticationTime(time.Since(acceptTime)) } // Shutdown closes the relay server @@ -151,99 +168,3 @@ func (r *Relay) Shutdown(ctx context.Context) { func (r *Relay) InstanceURL() string { return r.instanceURL } - -func (r *Relay) handshake(conn net.Conn) ([]byte, error) { - buf := make([]byte, messages.MaxHandshakeSize) - n, err := conn.Read(buf) - if err != nil { - return nil, fmt.Errorf("read from %s: %w", conn.RemoteAddr(), err) - } - - _, err = messages.ValidateVersion(buf[:n]) - if err != nil { - return nil, fmt.Errorf("validate version from %s: %w", conn.RemoteAddr(), err) - } - - msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n]) - if err != nil { - return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err) - } - - var ( - responseMsg []byte - peerID []byte - ) - switch msgType { - //nolint:staticcheck - case messages.MsgTypeHello: - peerID, responseMsg, err = r.handleHelloMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr()) - case messages.MsgTypeAuth: - peerID, responseMsg, err = r.handleAuthMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr()) - default: - return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr()) - } - if err != nil { - return nil, err - } - - _, err = conn.Write(responseMsg) - if err != nil { - return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err) - } - - return peerID, nil -} - -func (r *Relay) handleHelloMsg(buf []byte, remoteAddr net.Addr) ([]byte, []byte, error) { - //nolint:staticcheck - rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf) - if err != nil { - return nil, nil, fmt.Errorf("unmarshal hello message: %w", err) - } - - peerID := messages.HashIDToString(rawPeerID) - log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, remoteAddr) - - authMsg, err := authmsg.UnmarshalMsg(authData) - if err != nil { - return nil, nil, fmt.Errorf("unmarshal auth message: %w", err) - } - - //nolint:staticcheck - if err := r.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil { - return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, remoteAddr, err) - } - - addr := &address.Address{URL: r.instanceURL} - addrData, err := addr.Marshal() - if err != nil { - return nil, nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, remoteAddr, err) - } - - //nolint:staticcheck - responseMsg, err := messages.MarshalHelloResponse(addrData) - if err != nil { - return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, remoteAddr, err) - } - return rawPeerID, responseMsg, nil -} - -func (r *Relay) handleAuthMsg(buf []byte, addr net.Addr) ([]byte, []byte, error) { - rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf) - if err != nil { - return nil, nil, fmt.Errorf("unmarshal hello message: %w", err) - } - - peerID := messages.HashIDToString(rawPeerID) - - if err := r.validator.Validate(authPayload); err != nil { - return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, addr, err) - } - - responseMsg, err := messages.MarshalAuthResponse(r.instanceURL) - if err != nil { - return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, addr, err) - } - - return rawPeerID, responseMsg, nil -} From 49e65109d25a98bfa73245e1252663d37184533a Mon Sep 17 00:00:00 2001 From: ctrl-zzz <78654296+ctrl-zzz@users.noreply.github.com> Date: Sun, 13 Oct 2024 14:52:43 +0200 Subject: [PATCH 37/37] Add session expire functionality based on inactivity (#2326) Implemented inactivity expiration by checking the status of a peer: after a configurable period of time following netbird down, the peer shows login required. --- management/server/account.go | 133 +++++++++ management/server/account_test.go | 315 +++++++++++++++++++++ management/server/activity/codes.go | 14 + management/server/file_store.go | 3 + management/server/http/accounts_handler.go | 3 + management/server/http/api/openapi.yml | 19 ++ management/server/http/api/types.gen.go | 21 +- management/server/http/peers_handler.go | 54 ++-- management/server/peer.go | 97 +++++-- management/server/peer/peer.go | 20 ++ management/server/peer_test.go | 62 ++++ 11 files changed, 682 insertions(+), 59 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index c468b5ecc..7c84ad1ca 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -51,6 +51,7 @@ const ( CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days DefaultPeerLoginExpiration = 24 * time.Hour + DefaultPeerInactivityExpiration = 10 * time.Minute emptyUserID = "empty user ID in claims" errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" ) @@ -181,6 +182,8 @@ type DefaultAccountManager struct { dnsDomain string peerLoginExpiry Scheduler + peerInactivityExpiry Scheduler + // userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account userDeleteFromIDPEnabled bool @@ -198,6 +201,13 @@ type Settings struct { // Applies to all peers that have Peer.LoginExpirationEnabled set to true. PeerLoginExpiration time.Duration + // PeerInactivityExpirationEnabled globally enables or disables peer inactivity expiration + PeerInactivityExpirationEnabled bool + + // PeerInactivityExpiration is a setting that indicates when peer inactivity expires. + // Applies to all peers that have Peer.PeerInactivityExpirationEnabled set to true. + PeerInactivityExpiration time.Duration + // RegularUsersViewBlocked allows to block regular users from viewing even their own peers and some UI elements RegularUsersViewBlocked bool @@ -228,6 +238,9 @@ func (s *Settings) Copy() *Settings { GroupsPropagationEnabled: s.GroupsPropagationEnabled, JWTAllowGroups: s.JWTAllowGroups, RegularUsersViewBlocked: s.RegularUsersViewBlocked, + + PeerInactivityExpirationEnabled: s.PeerInactivityExpirationEnabled, + PeerInactivityExpiration: s.PeerInactivityExpiration, } if s.Extra != nil { settings.Extra = s.Extra.Copy() @@ -609,6 +622,60 @@ func (a *Account) GetPeersWithExpiration() []*nbpeer.Peer { return peers } +// GetInactivePeers returns peers that have been expired by inactivity +func (a *Account) GetInactivePeers() []*nbpeer.Peer { + var peers []*nbpeer.Peer + for _, inactivePeer := range a.GetPeersWithInactivity() { + inactive, _ := inactivePeer.SessionExpired(a.Settings.PeerInactivityExpiration) + if inactive { + peers = append(peers, inactivePeer) + } + } + return peers +} + +// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. +// If there is no peer that expires this function returns false and a duration of 0. +// This function only considers peers that haven't been expired yet and that are not connected. +func (a *Account) GetNextInactivePeerExpiration() (time.Duration, bool) { + peersWithExpiry := a.GetPeersWithInactivity() + if len(peersWithExpiry) == 0 { + return 0, false + } + var nextExpiry *time.Duration + for _, peer := range peersWithExpiry { + if peer.Status.LoginExpired || peer.Status.Connected { + continue + } + _, duration := peer.SessionExpired(a.Settings.PeerInactivityExpiration) + if nextExpiry == nil || duration < *nextExpiry { + // if expiration is below 1s return 1s duration + // this avoids issues with ticker that can't be set to < 0 + if duration < time.Second { + return time.Second, true + } + nextExpiry = &duration + } + } + + if nextExpiry == nil { + return 0, false + } + + return *nextExpiry, true +} + +// GetPeersWithInactivity eturns a list of peers that have Peer.InactivityExpirationEnabled set to true and that were added by a user +func (a *Account) GetPeersWithInactivity() []*nbpeer.Peer { + peers := make([]*nbpeer.Peer, 0) + for _, peer := range a.Peers { + if peer.InactivityExpirationEnabled && peer.AddedWithSSOLogin() { + peers = append(peers, peer) + } + } + return peers +} + // GetPeers returns a list of all Account peers func (a *Account) GetPeers() []*nbpeer.Peer { var peers []*nbpeer.Peer @@ -975,6 +1042,7 @@ func BuildManager( dnsDomain: dnsDomain, eventStore: eventStore, peerLoginExpiry: NewDefaultScheduler(), + peerInactivityExpiry: NewDefaultScheduler(), userDeleteFromIDPEnabled: userDeleteFromIDPEnabled, integratedPeerValidator: integratedPeerValidator, metrics: metrics, @@ -1103,6 +1171,11 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco am.checkAndSchedulePeerLoginExpiration(ctx, account) } + err = am.handleInactivityExpirationSettings(ctx, account, oldSettings, newSettings, userID, accountID) + if err != nil { + return nil, err + } + updatedAccount := account.UpdateSettings(newSettings) err = am.Store.SaveAccount(ctx, account) @@ -1113,6 +1186,26 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return updatedAccount, nil } +func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error { + if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled { + event := activity.AccountPeerInactivityExpirationEnabled + if !newSettings.PeerInactivityExpirationEnabled { + event = activity.AccountPeerInactivityExpirationDisabled + am.peerInactivityExpiry.Cancel(ctx, []string{accountID}) + } else { + am.checkAndSchedulePeerInactivityExpiration(ctx, account) + } + am.StoreEvent(ctx, userID, accountID, accountID, event, nil) + } + + if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil) + am.checkAndSchedulePeerInactivityExpiration(ctx, account) + } + + return nil +} + func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { return func() (time.Duration, bool) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) @@ -1148,6 +1241,43 @@ func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context } } +// peerInactivityExpirationJob marks login expired for all inactive peers and returns the minimum duration in which the next peer of the account will expire by inactivity if found +func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { + return func() (time.Duration, bool) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + account, err := am.Store.GetAccount(ctx, accountID) + if err != nil { + log.Errorf("failed getting account %s expiring peers", account.Id) + return account.GetNextInactivePeerExpiration() + } + + expiredPeers := account.GetInactivePeers() + var peerIDs []string + for _, peer := range expiredPeers { + peerIDs = append(peerIDs, peer.ID) + } + + log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id) + + if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil { + log.Errorf("failed updating account peers while expiring peers for account %s", account.Id) + return account.GetNextInactivePeerExpiration() + } + + return account.GetNextInactivePeerExpiration() + } +} + +// checkAndSchedulePeerInactivityExpiration periodically checks for inactive peers to end their sessions +func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, account *Account) { + am.peerInactivityExpiry.Cancel(ctx, []string{account.Id}) + if nextRun, ok := account.GetNextInactivePeerExpiration(); ok { + go am.peerInactivityExpiry.Schedule(ctx, nextRun, account.Id, am.peerInactivityExpirationJob(ctx, account.Id)) + } +} + // newAccount creates a new Account with a generated ID and generated default setup keys. // If ID is already in use (due to collision) we try one more time before returning error func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*Account, error) { @@ -2412,6 +2542,9 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac PeerLoginExpiration: DefaultPeerLoginExpiration, GroupsPropagationEnabled: true, RegularUsersViewBlocked: true, + + PeerInactivityExpirationEnabled: false, + PeerInactivityExpiration: DefaultPeerInactivityExpiration, }, } diff --git a/management/server/account_test.go b/management/server/account_test.go index b20071cba..19514dad1 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1957,6 +1957,90 @@ func TestAccount_GetExpiredPeers(t *testing.T) { } } +func TestAccount_GetInactivePeers(t *testing.T) { + type test struct { + name string + peers map[string]*nbpeer.Peer + expectedPeers map[string]struct{} + } + testCases := []test{ + { + name: "Peers with inactivity expiration disabled, no expired peers", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + InactivityExpirationEnabled: false, + }, + "peer-2": { + InactivityExpirationEnabled: false, + }, + }, + expectedPeers: map[string]struct{}{}, + }, + { + name: "Two peers expired", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + ID: "peer-1", + InactivityExpirationEnabled: true, + Status: &nbpeer.PeerStatus{ + LastSeen: time.Now().UTC().Add(-45 * time.Second), + Connected: false, + LoginExpired: false, + }, + LastLogin: time.Now().UTC().Add(-30 * time.Minute), + UserID: userID, + }, + "peer-2": { + ID: "peer-2", + InactivityExpirationEnabled: true, + Status: &nbpeer.PeerStatus{ + LastSeen: time.Now().UTC().Add(-45 * time.Second), + Connected: false, + LoginExpired: false, + }, + LastLogin: time.Now().UTC().Add(-2 * time.Hour), + UserID: userID, + }, + "peer-3": { + ID: "peer-3", + InactivityExpirationEnabled: true, + Status: &nbpeer.PeerStatus{ + LastSeen: time.Now().UTC(), + Connected: true, + LoginExpired: false, + }, + LastLogin: time.Now().UTC().Add(-1 * time.Hour), + UserID: userID, + }, + }, + expectedPeers: map[string]struct{}{ + "peer-1": {}, + "peer-2": {}, + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + account := &Account{ + Peers: testCase.peers, + Settings: &Settings{ + PeerInactivityExpirationEnabled: true, + PeerInactivityExpiration: time.Second, + }, + } + + expiredPeers := account.GetInactivePeers() + assert.Len(t, expiredPeers, len(testCase.expectedPeers)) + for _, peer := range expiredPeers { + if _, ok := testCase.expectedPeers[peer.ID]; !ok { + t.Fatalf("expected to have peer %s expired", peer.ID) + } + } + }) + } +} + func TestAccount_GetPeersWithExpiration(t *testing.T) { type test struct { name string @@ -2026,6 +2110,75 @@ func TestAccount_GetPeersWithExpiration(t *testing.T) { } } +func TestAccount_GetPeersWithInactivity(t *testing.T) { + type test struct { + name string + peers map[string]*nbpeer.Peer + expectedPeers map[string]struct{} + } + + testCases := []test{ + { + name: "No account peers, no peers with expiration", + peers: map[string]*nbpeer.Peer{}, + expectedPeers: map[string]struct{}{}, + }, + { + name: "Peers with login expiration disabled, no peers with expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + InactivityExpirationEnabled: false, + UserID: userID, + }, + "peer-2": { + InactivityExpirationEnabled: false, + UserID: userID, + }, + }, + expectedPeers: map[string]struct{}{}, + }, + { + name: "Peers with login expiration enabled, return peers with expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + ID: "peer-1", + InactivityExpirationEnabled: true, + UserID: userID, + }, + "peer-2": { + InactivityExpirationEnabled: false, + UserID: userID, + }, + }, + expectedPeers: map[string]struct{}{ + "peer-1": {}, + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + account := &Account{ + Peers: testCase.peers, + } + + actual := account.GetPeersWithInactivity() + assert.Len(t, actual, len(testCase.expectedPeers)) + if len(testCase.expectedPeers) > 0 { + for k := range testCase.expectedPeers { + contains := false + for _, peer := range actual { + if k == peer.ID { + contains = true + } + } + assert.True(t, contains) + } + } + }) + } +} + func TestAccount_GetNextPeerExpiration(t *testing.T) { type test struct { name string @@ -2187,6 +2340,168 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) { } } +func TestAccount_GetNextInactivePeerExpiration(t *testing.T) { + type test struct { + name string + peers map[string]*nbpeer.Peer + expiration time.Duration + expirationEnabled bool + expectedNextRun bool + expectedNextExpiration time.Duration + } + + expectedNextExpiration := time.Minute + testCases := []test{ + { + name: "No peers, no expiration", + peers: map[string]*nbpeer.Peer{}, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + { + name: "No connected peers, no expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + Status: &nbpeer.PeerStatus{ + Connected: false, + }, + InactivityExpirationEnabled: false, + UserID: userID, + }, + "peer-2": { + Status: &nbpeer.PeerStatus{ + Connected: false, + }, + InactivityExpirationEnabled: false, + UserID: userID, + }, + }, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + { + name: "Connected peers with disabled expiration, no expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + Status: &nbpeer.PeerStatus{ + Connected: true, + }, + InactivityExpirationEnabled: false, + UserID: userID, + }, + "peer-2": { + Status: &nbpeer.PeerStatus{ + Connected: true, + }, + InactivityExpirationEnabled: false, + UserID: userID, + }, + }, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + { + name: "Expired peers, no expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + Status: &nbpeer.PeerStatus{ + Connected: true, + LoginExpired: true, + }, + InactivityExpirationEnabled: true, + UserID: userID, + }, + "peer-2": { + Status: &nbpeer.PeerStatus{ + Connected: true, + LoginExpired: true, + }, + InactivityExpirationEnabled: true, + UserID: userID, + }, + }, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + { + name: "To be expired peer, return expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + Status: &nbpeer.PeerStatus{ + Connected: false, + LoginExpired: false, + LastSeen: time.Now().Add(-1 * time.Second), + }, + InactivityExpirationEnabled: true, + LastLogin: time.Now().UTC(), + UserID: userID, + }, + "peer-2": { + Status: &nbpeer.PeerStatus{ + Connected: true, + LoginExpired: true, + }, + InactivityExpirationEnabled: true, + UserID: userID, + }, + }, + expiration: time.Minute, + expirationEnabled: false, + expectedNextRun: true, + expectedNextExpiration: expectedNextExpiration, + }, + { + name: "Peers added with setup keys, no expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + Status: &nbpeer.PeerStatus{ + Connected: true, + LoginExpired: false, + }, + InactivityExpirationEnabled: true, + SetupKey: "key", + }, + "peer-2": { + Status: &nbpeer.PeerStatus{ + Connected: true, + LoginExpired: false, + }, + InactivityExpirationEnabled: true, + SetupKey: "key", + }, + }, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + account := &Account{ + Peers: testCase.peers, + Settings: &Settings{PeerInactivityExpiration: testCase.expiration, PeerInactivityExpirationEnabled: testCase.expirationEnabled}, + } + + expiration, ok := account.GetNextInactivePeerExpiration() + assert.Equal(t, testCase.expectedNextRun, ok) + if testCase.expectedNextRun { + assert.True(t, expiration >= 0 && expiration <= testCase.expectedNextExpiration) + } else { + assert.Equal(t, expiration, testCase.expectedNextExpiration) + } + }) + } +} + func TestAccount_SetJWTGroups(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 4ee57f181..188494241 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -139,6 +139,13 @@ const ( PostureCheckUpdated Activity = 61 // PostureCheckDeleted indicates that the user deleted a posture check PostureCheckDeleted Activity = 62 + + PeerInactivityExpirationEnabled Activity = 63 + PeerInactivityExpirationDisabled Activity = 64 + + AccountPeerInactivityExpirationEnabled Activity = 65 + AccountPeerInactivityExpirationDisabled Activity = 66 + AccountPeerInactivityExpirationDurationUpdated Activity = 67 ) var activityMap = map[Activity]Code{ @@ -205,6 +212,13 @@ var activityMap = map[Activity]Code{ PostureCheckCreated: {"Posture check created", "posture.check.created"}, PostureCheckUpdated: {"Posture check updated", "posture.check.updated"}, PostureCheckDeleted: {"Posture check deleted", "posture.check.deleted"}, + + PeerInactivityExpirationEnabled: {"Peer inactivity expiration enabled", "peer.inactivity.expiration.enable"}, + PeerInactivityExpirationDisabled: {"Peer inactivity expiration disabled", "peer.inactivity.expiration.disable"}, + + AccountPeerInactivityExpirationEnabled: {"Account peer inactivity expiration enabled", "account.peer.inactivity.expiration.enable"}, + AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"}, + AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"}, } // StringCode returns a string code of the activity diff --git a/management/server/file_store.go b/management/server/file_store.go index df3e9bb77..561e133ce 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -95,6 +95,9 @@ func restore(ctx context.Context, file string) (*FileStore, error) { account.Settings = &Settings{ PeerLoginExpirationEnabled: false, PeerLoginExpiration: DefaultPeerLoginExpiration, + + PeerInactivityExpirationEnabled: false, + PeerInactivityExpiration: DefaultPeerInactivityExpiration, } } diff --git a/management/server/http/accounts_handler.go b/management/server/http/accounts_handler.go index 91caa1512..4d4066de4 100644 --- a/management/server/http/accounts_handler.go +++ b/management/server/http/accounts_handler.go @@ -78,6 +78,9 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled, PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)), RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked, + + PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled, + PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)), } if req.Settings.Extra != nil { diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index fd0343e97..9d5148248 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -54,6 +54,14 @@ components: description: Period of time after which peer login expires (seconds). type: integer example: 43200 + peer_inactivity_expiration_enabled: + description: Enables or disables peer inactivity expiration globally. After peer's session has expired the user has to log in (authenticate). Applies only to peers that were added by a user (interactive SSO login). + type: boolean + example: true + peer_inactivity_expiration: + description: Period of time of inactivity after which peer session expires (seconds). + type: integer + example: 43200 regular_users_view_blocked: description: Allows blocking regular users from viewing parts of the system. type: boolean @@ -81,6 +89,8 @@ components: required: - peer_login_expiration_enabled - peer_login_expiration + - peer_inactivity_expiration_enabled + - peer_inactivity_expiration - regular_users_view_blocked AccountExtraSettings: type: object @@ -243,6 +253,9 @@ components: login_expiration_enabled: type: boolean example: false + inactivity_expiration_enabled: + type: boolean + example: false approval_required: description: (Cloud only) Indicates whether peer needs approval type: boolean @@ -251,6 +264,7 @@ components: - name - ssh_enabled - login_expiration_enabled + - inactivity_expiration_enabled Peer: allOf: - $ref: '#/components/schemas/PeerMinimum' @@ -327,6 +341,10 @@ components: type: string format: date-time example: "2023-05-05T09:00:35.477782Z" + inactivity_expiration_enabled: + description: Indicates whether peer inactivity expiration has been enabled or not + type: boolean + example: false approval_required: description: (Cloud only) Indicates whether peer needs approval type: boolean @@ -354,6 +372,7 @@ components: - last_seen - login_expiration_enabled - login_expired + - inactivity_expiration_enabled - os - ssh_enabled - user_id diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 570ec03c5..e2870d5d8 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -220,6 +220,12 @@ type AccountSettings struct { // JwtGroupsEnabled Allows extract groups from JWT claim and add it to account groups. JwtGroupsEnabled *bool `json:"jwt_groups_enabled,omitempty"` + // PeerInactivityExpiration Period of time of inactivity after which peer session expires (seconds). + PeerInactivityExpiration int `json:"peer_inactivity_expiration"` + + // PeerInactivityExpirationEnabled Enables or disables peer inactivity expiration globally. After peer's session has expired the user has to log in (authenticate). Applies only to peers that were added by a user (interactive SSO login). + PeerInactivityExpirationEnabled bool `json:"peer_inactivity_expiration_enabled"` + // PeerLoginExpiration Period of time after which peer login expires (seconds). PeerLoginExpiration int `json:"peer_login_expiration"` @@ -538,6 +544,9 @@ type Peer struct { // Id Peer ID Id string `json:"id"` + // InactivityExpirationEnabled Indicates whether peer inactivity expiration has been enabled or not + InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"` + // Ip Peer's IP address Ip string `json:"ip"` @@ -613,6 +622,9 @@ type PeerBatch struct { // Id Peer ID Id string `json:"id"` + // InactivityExpirationEnabled Indicates whether peer inactivity expiration has been enabled or not + InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"` + // Ip Peer's IP address Ip string `json:"ip"` @@ -677,10 +689,11 @@ type PeerNetworkRangeCheckAction string // PeerRequest defines model for PeerRequest. type PeerRequest struct { // ApprovalRequired (Cloud only) Indicates whether peer needs approval - ApprovalRequired *bool `json:"approval_required,omitempty"` - LoginExpirationEnabled bool `json:"login_expiration_enabled"` - Name string `json:"name"` - SshEnabled bool `json:"ssh_enabled"` + ApprovalRequired *bool `json:"approval_required,omitempty"` + InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"` + LoginExpirationEnabled bool `json:"login_expiration_enabled"` + Name string `json:"name"` + SshEnabled bool `json:"ssh_enabled"` } // PersonalAccessToken defines model for PersonalAccessToken. diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index 4fbbc3106..a5856a0e4 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -7,6 +7,8 @@ import ( "net/http" "github.com/gorilla/mux" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" @@ -14,7 +16,6 @@ import ( "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" - log "github.com/sirupsen/logrus" ) // PeersHandler is a handler that returns peers of the account @@ -87,6 +88,8 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, SSHEnabled: req.SshEnabled, Name: req.Name, LoginExpirationEnabled: req.LoginExpirationEnabled, + + InactivityExpirationEnabled: req.InactivityExpirationEnabled, } if req.ApprovalRequired != nil { @@ -331,29 +334,30 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD } return &api.Peer{ - Id: peer.ID, - Name: peer.Name, - Ip: peer.IP.String(), - ConnectionIp: peer.Location.ConnectionIP.String(), - Connected: peer.Status.Connected, - LastSeen: peer.Status.LastSeen, - Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion), - KernelVersion: peer.Meta.KernelVersion, - GeonameId: int(peer.Location.GeoNameID), - Version: peer.Meta.WtVersion, - Groups: groupsInfo, - SshEnabled: peer.SSHEnabled, - Hostname: peer.Meta.Hostname, - UserId: peer.UserID, - UiVersion: peer.Meta.UIVersion, - DnsLabel: fqdn(peer, dnsDomain), - LoginExpirationEnabled: peer.LoginExpirationEnabled, - LastLogin: peer.LastLogin, - LoginExpired: peer.Status.LoginExpired, - ApprovalRequired: !approved, - CountryCode: peer.Location.CountryCode, - CityName: peer.Location.CityName, - SerialNumber: peer.Meta.SystemSerialNumber, + Id: peer.ID, + Name: peer.Name, + Ip: peer.IP.String(), + ConnectionIp: peer.Location.ConnectionIP.String(), + Connected: peer.Status.Connected, + LastSeen: peer.Status.LastSeen, + Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion), + KernelVersion: peer.Meta.KernelVersion, + GeonameId: int(peer.Location.GeoNameID), + Version: peer.Meta.WtVersion, + Groups: groupsInfo, + SshEnabled: peer.SSHEnabled, + Hostname: peer.Meta.Hostname, + UserId: peer.UserID, + UiVersion: peer.Meta.UIVersion, + DnsLabel: fqdn(peer, dnsDomain), + LoginExpirationEnabled: peer.LoginExpirationEnabled, + LastLogin: peer.LastLogin, + LoginExpired: peer.Status.LoginExpired, + ApprovalRequired: !approved, + CountryCode: peer.Location.CountryCode, + CityName: peer.Location.CityName, + SerialNumber: peer.Meta.SystemSerialNumber, + InactivityExpirationEnabled: peer.InactivityExpirationEnabled, } } @@ -387,6 +391,8 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn CountryCode: peer.Location.CountryCode, CityName: peer.Location.CityName, SerialNumber: peer.Meta.SystemSerialNumber, + + InactivityExpirationEnabled: peer.InactivityExpirationEnabled, } } diff --git a/management/server/peer.go b/management/server/peer.go index a7d4f3b06..a4c7e1266 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -110,6 +110,31 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK return err } + expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, account) + if err != nil { + return err + } + + if peer.AddedWithSSOLogin() { + if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { + am.checkAndSchedulePeerLoginExpiration(ctx, account) + } + + if peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled { + am.checkAndSchedulePeerInactivityExpiration(ctx, account) + } + } + + if expired { + // we need to update other peers because when peer login expires all other peers are notified to disconnect from + // the expired one. Here we notify them that connection is now allowed again. + am.updateAccountPeers(ctx, account) + } + + return nil +} + +func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, account *Account) (bool, error) { oldStatus := peer.Status.Copy() newStatus := oldStatus newStatus.LastSeen = time.Now().UTC() @@ -138,25 +163,15 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK account.UpdatePeer(peer) - err = am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus) + err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus) if err != nil { - return err + return false, err } - if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { - am.checkAndSchedulePeerLoginExpiration(ctx, account) - } - - if oldStatus.LoginExpired { - // we need to update other peers because when peer login expires all other peers are notified to disconnect from - // the expired one. Here we notify them that connection is now allowed again. - am.updateAccountPeers(ctx, account) - } - - return nil + return oldStatus.LoginExpired, nil } -// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated. +// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated. func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -219,6 +234,25 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } } + if peer.InactivityExpirationEnabled != update.InactivityExpirationEnabled { + + if !peer.AddedWithSSOLogin() { + return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated") + } + + peer.InactivityExpirationEnabled = update.InactivityExpirationEnabled + + event := activity.PeerInactivityExpirationEnabled + if !update.InactivityExpirationEnabled { + event = activity.PeerInactivityExpirationDisabled + } + am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) + + if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled { + am.checkAndSchedulePeerInactivityExpiration(ctx, account) + } + } + account.UpdatePeer(peer) err = am.Store.SaveAccount(ctx, account) @@ -442,23 +476,24 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s registrationTime := time.Now().UTC() newPeer = &nbpeer.Peer{ - ID: xid.New().String(), - AccountID: accountID, - Key: peer.Key, - SetupKey: upperKey, - IP: freeIP, - Meta: peer.Meta, - Name: peer.Meta.Hostname, - DNSLabel: freeLabel, - UserID: userID, - Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime}, - SSHEnabled: false, - SSHKey: peer.SSHKey, - LastLogin: registrationTime, - CreatedAt: registrationTime, - LoginExpirationEnabled: addedByUser, - Ephemeral: ephemeral, - Location: peer.Location, + ID: xid.New().String(), + AccountID: accountID, + Key: peer.Key, + SetupKey: upperKey, + IP: freeIP, + Meta: peer.Meta, + Name: peer.Meta.Hostname, + DNSLabel: freeLabel, + UserID: userID, + Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime}, + SSHEnabled: false, + SSHKey: peer.SSHKey, + LastLogin: registrationTime, + CreatedAt: registrationTime, + LoginExpirationEnabled: addedByUser, + Ephemeral: ephemeral, + Location: peer.Location, + InactivityExpirationEnabled: addedByUser, } opEvent.TargetID = newPeer.ID opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain()) diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 3d9ba18e9..9a53459a8 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -38,6 +38,8 @@ type Peer struct { // LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login. // Works with LastLogin LoginExpirationEnabled bool + + InactivityExpirationEnabled bool // LastLogin the time when peer performed last login operation LastLogin time.Time // CreatedAt records the time the peer was created @@ -187,6 +189,8 @@ func (p *Peer) Copy() *Peer { CreatedAt: p.CreatedAt, Ephemeral: p.Ephemeral, Location: p.Location, + + InactivityExpirationEnabled: p.InactivityExpirationEnabled, } } @@ -219,6 +223,22 @@ func (p *Peer) MarkLoginExpired(expired bool) { p.Status = newStatus } +// SessionExpired indicates whether the peer's session has expired or not. +// If Peer.LastLogin plus the expiresIn duration has happened already; then session has expired. +// Return true if a session has expired, false otherwise, and time left to expiration (negative when expired). +// Session expiration can be disabled/enabled on a Peer level via Peer.LoginExpirationEnabled property. +// Session expiration can also be disabled/enabled globally on the Account level via Settings.PeerLoginExpirationEnabled. +// Only peers added by interactive SSO login can be expired. +func (p *Peer) SessionExpired(expiresIn time.Duration) (bool, time.Duration) { + if !p.AddedWithSSOLogin() || !p.InactivityExpirationEnabled || p.Status.Connected { + return false, 0 + } + expiresAt := p.Status.LastSeen.Add(expiresIn) + now := time.Now() + timeLeft := expiresAt.Sub(now) + return timeLeft <= 0, timeLeft +} + // LoginExpired indicates whether the peer's login has expired or not. // If Peer.LastLogin plus the expiresIn duration has happened already; then login has expired. // Return true if a login has expired, false otherwise, and time left to expiration (negative when expired). diff --git a/management/server/peer_test.go b/management/server/peer_test.go index f3bf0ddba..c5edb5636 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -82,6 +82,68 @@ func TestPeer_LoginExpired(t *testing.T) { } } +func TestPeer_SessionExpired(t *testing.T) { + tt := []struct { + name string + expirationEnabled bool + lastLogin time.Time + connected bool + expected bool + accountSettings *Settings + }{ + { + name: "Peer Inactivity Expiration Disabled. Peer Inactivity Should Not Expire", + expirationEnabled: false, + connected: false, + lastLogin: time.Now().UTC().Add(-1 * time.Second), + accountSettings: &Settings{ + PeerInactivityExpirationEnabled: true, + PeerInactivityExpiration: time.Hour, + }, + expected: false, + }, + { + name: "Peer Inactivity Should Expire", + expirationEnabled: true, + connected: false, + lastLogin: time.Now().UTC().Add(-1 * time.Second), + accountSettings: &Settings{ + PeerInactivityExpirationEnabled: true, + PeerInactivityExpiration: time.Second, + }, + expected: true, + }, + { + name: "Peer Inactivity Should Not Expire", + expirationEnabled: true, + connected: true, + lastLogin: time.Now().UTC(), + accountSettings: &Settings{ + PeerInactivityExpirationEnabled: true, + PeerInactivityExpiration: time.Second, + }, + expected: false, + }, + } + + for _, c := range tt { + t.Run(c.name, func(t *testing.T) { + peerStatus := &nbpeer.PeerStatus{ + Connected: c.connected, + } + peer := &nbpeer.Peer{ + InactivityExpirationEnabled: c.expirationEnabled, + LastLogin: c.lastLogin, + Status: peerStatus, + UserID: userID, + } + + expired, _ := peer.SessionExpired(c.accountSettings.PeerInactivityExpiration) + assert.Equal(t, expired, c.expected) + }) + } +} + func TestAccountManager_GetNetworkMap(t *testing.T) { manager, err := createManager(t) if err != nil {