diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml index 2aaef7564..88db8c5e8 100644 --- a/.github/workflows/golang-test-darwin.yml +++ b/.github/workflows/golang-test-darwin.yml @@ -42,4 +42,4 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./... + run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./... diff --git a/.github/workflows/golang-test-freebsd.yml b/.github/workflows/golang-test-freebsd.yml index 4f13ee30e..a2d743715 100644 --- a/.github/workflows/golang-test-freebsd.yml +++ b/.github/workflows/golang-test-freebsd.yml @@ -38,7 +38,7 @@ jobs: time go test -timeout 1m -failfast ./dns/... time go test -timeout 1m -failfast ./encryption/... time go test -timeout 1m -failfast ./formatter/... - time go test -timeout 1m -failfast ./iface/... + time go test -timeout 1m -failfast ./client/iface/... time go test -timeout 1m -failfast ./route/... time go test -timeout 1m -failfast ./sharedsock/... time go test -timeout 1m -failfast ./signal/... diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 2d5cf2856..e1e1ff236 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -16,7 +16,7 @@ jobs: matrix: arch: [ '386','amd64' ] store: [ 'sqlite', 'postgres'] - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Install Go uses: actions/setup-go@v5 @@ -49,7 +49,7 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./... + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./... test_client_on_docker: runs-on: ubuntu-20.04 @@ -80,7 +80,7 @@ jobs: run: git --no-pager diff --exit-code - name: Generate Iface Test bin - run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./iface/ + run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./client/iface/ - name: Generate Shared Sock Test bin run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 8b7136841..2d743f790 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -19,7 +19,7 @@ jobs: - name: codespell uses: codespell-project/actions-codespell@v2 with: - ignore_words_list: erro,clienta,hastable, + ignore_words_list: erro,clienta,hastable,iif skip: go.mod,go.sum only_warn: 1 golangci: diff --git a/.github/workflows/install-script-test.yml b/.github/workflows/install-script-test.yml index 04c222e87..22d002a48 100644 --- a/.github/workflows/install-script-test.yml +++ b/.github/workflows/install-script-test.yml @@ -13,6 +13,7 @@ concurrency: jobs: test-install-script: strategy: + fail-fast: false max-parallel: 2 matrix: os: [ubuntu-latest, macos-latest] diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 5f423f1c9..b2e2437e6 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -3,15 +3,14 @@ name: Release on: push: tags: - - 'v*' + - "v*" branches: - main pull_request: - env: SIGN_PIPE_VER: "v0.0.14" - GORELEASER_VER: "v1.14.1" + GORELEASER_VER: "v2.3.2" PRODUCT_NAME: "NetBird" COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)" @@ -21,7 +20,7 @@ concurrency: jobs: release: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 env: flags: "" steps: @@ -34,19 +33,16 @@ jobs: - if: ${{ !startsWith(github.ref, 'refs/tags/v') }} run: echo "flags=--snapshot" >> $GITHUB_ENV - - - name: Checkout + - name: Checkout uses: actions/checkout@v4 with: fetch-depth: 0 # It is required for GoReleaser to work properly - - - name: Set up Go + - name: Set up Go uses: actions/setup-go@v5 with: go-version: "1.23" cache: false - - - name: Cache Go modules + - name: Cache Go modules uses: actions/cache@v4 with: path: | @@ -55,24 +51,19 @@ jobs: key: ${{ runner.os }}-go-releaser-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go-releaser- - - - name: Install modules + - name: Install modules run: go mod tidy - - - name: check git status + - name: check git status run: git --no-pager diff --exit-code - - - name: Set up QEMU + - name: Set up QEMU uses: docker/setup-qemu-action@v2 - - - name: Set up Docker Buildx + - name: Set up Docker Buildx uses: docker/setup-buildx-action@v2 - - - name: Login to Docker hub + - name: Login to Docker hub if: github.event_name != 'pull_request' uses: docker/login-action@v1 with: - username: netbirdio + username: ${{ secrets.DOCKER_USER }} password: ${{ secrets.DOCKER_TOKEN }} - name: Install OS build dependencies run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu @@ -85,35 +76,31 @@ jobs: uses: goreleaser/goreleaser-action@v4 with: version: ${{ env.GORELEASER_VER }} - args: release --rm-dist ${{ env.flags }} + args: release --clean ${{ env.flags }} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} - - - name: upload non tags for debug purposes + - name: upload non tags for debug purposes uses: actions/upload-artifact@v4 with: name: release path: dist/ retention-days: 3 - - - name: upload linux packages + - name: upload linux packages uses: actions/upload-artifact@v4 with: name: linux-packages path: dist/netbird_linux** retention-days: 3 - - - name: upload windows packages + - name: upload windows packages uses: actions/upload-artifact@v4 with: name: windows-packages path: dist/netbird_windows** retention-days: 3 - - - name: upload macos packages + - name: upload macos packages uses: actions/upload-artifact@v4 with: name: macos-packages @@ -145,7 +132,7 @@ jobs: - name: Cache Go modules uses: actions/cache@v4 with: - path: | + path: | ~/go/pkg/mod ~/.cache/go-build key: ${{ runner.os }}-ui-go-releaser-${{ hashFiles('**/go.sum') }} @@ -169,7 +156,7 @@ jobs: uses: goreleaser/goreleaser-action@v4 with: version: ${{ env.GORELEASER_VER }} - args: release --config .goreleaser_ui.yaml --rm-dist ${{ env.flags }} + args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} @@ -187,19 +174,16 @@ jobs: steps: - if: ${{ !startsWith(github.ref, 'refs/tags/v') }} run: echo "flags=--snapshot" >> $GITHUB_ENV - - - name: Checkout + - name: Checkout uses: actions/checkout@v4 with: fetch-depth: 0 # It is required for GoReleaser to work properly - - - name: Set up Go + - name: Set up Go uses: actions/setup-go@v5 with: go-version: "1.23" cache: false - - - name: Cache Go modules + - name: Cache Go modules uses: actions/cache@v4 with: path: | @@ -208,23 +192,19 @@ jobs: key: ${{ runner.os }}-ui-go-releaser-darwin-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-ui-go-releaser-darwin- - - - name: Install modules + - name: Install modules run: go mod tidy - - - name: check git status + - name: check git status run: git --no-pager diff --exit-code - - - name: Run GoReleaser + - name: Run GoReleaser id: goreleaser uses: goreleaser/goreleaser-action@v4 with: version: ${{ env.GORELEASER_VER }} - args: release --config .goreleaser_ui_darwin.yaml --rm-dist ${{ env.flags }} + args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - - name: upload non tags for debug purposes + - name: upload non tags for debug purposes uses: actions/upload-artifact@v4 with: name: release-ui-darwin @@ -233,7 +213,7 @@ jobs: trigger_signer: runs-on: ubuntu-latest - needs: [release,release_ui,release_ui_darwin] + needs: [release, release_ui, release_ui_darwin] if: startsWith(github.ref, 'refs/tags/') steps: - name: Trigger binaries sign pipelines diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 068864d6e..cf2ce4f4f 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -1,3 +1,5 @@ +version: 2 + project_name: netbird builds: - id: netbird @@ -22,7 +24,7 @@ builds: goarch: 386 ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" tags: - load_wgnt_from_rsrc @@ -42,19 +44,19 @@ builds: - softfloat ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" tags: - load_wgnt_from_rsrc - id: netbird-mgmt dir: management env: - - CGO_ENABLED=1 - - >- - {{- if eq .Runtime.Goos "linux" }} - {{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }} - {{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }} - {{- end }} + - CGO_ENABLED=1 + - >- + {{- if eq .Runtime.Goos "linux" }} + {{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }} + {{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }} + {{- end }} binary: netbird-mgmt goos: - linux @@ -64,7 +66,7 @@ builds: - arm ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" - id: netbird-signal dir: signal @@ -78,7 +80,7 @@ builds: - arm ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" - id: netbird-relay dir: relay @@ -92,7 +94,7 @@ builds: - arm ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" archives: - builds: @@ -100,7 +102,6 @@ archives: - netbird-static nfpms: - - maintainer: Netbird description: Netbird client. homepage: https://netbird.io/ @@ -416,10 +417,9 @@ docker_manifests: - netbirdio/management:{{ .Version }}-debug-amd64 brews: - - - ids: + - ids: - default - tap: + repository: owner: netbirdio name: homebrew-tap token: "{{ .Env.HOMEBREW_TAP_GITHUB_TOKEN }}" @@ -436,7 +436,7 @@ brews: uploads: - name: debian ids: - - netbird-deb + - netbird-deb mode: archive target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package= username: dev@wiretrustee.com diff --git a/.goreleaser_ui.yaml b/.goreleaser_ui.yaml index fd92b5328..06577f4e3 100644 --- a/.goreleaser_ui.yaml +++ b/.goreleaser_ui.yaml @@ -1,3 +1,5 @@ +version: 2 + project_name: netbird-ui builds: - id: netbird-ui @@ -11,7 +13,7 @@ builds: - amd64 ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" - id: netbird-ui-windows dir: client/ui @@ -26,7 +28,7 @@ builds: ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -H windowsgui - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" archives: - id: linux-arch @@ -39,7 +41,6 @@ archives: - netbird-ui-windows nfpms: - - maintainer: Netbird description: Netbird client UI. homepage: https://netbird.io/ @@ -77,7 +78,7 @@ nfpms: uploads: - name: debian ids: - - netbird-ui-deb + - netbird-ui-deb mode: archive target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package= username: dev@wiretrustee.com diff --git a/.goreleaser_ui_darwin.yaml b/.goreleaser_ui_darwin.yaml index 2c3afa91b..bccb7f471 100644 --- a/.goreleaser_ui_darwin.yaml +++ b/.goreleaser_ui_darwin.yaml @@ -1,3 +1,5 @@ +version: 2 + project_name: netbird-ui builds: - id: netbird-ui-darwin @@ -17,7 +19,7 @@ builds: - softfloat ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" tags: - load_wgnt_from_rsrc @@ -28,4 +30,4 @@ archives: checksum: name_template: "{{ .ProjectName }}_darwin_checksums.txt" changelog: - skip: true \ No newline at end of file + disable: true diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 492aa5c2e..c82cfc763 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -96,7 +96,7 @@ They can be executed from the repository root before every push or PR: **Goreleaser** ```shell -goreleaser --snapshot --rm-dist +goreleaser build --snapshot --clean ``` **golangci-lint** ```shell diff --git a/README.md b/README.md index aa3ec41e5..270c9ad87 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,8 @@ ![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab) +### NetBird on Lawrence Systems (Video) +[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw) ### Key features @@ -62,6 +64,7 @@ | | | | | | | | |
  • - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)
  • | |
    • - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas)
    | | | | | |
    • - \[x] Docker
    | + ### Quickstart with NetBird Cloud - Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install) diff --git a/client/android/client.go b/client/android/client.go index d937e132e..229bcd974 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -8,6 +8,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" @@ -15,7 +16,6 @@ import ( "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/formatter" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/util/net" ) @@ -26,7 +26,7 @@ type ConnectionListener interface { // TunAdapter export internal TunAdapter for mobile type TunAdapter interface { - iface.TunAdapter + device.TunAdapter } // IFaceDiscover export internal IFaceDiscover for mobile @@ -51,7 +51,7 @@ func init() { // Client struct manage the life circle of background service type Client struct { cfgFile string - tunAdapter iface.TunAdapter + tunAdapter device.TunAdapter iFaceDiscover IFaceDiscover recorder *peer.Status ctxCancel context.CancelFunc diff --git a/client/cmd/login_test.go b/client/cmd/login_test.go index 6bb7eff4f..fa20435ea 100644 --- a/client/cmd/login_test.go +++ b/client/cmd/login_test.go @@ -5,8 +5,8 @@ import ( "strings" "testing" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/util" ) diff --git a/client/cmd/root_test.go b/client/cmd/root_test.go index f2805cf35..4cbbe8783 100644 --- a/client/cmd/root_test.go +++ b/client/cmd/root_test.go @@ -7,7 +7,7 @@ import ( "github.com/spf13/cobra" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" ) func TestInitCommands(t *testing.T) { diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 780cc8b04..d998f9ea9 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -3,7 +3,6 @@ package cmd import ( "context" "net" - "path/filepath" "testing" "time" @@ -34,18 +33,12 @@ func startTestingServices(t *testing.T) string { if err != nil { t.Fatal(err) } - testDir := t.TempDir() - config.Datadir = testDir - err = util.CopyFileContents("../testdata/store.json", filepath.Join(testDir, "store.json")) - if err != nil { - t.Fatal(err) - } _, signalLis := startSignal(t) signalAddr := signalLis.Addr().String() config.Signal.URI = signalAddr - _, mgmLis := startManagement(t, config) + _, mgmLis := startManagement(t, config, "../testdata/store.sql") mgmAddr := mgmLis.Addr().String() return mgmAddr } @@ -57,7 +50,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) { t.Fatal(err) } s := grpc.NewServer() - srv, err := sig.NewServer(otel.Meter("")) + srv, err := sig.NewServer(context.Background(), otel.Meter("")) require.NoError(t, err) sigProto.RegisterSignalExchangeServer(s, srv) @@ -70,7 +63,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) { return s, lis } -func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) { +func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.Server, net.Listener) { t.Helper() lis, err := net.Listen("tcp", ":0") @@ -78,7 +71,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste t.Fatal(err) } s := grpc.NewServer() - store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir) + store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir()) if err != nil { t.Fatal(err) } diff --git a/client/cmd/up.go b/client/cmd/up.go index b447f7141..05ecce9e0 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -15,11 +15,11 @@ import ( gstatus "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/durationpb" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/system" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/util" ) diff --git a/client/firewall/iface.go b/client/firewall/iface.go index 882daef75..f349f9210 100644 --- a/client/firewall/iface.go +++ b/client/firewall/iface.go @@ -1,11 +1,13 @@ package firewall -import "github.com/netbirdio/netbird/iface" +import ( + "github.com/netbirdio/netbird/client/iface/device" +) // IFaceMapper defines subset methods of interface required for manager type IFaceMapper interface { Name() string - Address() iface.WGAddress + Address() device.WGAddress IsUserspaceBind() bool - SetFilter(iface.PacketFilter) error + SetFilter(device.PacketFilter) error } diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index b77cc8f43..c271e592d 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -11,6 +11,7 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + nbnet "github.com/netbirdio/netbird/util/net" ) const ( @@ -19,27 +20,32 @@ const ( // rules chains contains the effective ACL rules chainNameInputRules = "NETBIRD-ACL-INPUT" chainNameOutputRules = "NETBIRD-ACL-OUTPUT" - - postRoutingMark = "0x000007e4" ) -type aclManager struct { - iptablesClient *iptables.IPTables - wgIface iFaceMapper - routeingFwChainName string - - entries map[string][][]string - ipsetStore *ipsetStore +type entry struct { + spec []string + position int } -func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routeingFwChainName string) (*aclManager, error) { - m := &aclManager{ - iptablesClient: iptablesClient, - wgIface: wgIface, - routeingFwChainName: routeingFwChainName, +type aclManager struct { + iptablesClient *iptables.IPTables + wgIface iFaceMapper + routingFwChainName string - entries: make(map[string][][]string), - ipsetStore: newIpsetStore(), + entries map[string][][]string + optionalEntries map[string][]entry + ipsetStore *ipsetStore +} + +func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) { + m := &aclManager{ + iptablesClient: iptablesClient, + wgIface: wgIface, + routingFwChainName: routingFwChainName, + + entries: make(map[string][][]string), + optionalEntries: make(map[string][]entry), + ipsetStore: newIpsetStore(), } err := ipset.Init() @@ -48,6 +54,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, route } m.seedInitialEntries() + m.seedInitialOptionalEntries() err = m.cleanChains() if err != nil { @@ -61,7 +68,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, route return m, nil } -func (m *aclManager) AddFiltering( +func (m *aclManager) AddPeerFiltering( ip net.IP, protocol firewall.Protocol, sPort *firewall.Port, @@ -127,7 +134,7 @@ func (m *aclManager) AddFiltering( return nil, fmt.Errorf("rule already exists") } - if err := m.iptablesClient.Insert("filter", chain, 1, specs...); err != nil { + if err := m.iptablesClient.Append("filter", chain, specs...); err != nil { return nil, err } @@ -139,28 +146,16 @@ func (m *aclManager) AddFiltering( chain: chain, } - if !shouldAddToPrerouting(protocol, dPort, direction) { - return []firewall.Rule{rule}, nil - } - - rulePrerouting, err := m.addPreroutingFilter(ipsetName, string(protocol), dPortVal, ip) - if err != nil { - return []firewall.Rule{rule}, err - } - return []firewall.Rule{rule, rulePrerouting}, nil + return []firewall.Rule{rule}, nil } -// DeleteRule from the firewall by rule definition -func (m *aclManager) DeleteRule(rule firewall.Rule) error { +// DeletePeerRule from the firewall by rule definition +func (m *aclManager) DeletePeerRule(rule firewall.Rule) error { r, ok := rule.(*Rule) if !ok { return fmt.Errorf("invalid rule type") } - if r.chain == "PREROUTING" { - goto DELETERULE - } - if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok { // delete IP from ruleset IPs list and ipset if _, ok := ipsetList.ips[r.ip]; ok { @@ -185,14 +180,7 @@ func (m *aclManager) DeleteRule(rule firewall.Rule) error { } } -DELETERULE: - var table string - if r.chain == "PREROUTING" { - table = "mangle" - } else { - table = "filter" - } - err := m.iptablesClient.Delete(table, r.chain, r.specs...) + err := m.iptablesClient.Delete(tableName, r.chain, r.specs...) if err != nil { log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err) } @@ -203,44 +191,6 @@ func (m *aclManager) Reset() error { return m.cleanChains() } -func (m *aclManager) addPreroutingFilter(ipsetName string, protocol string, port string, ip net.IP) (*Rule, error) { - var src []string - if ipsetName != "" { - src = []string{"-m", "set", "--set", ipsetName, "src"} - } else { - src = []string{"-s", ip.String()} - } - specs := []string{ - "-d", m.wgIface.Address().IP.String(), - "-p", protocol, - "--dport", port, - "-j", "MARK", "--set-mark", postRoutingMark, - } - - specs = append(src, specs...) - - ok, err := m.iptablesClient.Exists("mangle", "PREROUTING", specs...) - if err != nil { - return nil, fmt.Errorf("failed to check rule: %w", err) - } - if ok { - return nil, fmt.Errorf("rule already exists") - } - - if err := m.iptablesClient.Insert("mangle", "PREROUTING", 1, specs...); err != nil { - return nil, err - } - - rule := &Rule{ - ruleID: uuid.New().String(), - specs: specs, - ipsetName: ipsetName, - ip: ip.String(), - chain: "PREROUTING", - } - return rule, nil -} - // todo write less destructive cleanup mechanism func (m *aclManager) cleanChains() error { ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules) @@ -293,8 +243,7 @@ func (m *aclManager) cleanChains() error { ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING") if err != nil { - log.Debugf("failed to list chains: %s", err) - return err + return fmt.Errorf("list chains: %w", err) } if ok { for _, rule := range m.entries["PREROUTING"] { @@ -303,11 +252,6 @@ func (m *aclManager) cleanChains() error { log.Errorf("failed to delete rule: %v, %s", rule, err) } } - err = m.iptablesClient.ClearChain("mangle", "PREROUTING") - if err != nil { - log.Debugf("failed to clear %s chain: %s", "PREROUTING", err) - return err - } } for _, ipsetName := range m.ipsetStore.ipsetNames() { @@ -338,58 +282,66 @@ func (m *aclManager) createDefaultChains() error { for chainName, rules := range m.entries { for _, rule := range rules { - if chainName == "FORWARD" { - // position 2 because we add it after router's, jump rule - if err := m.iptablesClient.InsertUnique(tableName, "FORWARD", 2, rule...); err != nil { - log.Debugf("failed to create input chain jump rule: %s", err) - return err - } - } else { - if err := m.iptablesClient.AppendUnique(tableName, chainName, rule...); err != nil { - log.Debugf("failed to create input chain jump rule: %s", err) - return err - } + if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil { + log.Debugf("failed to create input chain jump rule: %s", err) + return err } } } + for chainName, entries := range m.optionalEntries { + for _, entry := range entries { + if err := m.iptablesClient.InsertUnique(tableName, chainName, entry.position, entry.spec...); err != nil { + log.Errorf("failed to insert optional entry %v: %v", entry.spec, err) + continue + } + m.entries[chainName] = append(m.entries[chainName], entry.spec) + } + } + clear(m.optionalEntries) + return nil } +// seedInitialEntries adds default rules to the entries map, rules are inserted on pos 1, hence the order is reversed. +// We want to make sure our traffic is not dropped by existing rules. + +// The existing FORWARD rules/policies decide outbound traffic towards our interface. +// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. + +// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule. func (m *aclManager) seedInitialEntries() { - m.appendToEntries("INPUT", - []string{"-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"}) - m.appendToEntries("INPUT", - []string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"}) - - m.appendToEntries("INPUT", - []string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameInputRules}) + established := getConntrackEstablished() m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) - - m.appendToEntries("OUTPUT", - []string{"-o", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"}) - - m.appendToEntries("OUTPUT", - []string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"}) - - m.appendToEntries("OUTPUT", - []string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameOutputRules}) + m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules}) + m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...)) m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"}) + m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", chainNameOutputRules}) + m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"}) + m.appendToEntries("OUTPUT", append([]string{"-o", m.wgIface.Name()}, established...)) m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) - m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules}) - m.appendToEntries("FORWARD", - []string{"-o", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"}) - m.appendToEntries("FORWARD", - []string{"-i", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"}) - m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", m.routeingFwChainName}) - m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routeingFwChainName}) + m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName}) + m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...)) +} - m.appendToEntries("PREROUTING", - []string{"-t", "mangle", "-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().IP.String(), "-m", "mark", "--mark", postRoutingMark}) +func (m *aclManager) seedInitialOptionalEntries() { + m.optionalEntries["FORWARD"] = []entry{ + { + spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark), "-j", chainNameInputRules}, + position: 2, + }, + } + + m.optionalEntries["PREROUTING"] = []entry{ + { + spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark)}, + position: 1, + }, + } } func (m *aclManager) appendToEntries(chainName string, spec []string) { @@ -456,18 +408,3 @@ func transformIPsetName(ipsetName string, sPort, dPort string) string { return ipsetName } } - -func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool { - if proto == "all" { - return false - } - - if direction != firewall.RuleDirectionIN { - return false - } - - if dPort == nil { - return false - } - return true -} diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 2d231ec45..94bd2fccf 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -4,13 +4,14 @@ import ( "context" "fmt" "net" + "net/netip" "sync" "github.com/coreos/go-iptables/iptables" log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" ) // Manager of iptables firewall @@ -21,7 +22,7 @@ type Manager struct { ipv4Client *iptables.IPTables aclMgr *aclManager - router *routerManager + router *router } // iFaceMapper defines subset methods of interface required for manager @@ -43,12 +44,12 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { ipv4Client: iptablesClient, } - m.router, err = newRouterManager(context, iptablesClient) + m.router, err = newRouter(context, iptablesClient, wgIface) if err != nil { log.Debugf("failed to initialize route related chains: %s", err) return nil, err } - m.aclMgr, err = newAclManager(iptablesClient, wgIface, m.router.RouteingFwChainName()) + m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD) if err != nil { log.Debugf("failed to initialize ACL manager: %s", err) return nil, err @@ -57,10 +58,10 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { return m, nil } -// AddFiltering rule to the firewall +// AddPeerFiltering adds a rule to the firewall // // Comment will be ignored because some system this feature is not supported -func (m *Manager) AddFiltering( +func (m *Manager) AddPeerFiltering( ip net.IP, protocol firewall.Protocol, sPort *firewall.Port, @@ -73,33 +74,62 @@ func (m *Manager) AddFiltering( m.mutex.Lock() defer m.mutex.Unlock() - return m.aclMgr.AddFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName) + return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName) } -// DeleteRule from the firewall by rule definition -func (m *Manager) DeleteRule(rule firewall.Rule) error { +func (m *Manager) AddRouteFiltering( + sources []netip.Prefix, + destination netip.Prefix, + proto firewall.Protocol, + sPort *firewall.Port, + dPort *firewall.Port, + action firewall.Action, +) (firewall.Rule, error) { m.mutex.Lock() defer m.mutex.Unlock() - return m.aclMgr.DeleteRule(rule) + if !destination.Addr().Is4() { + return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String()) + } + + return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action) +} + +// DeletePeerRule from the firewall by rule definition +func (m *Manager) DeletePeerRule(rule firewall.Rule) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.aclMgr.DeletePeerRule(rule) +} + +func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.DeleteRouteRule(rule) } func (m *Manager) IsServerRouteSupported() bool { return true } -func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error { +func (m *Manager) AddNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.InsertRoutingRules(pair) + return m.router.AddNatRule(pair) } -func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error { +func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.RemoveRoutingRules(pair) + return m.router.RemoveNatRule(pair) +} + +func (m *Manager) SetLegacyManagement(isLegacy bool) error { + return firewall.SetLegacyManagement(m.router, isLegacy) } // Reset firewall to the default state @@ -125,7 +155,7 @@ func (m *Manager) AllowNetbird() error { return nil } - _, err := m.AddFiltering( + _, err := m.AddPeerFiltering( net.ParseIP("0.0.0.0"), "all", nil, @@ -138,7 +168,7 @@ func (m *Manager) AllowNetbird() error { if err != nil { return fmt.Errorf("failed to allow netbird interface traffic: %w", err) } - _, err = m.AddFiltering( + _, err = m.AddPeerFiltering( net.ParseIP("0.0.0.0"), "all", nil, @@ -153,3 +183,7 @@ func (m *Manager) AllowNetbird() error { // Flush doesn't need to be implemented for this manager func (m *Manager) Flush() error { return nil } + +func getConntrackEstablished() []string { + return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} +} diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go index ceb116c62..498d8f58b 100644 --- a/client/firewall/iptables/manager_linux_test.go +++ b/client/firewall/iptables/manager_linux_test.go @@ -11,9 +11,24 @@ import ( "github.com/stretchr/testify/require" fw "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" ) +var ifaceMock = &iFaceMock{ + NameFunc: func() string { + return "lo" + }, + AddressFunc: func() iface.WGAddress { + return iface.WGAddress{ + IP: net.ParseIP("10.20.0.1"), + Network: &net.IPNet{ + IP: net.ParseIP("10.20.0.0"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + } + }, +} + // iFaceMapper defines subset methods of interface required for manager type iFaceMock struct { NameFunc func() string @@ -40,23 +55,8 @@ func TestIptablesManager(t *testing.T) { ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err) - mock := &iFaceMock{ - NameFunc: func() string { - return "lo" - }, - AddressFunc: func() iface.WGAddress { - return iface.WGAddress{ - IP: net.ParseIP("10.20.0.1"), - Network: &net.IPNet{ - IP: net.ParseIP("10.20.0.0"), - Mask: net.IPv4Mask(255, 255, 255, 0), - }, - } - }, - } - // just check on the local interface - manager, err := Create(context.Background(), mock) + manager, err := Create(context.Background(), ifaceMock) require.NoError(t, err) time.Sleep(time.Second) @@ -72,7 +72,7 @@ func TestIptablesManager(t *testing.T) { t.Run("add first rule", func(t *testing.T) { ip := net.ParseIP("10.20.0.2") port := &fw.Port{Values: []int{8080}} - rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") + rule1, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") require.NoError(t, err, "failed to add rule") for _, r := range rule1 { @@ -87,7 +87,7 @@ func TestIptablesManager(t *testing.T) { port := &fw.Port{ Values: []int{8043: 8046}, } - rule2, err = manager.AddFiltering( + rule2, err = manager.AddPeerFiltering( ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range") require.NoError(t, err, "failed to add rule") @@ -99,7 +99,7 @@ func TestIptablesManager(t *testing.T) { t.Run("delete first rule", func(t *testing.T) { for _, r := range rule1 { - err := manager.DeleteRule(r) + err := manager.DeletePeerRule(r) require.NoError(t, err, "failed to delete rule") checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...) @@ -108,7 +108,7 @@ func TestIptablesManager(t *testing.T) { t.Run("delete second rule", func(t *testing.T) { for _, r := range rule2 { - err := manager.DeleteRule(r) + err := manager.DeletePeerRule(r) require.NoError(t, err, "failed to delete rule") } @@ -119,7 +119,7 @@ func TestIptablesManager(t *testing.T) { // add second rule ip := net.ParseIP("10.20.0.3") port := &fw.Port{Values: []int{5353}} - _, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic") + _, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic") require.NoError(t, err, "failed to add rule") err = manager.Reset() @@ -170,7 +170,7 @@ func TestIptablesManagerIPSet(t *testing.T) { t.Run("add first rule with set", func(t *testing.T) { ip := net.ParseIP("10.20.0.2") port := &fw.Port{Values: []int{8080}} - rule1, err = manager.AddFiltering( + rule1, err = manager.AddPeerFiltering( ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "default", "accept HTTP traffic", ) @@ -189,7 +189,7 @@ func TestIptablesManagerIPSet(t *testing.T) { port := &fw.Port{ Values: []int{443}, } - rule2, err = manager.AddFiltering( + rule2, err = manager.AddPeerFiltering( ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "default", "accept HTTPS traffic from ports range", ) @@ -202,7 +202,7 @@ func TestIptablesManagerIPSet(t *testing.T) { t.Run("delete first rule", func(t *testing.T) { for _, r := range rule1 { - err := manager.DeleteRule(r) + err := manager.DeletePeerRule(r) require.NoError(t, err, "failed to delete rule") require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index") @@ -211,7 +211,7 @@ func TestIptablesManagerIPSet(t *testing.T) { t.Run("delete second rule", func(t *testing.T) { for _, r := range rule2 { - err := manager.DeleteRule(r) + err := manager.DeletePeerRule(r) require.NoError(t, err, "failed to delete rule") require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty") @@ -269,9 +269,9 @@ func TestIptablesCreatePerformance(t *testing.T) { for i := 0; i < testMax; i++ { port := &fw.Port{Values: []int{1000 + i}} if i%2 == 0 { - _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") + _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") } else { - _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") + _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") } require.NoError(t, err, "failed to add rule") diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index e8f09a106..e60c352d5 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -5,368 +5,479 @@ package iptables import ( "context" "fmt" + "net/netip" + "strconv" "strings" "github.com/coreos/go-iptables/iptables" + "github.com/hashicorp/go-multierror" + "github.com/nadoo/ipset" log "github.com/sirupsen/logrus" + nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/internal/acl/id" + "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" ) const ( - Ipv4Forwarding = "netbird-rt-forwarding" - ipv4Nat = "netbird-rt-nat" + ipv4Nat = "netbird-rt-nat" ) // constants needed to manage and create iptable rules const ( tableFilter = "filter" tableNat = "nat" - chainFORWARD = "FORWARD" chainPOSTROUTING = "POSTROUTING" chainRTNAT = "NETBIRD-RT-NAT" chainRTFWD = "NETBIRD-RT-FWD" routingFinalForwardJump = "ACCEPT" routingFinalNatJump = "MASQUERADE" + + matchSet = "--match-set" ) -type routerManager struct { - ctx context.Context - stop context.CancelFunc - iptablesClient *iptables.IPTables - rules map[string][]string +type routeFilteringRuleParams struct { + Sources []netip.Prefix + Destination netip.Prefix + Proto firewall.Protocol + SPort *firewall.Port + DPort *firewall.Port + Direction firewall.RuleDirection + Action firewall.Action + SetName string } -func newRouterManager(parentCtx context.Context, iptablesClient *iptables.IPTables) (*routerManager, error) { +type router struct { + ctx context.Context + stop context.CancelFunc + iptablesClient *iptables.IPTables + rules map[string][]string + ipsetCounter *refcounter.Counter[string, []netip.Prefix, struct{}] + wgIface iFaceMapper + legacyManagement bool +} + +func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) { ctx, cancel := context.WithCancel(parentCtx) - m := &routerManager{ + r := &router{ ctx: ctx, stop: cancel, iptablesClient: iptablesClient, rules: make(map[string][]string), + wgIface: wgIface, } - err := m.cleanUpDefaultForwardRules() + r.ipsetCounter = refcounter.New( + r.createIpSet, + func(name string, _ struct{}) error { + return r.deleteIpSet(name) + }, + ) + + if err := ipset.Init(); err != nil { + return nil, fmt.Errorf("init ipset: %w", err) + } + + err := r.cleanUpDefaultForwardRules() if err != nil { - log.Errorf("failed to cleanup routing rules: %s", err) + log.Errorf("cleanup routing rules: %s", err) return nil, err } - err = m.createContainers() + err = r.createContainers() if err != nil { - log.Errorf("failed to create containers for route: %s", err) + log.Errorf("create containers for route: %s", err) } - return m, err + return r, err } -// InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain -func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error { - err := i.insertRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, pair) - if err != nil { - return err +func (r *router) AddRouteFiltering( + sources []netip.Prefix, + destination netip.Prefix, + proto firewall.Protocol, + sPort *firewall.Port, + dPort *firewall.Port, + action firewall.Action, +) (firewall.Rule, error) { + ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action) + if _, ok := r.rules[string(ruleKey)]; ok { + return ruleKey, nil } - err = i.insertRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, firewall.GetInPair(pair)) - if err != nil { - return err + var setName string + if len(sources) > 1 { + setName = firewall.GenerateSetName(sources) + if _, err := r.ipsetCounter.Increment(setName, sources); err != nil { + return nil, fmt.Errorf("create or get ipset: %w", err) + } + } + + params := routeFilteringRuleParams{ + Sources: sources, + Destination: destination, + Proto: proto, + SPort: sPort, + DPort: dPort, + Action: action, + SetName: setName, + } + + rule := genRouteFilteringRuleSpec(params) + if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil { + return nil, fmt.Errorf("add route rule: %v", err) + } + + r.rules[string(ruleKey)] = rule + + return ruleKey, nil +} + +func (r *router) DeleteRouteRule(rule firewall.Rule) error { + ruleKey := rule.GetRuleID() + + if rule, exists := r.rules[ruleKey]; exists { + setName := r.findSetNameInRule(rule) + + if err := r.iptablesClient.Delete(tableFilter, chainRTFWD, rule...); err != nil { + return fmt.Errorf("delete route rule: %v", err) + } + delete(r.rules, ruleKey) + + if setName != "" { + if _, err := r.ipsetCounter.Decrement(setName); err != nil { + return fmt.Errorf("failed to remove ipset: %w", err) + } + } + } else { + log.Debugf("route rule %s not found", ruleKey) + } + + return nil +} + +func (r *router) findSetNameInRule(rule []string) string { + for i, arg := range rule { + if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet { + return rule[i+3] + } + } + return "" +} + +func (r *router) createIpSet(setName string, sources []netip.Prefix) (struct{}, error) { + if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil { + return struct{}{}, fmt.Errorf("create set %s: %w", setName, err) + } + + for _, prefix := range sources { + if err := ipset.AddPrefix(setName, prefix); err != nil { + return struct{}{}, fmt.Errorf("add element to set %s: %w", setName, err) + } + } + + return struct{}{}, nil +} + +func (r *router) deleteIpSet(setName string) error { + if err := ipset.Destroy(setName); err != nil { + return fmt.Errorf("destroy set %s: %w", setName, err) + } + return nil +} + +// AddNatRule inserts an iptables rule pair into the nat chain +func (r *router) AddNatRule(pair firewall.RouterPair) error { + if r.legacyManagement { + log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination) + if err := r.addLegacyRouteRule(pair); err != nil { + return fmt.Errorf("add legacy routing rule: %w", err) + } } if !pair.Masquerade { return nil } - err = i.addNATRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair) - if err != nil { - return err + if err := r.addNatRule(pair); err != nil { + return fmt.Errorf("add nat rule: %w", err) } - err = i.addNATRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair)) - if err != nil { - return err + if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil { + return fmt.Errorf("add inverse nat rule: %w", err) } return nil } -// insertRoutingRule inserts an iptables rule -func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error { - var err error +// RemoveNatRule removes an iptables rule pair from forwarding and nat chains +func (r *router) RemoveNatRule(pair firewall.RouterPair) error { + if err := r.removeNatRule(pair); err != nil { + return fmt.Errorf("remove nat rule: %w", err) + } - ruleKey := firewall.GenKey(keyFormat, pair.ID) - rule := genRuleSpec(jump, pair.Source, pair.Destination) - existingRule, found := i.rules[ruleKey] - if found { - err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...) - if err != nil { - return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err) + if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { + return fmt.Errorf("remove inverse nat rule: %w", err) + } + + if err := r.removeLegacyRouteRule(pair); err != nil { + return fmt.Errorf("remove legacy routing rule: %w", err) + } + + return nil +} + +// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls +func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { + ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) + + if err := r.removeLegacyRouteRule(pair); err != nil { + return err + } + + rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump} + if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil { + return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) + } + + r.rules[ruleKey] = rule + + return nil +} + +func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error { + ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) + + if rule, exists := r.rules[ruleKey]; exists { + if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil { + return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) } - delete(i.rules, ruleKey) - } - - err = i.iptablesClient.Insert(table, chain, 1, rule...) - if err != nil { - return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err) - } - - i.rules[ruleKey] = rule - - return nil -} - -// RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains -func (i *routerManager) RemoveRoutingRules(pair firewall.RouterPair) error { - err := i.removeRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, pair) - if err != nil { - return err - } - - err = i.removeRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, firewall.GetInPair(pair)) - if err != nil { - return err - } - - if !pair.Masquerade { - return nil - } - - err = i.removeRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, pair) - if err != nil { - return err - } - - err = i.removeRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, firewall.GetInPair(pair)) - if err != nil { - return err + delete(r.rules, ruleKey) + } else { + log.Debugf("legacy forwarding rule %s not found", ruleKey) } return nil } -func (i *routerManager) removeRoutingRule(keyFormat, table, chain string, pair firewall.RouterPair) error { - var err error +// GetLegacyManagement returns the current legacy management mode +func (r *router) GetLegacyManagement() bool { + return r.legacyManagement +} - ruleKey := firewall.GenKey(keyFormat, pair.ID) - existingRule, found := i.rules[ruleKey] - if found { - err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...) - if err != nil { - return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err) +// SetLegacyManagement sets the route manager to use legacy management mode +func (r *router) SetLegacyManagement(isLegacy bool) { + r.legacyManagement = isLegacy +} + +// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls +func (r *router) RemoveAllLegacyRouteRules() error { + var merr *multierror.Error + for k, rule := range r.rules { + if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) { + continue + } + if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err)) } } - delete(i.rules, ruleKey) - - return nil + return nberrors.FormatErrorOrNil(merr) } -func (i *routerManager) RouteingFwChainName() string { - return chainRTFWD -} - -func (i *routerManager) Reset() error { - err := i.cleanUpDefaultForwardRules() - if err != nil { - return err +func (r *router) Reset() error { + var merr *multierror.Error + if err := r.cleanUpDefaultForwardRules(); err != nil { + merr = multierror.Append(merr, err) } - i.rules = make(map[string][]string) - return nil + r.rules = make(map[string][]string) + + if err := r.ipsetCounter.Flush(); err != nil { + merr = multierror.Append(merr, err) + } + + return nberrors.FormatErrorOrNil(merr) } -func (i *routerManager) cleanUpDefaultForwardRules() error { - err := i.cleanJumpRules() +func (r *router) cleanUpDefaultForwardRules() error { + err := r.cleanJumpRules() if err != nil { return err } log.Debug("flushing routing related tables") - ok, err := i.iptablesClient.ChainExists(tableFilter, chainRTFWD) - if err != nil { - log.Errorf("failed check chain %s,error: %v", chainRTFWD, err) - return err - } else if ok { - err = i.iptablesClient.ClearAndDeleteChain(tableFilter, chainRTFWD) + for _, chain := range []string{chainRTFWD, chainRTNAT} { + table := r.getTableForChain(chain) + + ok, err := r.iptablesClient.ChainExists(table, chain) if err != nil { - log.Errorf("failed cleaning chain %s,error: %v", chainRTFWD, err) + log.Errorf("failed check chain %s, error: %v", chain, err) return err + } else if ok { + err = r.iptablesClient.ClearAndDeleteChain(table, chain) + if err != nil { + log.Errorf("failed cleaning chain %s, error: %v", chain, err) + return err + } } } - ok, err = i.iptablesClient.ChainExists(tableNat, chainRTNAT) - if err != nil { - log.Errorf("failed check chain %s,error: %v", chainRTNAT, err) - return err - } else if ok { - err = i.iptablesClient.ClearAndDeleteChain(tableNat, chainRTNAT) - if err != nil { - log.Errorf("failed cleaning chain %s,error: %v", chainRTNAT, err) - return err - } - } - return nil -} - -func (i *routerManager) createContainers() error { - if i.rules[Ipv4Forwarding] != nil { - return nil - } - - errMSGFormat := "failed creating chain %s,error: %v" - err := i.createChain(tableFilter, chainRTFWD) - if err != nil { - return fmt.Errorf(errMSGFormat, chainRTFWD, err) - } - - err = i.createChain(tableNat, chainRTNAT) - if err != nil { - return fmt.Errorf(errMSGFormat, chainRTNAT, err) - } - - err = i.addJumpRules() - if err != nil { - return fmt.Errorf("error while creating jump rules: %v", err) - } - return nil } -// addJumpRules create jump rules to send packets to NetBird chains -func (i *routerManager) addJumpRules() error { - rule := []string{"-j", chainRTFWD} - err := i.iptablesClient.Insert(tableFilter, chainFORWARD, 1, rule...) +func (r *router) createContainers() error { + for _, chain := range []string{chainRTFWD, chainRTNAT} { + if err := r.createAndSetupChain(chain); err != nil { + return fmt.Errorf("create chain %s: %w", chain, err) + } + } + + if err := r.insertEstablishedRule(chainRTFWD); err != nil { + return fmt.Errorf("insert established rule: %w", err) + } + + if err := r.addJumpRules(); err != nil { + return fmt.Errorf("add jump rules: %w", err) + } + + return nil +} + +func (r *router) createAndSetupChain(chain string) error { + table := r.getTableForChain(chain) + + if err := r.iptablesClient.NewChain(table, chain); err != nil { + return fmt.Errorf("failed creating chain %s, error: %v", chain, err) + } + + return nil +} + +func (r *router) getTableForChain(chain string) string { + if chain == chainRTNAT { + return tableNat + } + return tableFilter +} + +func (r *router) insertEstablishedRule(chain string) error { + establishedRule := getConntrackEstablished() + + err := r.iptablesClient.Insert(tableFilter, chain, 1, establishedRule...) + if err != nil { + return fmt.Errorf("failed to insert established rule: %v", err) + } + + ruleKey := "established-" + chain + r.rules[ruleKey] = establishedRule + + return nil +} + +func (r *router) addJumpRules() error { + rule := []string{"-j", chainRTNAT} + err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...) if err != nil { return err } - i.rules[Ipv4Forwarding] = rule - - rule = []string{"-j", chainRTNAT} - err = i.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...) - if err != nil { - return err - } - i.rules[ipv4Nat] = rule + r.rules[ipv4Nat] = rule return nil } -// cleanJumpRules cleans jump rules that was sending packets to NetBird chains -func (i *routerManager) cleanJumpRules() error { - var err error - errMSGFormat := "failed cleaning rule from chain %s,err: %v" - rule, found := i.rules[Ipv4Forwarding] +func (r *router) cleanJumpRules() error { + rule, found := r.rules[ipv4Nat] if found { - err = i.iptablesClient.DeleteIfExists(tableFilter, chainFORWARD, rule...) + err := r.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...) if err != nil { - return fmt.Errorf(errMSGFormat, chainFORWARD, err) - } - } - rule, found = i.rules[ipv4Nat] - if found { - err = i.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...) - if err != nil { - return fmt.Errorf(errMSGFormat, chainPOSTROUTING, err) + return fmt.Errorf("failed cleaning rule from chain %s, err: %v", chainPOSTROUTING, err) } } - rules, err := i.iptablesClient.List("nat", "POSTROUTING") - if err != nil { - return fmt.Errorf("failed to list rules: %s", err) - } - - for _, ruleString := range rules { - if !strings.Contains(ruleString, "NETBIRD") { - continue - } - rule := strings.Fields(ruleString) - err := i.iptablesClient.DeleteIfExists("nat", "POSTROUTING", rule[2:]...) - if err != nil { - return fmt.Errorf("failed to delete postrouting jump rule: %s", err) - } - } - - rules, err = i.iptablesClient.List(tableFilter, "FORWARD") - if err != nil { - return fmt.Errorf("failed to list rules in FORWARD chain: %s", err) - } - - for _, ruleString := range rules { - if !strings.Contains(ruleString, "NETBIRD") { - continue - } - rule := strings.Fields(ruleString) - err := i.iptablesClient.DeleteIfExists(tableFilter, "FORWARD", rule[2:]...) - if err != nil { - return fmt.Errorf("failed to delete FORWARD jump rule: %s", err) - } - } return nil } -func (i *routerManager) createChain(table, newChain string) error { - chains, err := i.iptablesClient.ListChains(table) - if err != nil { - return fmt.Errorf("couldn't get %s table chains, error: %v", table, err) - } +func (r *router) addNatRule(pair firewall.RouterPair) error { + ruleKey := firewall.GenKey(firewall.NatFormat, pair) - shouldCreateChain := true - for _, chain := range chains { - if chain == newChain { - shouldCreateChain = false - } - } - - if shouldCreateChain { - err = i.iptablesClient.NewChain(table, newChain) - if err != nil { - return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err) - } - - // Add the loopback return rule to the NAT chain - loopbackRule := []string{"-o", "lo", "-j", "RETURN"} - err = i.iptablesClient.Insert(table, newChain, 1, loopbackRule...) - if err != nil { - return fmt.Errorf("failed to add loopback return rule to %s: %v", chainRTNAT, err) - } - - err = i.iptablesClient.Append(table, newChain, "-j", "RETURN") - if err != nil { - return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err) - } - - } - return nil -} - -// addNATRule appends an iptables rule pair to the nat chain -func (i *routerManager) addNATRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error { - ruleKey := firewall.GenKey(keyFormat, pair.ID) - rule := genRuleSpec(jump, pair.Source, pair.Destination) - existingRule, found := i.rules[ruleKey] - if found { - err := i.iptablesClient.DeleteIfExists(table, chain, existingRule...) - if err != nil { + if rule, exists := r.rules[ruleKey]; exists { + if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil { return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err) } - delete(i.rules, ruleKey) + delete(r.rules, ruleKey) } - // inserting after loopback ignore rule - err := i.iptablesClient.Insert(table, chain, 2, rule...) - if err != nil { + rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, r.wgIface.Name(), pair.Inverse) + if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule...); err != nil { return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err) } - i.rules[ruleKey] = rule + r.rules[ruleKey] = rule return nil } -// genRuleSpec generates rule specification -func genRuleSpec(jump, source, destination string) []string { - return []string{"-s", source, "-d", destination, "-j", jump} +func (r *router) removeNatRule(pair firewall.RouterPair) error { + ruleKey := firewall.GenKey(firewall.NatFormat, pair) + + if rule, exists := r.rules[ruleKey]; exists { + if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil { + return fmt.Errorf("error while removing existing nat rule for %s: %v", pair.Destination, err) + } + + delete(r.rules, ruleKey) + } else { + log.Debugf("nat rule %s not found", ruleKey) + } + + return nil } -func getIptablesRuleType(table string) string { - ruleType := "forwarding" - if table == tableNat { - ruleType = "nat" +func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string { + intdir := "-i" + if inverse { + intdir = "-o" } - return ruleType + return []string{intdir, intf, "-s", source.String(), "-d", destination.String(), "-j", jump} +} + +func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string { + var rule []string + + if params.SetName != "" { + rule = append(rule, "-m", "set", matchSet, params.SetName, "src") + } else if len(params.Sources) > 0 { + source := params.Sources[0] + rule = append(rule, "-s", source.String()) + } + + rule = append(rule, "-d", params.Destination.String()) + + if params.Proto != firewall.ProtocolALL { + rule = append(rule, "-p", strings.ToLower(string(params.Proto))) + rule = append(rule, applyPort("--sport", params.SPort)...) + rule = append(rule, applyPort("--dport", params.DPort)...) + } + + rule = append(rule, "-j", actionToStr(params.Action)) + + return rule +} + +func applyPort(flag string, port *firewall.Port) []string { + if port == nil { + return nil + } + + if port.IsRange && len(port.Values) == 2 { + return []string{flag, fmt.Sprintf("%d:%d", port.Values[0], port.Values[1])} + } + + if len(port.Values) > 1 { + portList := make([]string, len(port.Values)) + for i, p := range port.Values { + portList[i] = strconv.Itoa(p) + } + return []string{"-m", "multiport", flag, strings.Join(portList, ",")} + } + + return []string{flag, strconv.Itoa(port.Values[0])} } diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go index 79b970c36..6cede09e2 100644 --- a/client/firewall/iptables/router_linux_test.go +++ b/client/firewall/iptables/router_linux_test.go @@ -4,11 +4,13 @@ package iptables import ( "context" + "net/netip" "os/exec" "testing" "github.com/coreos/go-iptables/iptables" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" firewall "github.com/netbirdio/netbird/client/firewall/manager" @@ -28,7 +30,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "failed to init iptables client") - manager, err := newRouterManager(context.TODO(), iptablesClient) + manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) require.NoError(t, err, "should return a valid iptables manager") defer func() { @@ -37,26 +39,22 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { require.Len(t, manager.rules, 2, "should have created rules map") - exists, err := manager.iptablesClient.Exists(tableFilter, chainFORWARD, manager.rules[Ipv4Forwarding]...) - require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainFORWARD) - require.True(t, exists, "forwarding rule should exist") - - exists, err = manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...) + exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING) require.True(t, exists, "postrouting rule should exist") pair := firewall.RouterPair{ ID: "abc", - Source: "100.100.100.1/32", - Destination: "100.100.100.0/24", + Source: netip.MustParsePrefix("100.100.100.1/32"), + Destination: netip.MustParsePrefix("100.100.100.0/24"), Masquerade: true, } - forward4Rule := genRuleSpec(routingFinalForwardJump, pair.Source, pair.Destination) + forward4Rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump} err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...) require.NoError(t, err, "inserting rule should not return error") - nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination) + nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, ifaceMock.Name(), false) err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...) require.NoError(t, err, "inserting rule should not return error") @@ -65,7 +63,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { require.NoError(t, err, "shouldn't return error") } -func TestIptablesManager_InsertRoutingRules(t *testing.T) { +func TestIptablesManager_AddNatRule(t *testing.T) { if !isIptablesSupported() { t.SkipNow() @@ -76,7 +74,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "failed to init iptables client") - manager, err := newRouterManager(context.TODO(), iptablesClient) + manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) require.NoError(t, err, "shouldn't return error") defer func() { @@ -86,35 +84,13 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) { } }() - err = manager.InsertRoutingRules(testCase.InputPair) + err = manager.AddNatRule(testCase.InputPair) require.NoError(t, err, "forwarding pair should be inserted") - forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID) - forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination) + natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) + natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false) - exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...) - require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD) - require.True(t, exists, "forwarding rule should exist") - - foundRule, found := manager.rules[forwardRuleKey] - require.True(t, found, "forwarding rule should exist in the manager map") - require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match") - - inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID) - inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination) - - exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...) - require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD) - require.True(t, exists, "income forwarding rule should exist") - - foundRule, found = manager.rules[inForwardRuleKey] - require.True(t, found, "income forwarding rule should exist in the manager map") - require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match") - - natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID) - natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination) - - exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...) + exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) if testCase.InputPair.Masquerade { require.True(t, exists, "nat rule should be created") @@ -127,8 +103,8 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) { require.False(t, foundNat, "nat rule should not exist in the map") } - inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID) - inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination) + inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) + inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true) exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) @@ -146,7 +122,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) { } } -func TestIptablesManager_RemoveRoutingRules(t *testing.T) { +func TestIptablesManager_RemoveNatRule(t *testing.T) { if !isIptablesSupported() { t.SkipNow() @@ -156,7 +132,7 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) { iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) - manager, err := newRouterManager(context.TODO(), iptablesClient) + manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) require.NoError(t, err, "shouldn't return error") defer func() { _ = manager.Reset() @@ -164,26 +140,14 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) { require.NoError(t, err, "shouldn't return error") - forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID) - forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination) - - err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, forwardRule...) - require.NoError(t, err, "inserting rule should not return error") - - inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID) - inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination) - - err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, inForwardRule...) - require.NoError(t, err, "inserting rule should not return error") - - natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID) - natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination) + natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) + natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false) err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...) require.NoError(t, err, "inserting rule should not return error") - inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID) - inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination) + inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) + inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true) err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...) require.NoError(t, err, "inserting rule should not return error") @@ -191,28 +155,14 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) { err = manager.Reset() require.NoError(t, err, "shouldn't return error") - err = manager.RemoveRoutingRules(testCase.InputPair) + err = manager.RemoveNatRule(testCase.InputPair) require.NoError(t, err, "shouldn't return error") - exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...) - require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD) - require.False(t, exists, "forwarding rule should not exist") - - _, found := manager.rules[forwardRuleKey] - require.False(t, found, "forwarding rule should exist in the manager map") - - exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...) - require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD) - require.False(t, exists, "income forwarding rule should not exist") - - _, found = manager.rules[inForwardRuleKey] - require.False(t, found, "income forwarding rule should exist in the manager map") - - exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...) + exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) require.False(t, exists, "nat rule should not exist") - _, found = manager.rules[natRuleKey] + _, found := manager.rules[natRuleKey] require.False(t, found, "nat rule should exist in the manager map") exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...) @@ -221,7 +171,175 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) { _, found = manager.rules[inNatRuleKey] require.False(t, found, "income nat rule should exist in the manager map") - + }) + } +} + +func TestRouter_AddRouteFiltering(t *testing.T) { + if !isIptablesSupported() { + t.Skip("iptables not supported on this system") + } + + iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) + require.NoError(t, err, "Failed to create iptables client") + + r, err := newRouter(context.Background(), iptablesClient, ifaceMock) + require.NoError(t, err, "Failed to create router manager") + + defer func() { + err := r.Reset() + require.NoError(t, err, "Failed to reset router") + }() + + tests := []struct { + name string + sources []netip.Prefix + destination netip.Prefix + proto firewall.Protocol + sPort *firewall.Port + dPort *firewall.Port + direction firewall.RuleDirection + action firewall.Action + expectSet bool + }{ + { + name: "Basic TCP rule with single source", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")}, + destination: netip.MustParsePrefix("10.0.0.0/24"), + proto: firewall.ProtocolTCP, + sPort: nil, + dPort: &firewall.Port{Values: []int{80}}, + direction: firewall.RuleDirectionIN, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "UDP rule with multiple sources", + sources: []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/16"), + netip.MustParsePrefix("192.168.0.0/16"), + }, + destination: netip.MustParsePrefix("10.0.0.0/8"), + proto: firewall.ProtocolUDP, + sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true}, + dPort: nil, + direction: firewall.RuleDirectionOUT, + action: firewall.ActionDrop, + expectSet: true, + }, + { + name: "All protocols rule", + sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, + destination: netip.MustParsePrefix("0.0.0.0/0"), + proto: firewall.ProtocolALL, + sPort: nil, + dPort: nil, + direction: firewall.RuleDirectionIN, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "ICMP rule", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + destination: netip.MustParsePrefix("10.0.0.0/8"), + proto: firewall.ProtocolICMP, + sPort: nil, + dPort: nil, + direction: firewall.RuleDirectionIN, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "TCP rule with multiple source ports", + sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")}, + destination: netip.MustParsePrefix("192.168.0.0/16"), + proto: firewall.ProtocolTCP, + sPort: &firewall.Port{Values: []int{80, 443, 8080}}, + dPort: nil, + direction: firewall.RuleDirectionOUT, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "UDP rule with single IP and port range", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")}, + destination: netip.MustParsePrefix("10.0.0.0/24"), + proto: firewall.ProtocolUDP, + sPort: nil, + dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true}, + direction: firewall.RuleDirectionIN, + action: firewall.ActionDrop, + expectSet: false, + }, + { + name: "TCP rule with source and destination ports", + sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}, + destination: netip.MustParsePrefix("172.16.0.0/16"), + proto: firewall.ProtocolTCP, + sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true}, + dPort: &firewall.Port{Values: []int{22}}, + direction: firewall.RuleDirectionOUT, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "Drop all incoming traffic", + sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, + destination: netip.MustParsePrefix("192.168.0.0/24"), + proto: firewall.ProtocolALL, + sPort: nil, + dPort: nil, + direction: firewall.RuleDirectionIN, + action: firewall.ActionDrop, + expectSet: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) + require.NoError(t, err, "AddRouteFiltering failed") + + // Check if the rule is in the internal map + rule, ok := r.rules[ruleKey.GetRuleID()] + assert.True(t, ok, "Rule not found in internal map") + + // Log the internal rule + t.Logf("Internal rule: %v", rule) + + // Check if the rule exists in iptables + exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, rule...) + assert.NoError(t, err, "Failed to check rule existence") + assert.True(t, exists, "Rule not found in iptables") + + // Verify rule content + params := routeFilteringRuleParams{ + Sources: tt.sources, + Destination: tt.destination, + Proto: tt.proto, + SPort: tt.sPort, + DPort: tt.dPort, + Action: tt.action, + SetName: "", + } + + expectedRule := genRouteFilteringRuleSpec(params) + + if tt.expectSet { + setName := firewall.GenerateSetName(tt.sources) + params.SetName = setName + expectedRule = genRouteFilteringRuleSpec(params) + + // Check if the set was created + _, exists := r.ipsetCounter.Get(setName) + assert.True(t, exists, "IPSet not created") + } + + assert.Equal(t, expectedRule, rule, "Rule content mismatch") + + // Clean up + err = r.DeleteRouteRule(ruleKey) + require.NoError(t, err, "Failed to delete rule") }) } } diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 6e4edb63e..556bda0d6 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -1,15 +1,21 @@ package manager import ( + "crypto/sha256" + "encoding/hex" "fmt" "net" + "net/netip" + "sort" + "strings" + + log "github.com/sirupsen/logrus" ) const ( - NatFormat = "netbird-nat-%s" - ForwardingFormat = "netbird-fwd-%s" - InNatFormat = "netbird-nat-in-%s" - InForwardingFormat = "netbird-fwd-in-%s" + ForwardingFormatPrefix = "netbird-fwd-" + ForwardingFormat = "netbird-fwd-%s-%t" + NatFormat = "netbird-nat-%s-%t" ) // Rule abstraction should be implemented by each firewall manager @@ -49,11 +55,11 @@ type Manager interface { // AllowNetbird allows netbird interface traffic AllowNetbird() error - // AddFiltering rule to the firewall + // AddPeerFiltering adds a rule to the firewall // // If comment argument is empty firewall manager should set // rule ID as comment for the rule - AddFiltering( + AddPeerFiltering( ip net.IP, proto Protocol, sPort *Port, @@ -64,17 +70,25 @@ type Manager interface { comment string, ) ([]Rule, error) - // DeleteRule from the firewall by rule definition - DeleteRule(rule Rule) error + // DeletePeerRule from the firewall by rule definition + DeletePeerRule(rule Rule) error // IsServerRouteSupported returns true if the firewall supports server side routing operations IsServerRouteSupported() bool - // InsertRoutingRules inserts a routing firewall rule - InsertRoutingRules(pair RouterPair) error + AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error) - // RemoveRoutingRules removes a routing firewall rule - RemoveRoutingRules(pair RouterPair) error + // DeleteRouteRule deletes a routing rule + DeleteRouteRule(rule Rule) error + + // AddNatRule inserts a routing NAT rule + AddNatRule(pair RouterPair) error + + // RemoveNatRule removes a routing NAT rule + RemoveNatRule(pair RouterPair) error + + // SetLegacyManagement sets the legacy management mode + SetLegacyManagement(legacy bool) error // Reset firewall to the default state Reset() error @@ -83,6 +97,89 @@ type Manager interface { Flush() error } -func GenKey(format string, input string) string { - return fmt.Sprintf(format, input) +func GenKey(format string, pair RouterPair) string { + return fmt.Sprintf(format, pair.ID, pair.Inverse) +} + +// LegacyManager defines the interface for legacy management operations +type LegacyManager interface { + RemoveAllLegacyRouteRules() error + GetLegacyManagement() bool + SetLegacyManagement(bool) +} + +// SetLegacyManagement sets the route manager to use legacy management +func SetLegacyManagement(router LegacyManager, isLegacy bool) error { + oldLegacy := router.GetLegacyManagement() + + if oldLegacy != isLegacy { + router.SetLegacyManagement(isLegacy) + log.Debugf("Set legacy management to %v", isLegacy) + } + + // client reconnected to a newer mgmt, we need to clean up the legacy rules + if !isLegacy && oldLegacy { + if err := router.RemoveAllLegacyRouteRules(); err != nil { + return fmt.Errorf("remove legacy routing rules: %v", err) + } + + log.Debugf("Legacy routing rules removed") + } + + return nil +} + +// GenerateSetName generates a unique name for an ipset based on the given sources. +func GenerateSetName(sources []netip.Prefix) string { + // sort for consistent naming + SortPrefixes(sources) + + var sourcesStr strings.Builder + for _, src := range sources { + sourcesStr.WriteString(src.String()) + } + + hash := sha256.Sum256([]byte(sourcesStr.String())) + shortHash := hex.EncodeToString(hash[:])[:8] + + return fmt.Sprintf("nb-%s", shortHash) +} + +// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix +func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix { + if len(prefixes) == 0 { + return prefixes + } + + merged := []netip.Prefix{prefixes[0]} + for _, prefix := range prefixes[1:] { + last := merged[len(merged)-1] + if last.Contains(prefix.Addr()) { + // If the current prefix is contained within the last merged prefix, skip it + continue + } + if prefix.Contains(last.Addr()) { + // If the current prefix contains the last merged prefix, replace it + merged[len(merged)-1] = prefix + } else { + // Otherwise, add the current prefix to the merged list + merged = append(merged, prefix) + } + } + + return merged +} + +// SortPrefixes sorts the given slice of netip.Prefix in place. +// It sorts first by IP address, then by prefix length (most specific to least specific). +func SortPrefixes(prefixes []netip.Prefix) { + sort.Slice(prefixes, func(i, j int) bool { + addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr()) + if addrCmp != 0 { + return addrCmp < 0 + } + + // If IP addresses are the same, compare prefix lengths (longer prefixes first) + return prefixes[i].Bits() > prefixes[j].Bits() + }) } diff --git a/client/firewall/manager/firewall_test.go b/client/firewall/manager/firewall_test.go new file mode 100644 index 000000000..3f47d6679 --- /dev/null +++ b/client/firewall/manager/firewall_test.go @@ -0,0 +1,192 @@ +package manager_test + +import ( + "net/netip" + "reflect" + "regexp" + "testing" + + "github.com/netbirdio/netbird/client/firewall/manager" +) + +func TestGenerateSetName(t *testing.T) { + t.Run("Different orders result in same hash", func(t *testing.T) { + prefixes1 := []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("10.0.0.0/8"), + } + prefixes2 := []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("192.168.1.0/24"), + } + + result1 := manager.GenerateSetName(prefixes1) + result2 := manager.GenerateSetName(prefixes2) + + if result1 != result2 { + t.Errorf("Different orders produced different hashes: %s != %s", result1, result2) + } + }) + + t.Run("Result format is correct", func(t *testing.T) { + prefixes := []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("10.0.0.0/8"), + } + + result := manager.GenerateSetName(prefixes) + + matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result) + if err != nil { + t.Fatalf("Error matching regex: %v", err) + } + if !matched { + t.Errorf("Result format is incorrect: %s", result) + } + }) + + t.Run("Empty input produces consistent result", func(t *testing.T) { + result1 := manager.GenerateSetName([]netip.Prefix{}) + result2 := manager.GenerateSetName([]netip.Prefix{}) + + if result1 != result2 { + t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2) + } + }) + + t.Run("IPv4 and IPv6 mixing", func(t *testing.T) { + prefixes1 := []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("2001:db8::/32"), + } + prefixes2 := []netip.Prefix{ + netip.MustParsePrefix("2001:db8::/32"), + netip.MustParsePrefix("192.168.1.0/24"), + } + + result1 := manager.GenerateSetName(prefixes1) + result2 := manager.GenerateSetName(prefixes2) + + if result1 != result2 { + t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2) + } + }) +} + +func TestMergeIPRanges(t *testing.T) { + tests := []struct { + name string + input []netip.Prefix + expected []netip.Prefix + }{ + { + name: "Empty input", + input: []netip.Prefix{}, + expected: []netip.Prefix{}, + }, + { + name: "Single range", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + }, + }, + { + name: "Two non-overlapping ranges", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("10.0.0.0/8"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("10.0.0.0/8"), + }, + }, + { + name: "One range containing another", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + netip.MustParsePrefix("192.168.1.0/24"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + }, + }, + { + name: "One range containing another (different order)", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.0.0/16"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + }, + }, + { + name: "Overlapping ranges", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.1.128/25"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + }, + }, + { + name: "Overlapping ranges (different order)", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.128/25"), + netip.MustParsePrefix("192.168.1.0/24"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + }, + }, + { + name: "Multiple overlapping ranges", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.2.0/24"), + netip.MustParsePrefix("192.168.1.128/25"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + }, + }, + { + name: "Partially overlapping ranges", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/23"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.2.0/25"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/23"), + netip.MustParsePrefix("192.168.2.0/25"), + }, + }, + { + name: "IPv6 ranges", + input: []netip.Prefix{ + netip.MustParsePrefix("2001:db8::/32"), + netip.MustParsePrefix("2001:db8:1::/48"), + netip.MustParsePrefix("2001:db8:2::/48"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("2001:db8::/32"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := manager.MergeIPRanges(tt.input) + if !reflect.DeepEqual(result, tt.expected) { + t.Errorf("MergeIPRanges() = %v, want %v", result, tt.expected) + } + }) + } +} diff --git a/client/firewall/manager/routerpair.go b/client/firewall/manager/routerpair.go index b63a9f104..8c94b7dd4 100644 --- a/client/firewall/manager/routerpair.go +++ b/client/firewall/manager/routerpair.go @@ -1,18 +1,26 @@ package manager +import ( + "net/netip" + + "github.com/netbirdio/netbird/route" +) + type RouterPair struct { - ID string - Source string - Destination string + ID route.ID + Source netip.Prefix + Destination netip.Prefix Masquerade bool + Inverse bool } -func GetInPair(pair RouterPair) RouterPair { +func GetInversePair(pair RouterPair) RouterPair { return RouterPair{ ID: pair.ID, // invert Source/Destination Source: pair.Destination, Destination: pair.Source, Masquerade: pair.Masquerade, + Inverse: true, } } diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index 1fa41b63a..61434f035 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -11,12 +11,14 @@ import ( "time" "github.com/google/nftables" + "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" firewall "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" + nbnet "github.com/netbirdio/netbird/util/net" ) const ( @@ -29,26 +31,26 @@ const ( chainNameInputFilter = "netbird-acl-input-filter" chainNameOutputFilter = "netbird-acl-output-filter" chainNameForwardFilter = "netbird-acl-forward-filter" + chainNamePrerouting = "netbird-rt-prerouting" allowNetbirdInputRuleID = "allow Netbird incoming traffic" ) +const flushError = "flush: %w" + var ( - anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} - postroutingMark = []byte{0xe4, 0x7, 0x0, 0x00} + anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} ) type AclManager struct { - rConn *nftables.Conn - sConn *nftables.Conn - wgIface iFaceMapper - routeingFwChainName string + rConn *nftables.Conn + sConn *nftables.Conn + wgIface iFaceMapper + routingFwChainName string workTable *nftables.Table chainInputRules *nftables.Chain chainOutputRules *nftables.Chain - chainFwFilter *nftables.Chain - chainPrerouting *nftables.Chain ipsetStore *ipsetStore rules map[string]*Rule @@ -61,10 +63,10 @@ type iFaceMapper interface { IsUserspaceBind() bool } -func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainName string) (*AclManager, error) { +func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) { // sConn is used for creating sets and adding/removing elements from them // it's differ then rConn (which does create new conn for each flush operation) - // and is permanent. Using same connection for booth type of operations + // and is permanent. Using same connection for both type of operations // overloads netlink with high amount of rules ( > 10000) sConn, err := nftables.New(nftables.AsLasting()) if err != nil { @@ -72,11 +74,11 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainNa } m := &AclManager{ - rConn: &nftables.Conn{}, - sConn: sConn, - wgIface: wgIface, - workTable: table, - routeingFwChainName: routeingFwChainName, + rConn: &nftables.Conn{}, + sConn: sConn, + wgIface: wgIface, + workTable: table, + routingFwChainName: routingFwChainName, ipsetStore: newIpsetStore(), rules: make(map[string]*Rule), @@ -90,11 +92,11 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainNa return m, nil } -// AddFiltering rule to the firewall +// AddPeerFiltering rule to the firewall // // If comment argument is empty firewall manager should set // rule ID as comment for the rule -func (m *AclManager) AddFiltering( +func (m *AclManager) AddPeerFiltering( ip net.IP, proto firewall.Protocol, sPort *firewall.Port, @@ -120,20 +122,11 @@ func (m *AclManager) AddFiltering( } newRules = append(newRules, ioRule) - if !shouldAddToPrerouting(proto, dPort, direction) { - return newRules, nil - } - - preroutingRule, err := m.addPreroutingFiltering(ipset, proto, dPort, ip) - if err != nil { - return newRules, err - } - newRules = append(newRules, preroutingRule) return newRules, nil } -// DeleteRule from the firewall by rule definition -func (m *AclManager) DeleteRule(rule firewall.Rule) error { +// DeletePeerRule from the firewall by rule definition +func (m *AclManager) DeletePeerRule(rule firewall.Rule) error { r, ok := rule.(*Rule) if !ok { return fmt.Errorf("invalid rule type") @@ -199,8 +192,7 @@ func (m *AclManager) DeleteRule(rule firewall.Rule) error { return nil } -// createDefaultAllowRules In case if the USP firewall manager can use the native firewall manager we must to create allow rules for -// input and output chains +// createDefaultAllowRules creates default allow rules for the input and output chains func (m *AclManager) createDefaultAllowRules() error { expIn := []expr.Any{ &expr.Payload{ @@ -214,13 +206,13 @@ func (m *AclManager) createDefaultAllowRules() error { SourceRegister: 1, DestRegister: 1, Len: 4, - Mask: []byte{0x00, 0x00, 0x00, 0x00}, - Xor: zeroXor, + Mask: []byte{0, 0, 0, 0}, + Xor: []byte{0, 0, 0, 0}, }, // net address &expr.Cmp{ Register: 1, - Data: []byte{0x00, 0x00, 0x00, 0x00}, + Data: []byte{0, 0, 0, 0}, }, &expr.Verdict{ Kind: expr.VerdictAccept, @@ -246,13 +238,13 @@ func (m *AclManager) createDefaultAllowRules() error { SourceRegister: 1, DestRegister: 1, Len: 4, - Mask: []byte{0x00, 0x00, 0x00, 0x00}, - Xor: zeroXor, + Mask: []byte{0, 0, 0, 0}, + Xor: []byte{0, 0, 0, 0}, }, // net address &expr.Cmp{ Register: 1, - Data: []byte{0x00, 0x00, 0x00, 0x00}, + Data: []byte{0, 0, 0, 0}, }, &expr.Verdict{ Kind: expr.VerdictAccept, @@ -266,10 +258,8 @@ func (m *AclManager) createDefaultAllowRules() error { Exprs: expOut, }) - err := m.rConn.Flush() - if err != nil { - log.Debugf("failed to create default allow rules: %s", err) - return err + if err := m.rConn.Flush(); err != nil { + return fmt.Errorf(flushError, err) } return nil } @@ -290,15 +280,11 @@ func (m *AclManager) Flush() error { log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err) } - if err := m.refreshRuleHandles(m.chainPrerouting); err != nil { - log.Errorf("failed to refresh rule handles IPv4 prerouting chain: %v", err) - } - return nil } func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) { - ruleId := generateRuleId(ip, sPort, dPort, direction, action, ipset) + ruleId := generatePeerRuleId(ip, sPort, dPort, direction, action, ipset) if r, ok := m.rules[ruleId]; ok { return &Rule{ r.nftRule, @@ -308,18 +294,7 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f }, nil } - ifaceKey := expr.MetaKeyIIFNAME - if direction == firewall.RuleDirectionOUT { - ifaceKey = expr.MetaKeyOIFNAME - } - expressions := []expr.Any{ - &expr.Meta{Key: ifaceKey, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - } + var expressions []expr.Any if proto != firewall.ProtocolALL { expressions = append(expressions, &expr.Payload{ @@ -329,21 +304,15 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f Len: uint32(1), }) - var protoData []byte - switch proto { - case firewall.ProtocolTCP: - protoData = []byte{unix.IPPROTO_TCP} - case firewall.ProtocolUDP: - protoData = []byte{unix.IPPROTO_UDP} - case firewall.ProtocolICMP: - protoData = []byte{unix.IPPROTO_ICMP} - default: - return nil, fmt.Errorf("unsupported protocol: %s", proto) + protoData, err := protoToInt(proto) + if err != nil { + return nil, fmt.Errorf("convert protocol to number: %v", err) } + expressions = append(expressions, &expr.Cmp{ Register: 1, Op: expr.CmpOpEq, - Data: protoData, + Data: []byte{protoData}, }) } @@ -432,10 +401,9 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f } else { chain = m.chainOutputRules } - nftRule := m.rConn.InsertRule(&nftables.Rule{ + nftRule := m.rConn.AddRule(&nftables.Rule{ Table: m.workTable, Chain: chain, - Position: 0, Exprs: expressions, UserData: userData, }) @@ -453,139 +421,13 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f return rule, nil } -func (m *AclManager) addPreroutingFiltering(ipset *nftables.Set, proto firewall.Protocol, port *firewall.Port, ip net.IP) (*Rule, error) { - var protoData []byte - switch proto { - case firewall.ProtocolTCP: - protoData = []byte{unix.IPPROTO_TCP} - case firewall.ProtocolUDP: - protoData = []byte{unix.IPPROTO_UDP} - case firewall.ProtocolICMP: - protoData = []byte{unix.IPPROTO_ICMP} - default: - return nil, fmt.Errorf("unsupported protocol: %s", proto) - } - - ruleId := generateRuleIdForMangle(ipset, ip, proto, port) - if r, ok := m.rules[ruleId]; ok { - return &Rule{ - r.nftRule, - r.nftSet, - r.ruleID, - ip, - }, nil - } - - var ipExpression expr.Any - // add individual IP for match if no ipset defined - rawIP := ip.To4() - if ipset == nil { - ipExpression = &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: rawIP, - } - } else { - ipExpression = &expr.Lookup{ - SourceRegister: 1, - SetName: ipset.Name, - SetID: ipset.ID, - } - } - - expressions := []expr.Any{ - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: 12, - Len: 4, - }, - ipExpression, - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: 16, - Len: 4, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: m.wgIface.Address().IP.To4(), - }, - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: uint32(9), - Len: uint32(1), - }, - &expr.Cmp{ - Register: 1, - Op: expr.CmpOpEq, - Data: protoData, - }, - } - - if port != nil { - expressions = append(expressions, - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseTransportHeader, - Offset: 2, - Len: 2, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: encodePort(*port), - }, - ) - } - - expressions = append(expressions, - &expr.Immediate{ - Register: 1, - Data: postroutingMark, - }, - &expr.Meta{ - Key: expr.MetaKeyMARK, - SourceRegister: true, - Register: 1, - }, - ) - - nftRule := m.rConn.InsertRule(&nftables.Rule{ - Table: m.workTable, - Chain: m.chainPrerouting, - Position: 0, - Exprs: expressions, - UserData: []byte(ruleId), - }) - - if err := m.rConn.Flush(); err != nil { - return nil, fmt.Errorf("flush insert rule: %v", err) - } - - rule := &Rule{ - nftRule: nftRule, - nftSet: ipset, - ruleID: ruleId, - ip: ip, - } - - m.rules[ruleId] = rule - if ipset != nil { - m.ipsetStore.AddReferenceToIpset(ipset.Name) - } - return rule, nil -} - func (m *AclManager) createDefaultChains() (err error) { // chainNameInputRules chain := m.createChain(chainNameInputRules) err = m.rConn.Flush() if err != nil { log.Debugf("failed to create chain (%s): %s", chain.Name, err) - return err + return fmt.Errorf(flushError, err) } m.chainInputRules = chain @@ -601,9 +443,6 @@ func (m *AclManager) createDefaultChains() (err error) { // netbird-acl-input-filter // type filter hook input priority filter; policy accept; chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput) - //netbird-acl-input-filter iifname "wt0" ip saddr 100.72.0.0/16 ip daddr != 100.72.0.0/16 accept - m.addRouteAllowRule(chain, expr.MetaKeyIIFNAME) - m.addFwdAllow(chain, expr.MetaKeyIIFNAME) m.addJumpRule(chain, m.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules m.addDropExpressions(chain, expr.MetaKeyIIFNAME) err = m.rConn.Flush() @@ -615,7 +454,6 @@ func (m *AclManager) createDefaultChains() (err error) { // netbird-acl-output-filter // type filter hook output priority filter; policy accept; chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput) - m.addRouteAllowRule(chain, expr.MetaKeyOIFNAME) m.addFwdAllow(chain, expr.MetaKeyOIFNAME) m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules m.addDropExpressions(chain, expr.MetaKeyOIFNAME) @@ -626,29 +464,106 @@ func (m *AclManager) createDefaultChains() (err error) { } // netbird-acl-forward-filter - m.chainFwFilter = m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward) - m.addJumpRulesToRtForward() // to - m.addMarkAccept() - m.addJumpRuleToInputChain() // to netbird-acl-input-rules - m.addDropExpressions(m.chainFwFilter, expr.MetaKeyIIFNAME) + chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward) + m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd + m.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME) + err = m.rConn.Flush() if err != nil { log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err) - return err + return fmt.Errorf(flushError, err) } - // netbird-acl-output-filter - // type filter hook output priority filter; policy accept; - m.chainPrerouting = m.createPreroutingMangle() - err = m.rConn.Flush() - if err != nil { - log.Debugf("failed to create chain (%s): %s", m.chainPrerouting.Name, err) - return err + if err := m.allowRedirectedTraffic(chainFwFilter); err != nil { + log.Errorf("failed to allow redirected traffic: %s", err) } + return nil } -func (m *AclManager) addJumpRulesToRtForward() { +// Makes redirected traffic originally destined for the host itself (now subject to the forward filter) +// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the +// netbird peer IP. +func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error { + preroutingChain := m.rConn.AddChain(&nftables.Chain{ + Name: chainNamePrerouting, + Table: m.workTable, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookPrerouting, + Priority: nftables.ChainPriorityMangle, + }) + + m.addPreroutingRule(preroutingChain) + + m.addFwmarkToForward(chainFwFilter) + + if err := m.rConn.Flush(); err != nil { + return fmt.Errorf(flushError, err) + } + + return nil +} + +func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) { + m.rConn.AddRule(&nftables.Rule{ + Table: m.workTable, + Chain: preroutingChain, + Exprs: []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyIIFNAME, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(m.wgIface.Name()), + }, + &expr.Fib{ + Register: 1, + ResultADDRTYPE: true, + FlagDADDR: true, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL), + }, + &expr.Immediate{ + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark), + }, + &expr.Meta{ + Key: expr.MetaKeyMARK, + Register: 1, + SourceRegister: true, + }, + }, + }) +} + +func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) { + m.rConn.InsertRule(&nftables.Rule{ + Table: m.workTable, + Chain: chainFwFilter, + Exprs: []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyMARK, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark), + }, + &expr.Verdict{ + Kind: expr.VerdictJump, + Chain: m.chainInputRules.Name, + }, + }, + }) +} + +func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) { expressions := []expr.Any{ &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, &expr.Cmp{ @@ -658,68 +573,15 @@ func (m *AclManager) addJumpRulesToRtForward() { }, &expr.Verdict{ Kind: expr.VerdictJump, - Chain: m.routeingFwChainName, + Chain: m.routingFwChainName, }, } _ = m.rConn.AddRule(&nftables.Rule{ Table: m.workTable, - Chain: m.chainFwFilter, + Chain: chainFwFilter, Exprs: expressions, }) - - expressions = []expr.Any{ - &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Verdict{ - Kind: expr.VerdictJump, - Chain: m.routeingFwChainName, - }, - } - - _ = m.rConn.AddRule(&nftables.Rule{ - Table: m.workTable, - Chain: m.chainFwFilter, - Exprs: expressions, - }) -} - -func (m *AclManager) addMarkAccept() { - // oifname "wt0" meta mark 0x000007e4 accept - // iifname "wt0" meta mark 0x000007e4 accept - ifaces := []expr.MetaKey{expr.MetaKeyIIFNAME, expr.MetaKeyOIFNAME} - for _, iface := range ifaces { - expressions := []expr.Any{ - &expr.Meta{Key: iface, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Meta{ - Key: expr.MetaKeyMARK, - Register: 1, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: postroutingMark, - }, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - } - - _ = m.rConn.AddRule(&nftables.Rule{ - Table: m.workTable, - Chain: m.chainFwFilter, - Exprs: expressions, - }) - } } func (m *AclManager) createChain(name string) *nftables.Chain { @@ -729,10 +591,13 @@ func (m *AclManager) createChain(name string) *nftables.Chain { } chain = m.rConn.AddChain(chain) + + insertReturnTrafficRule(m.rConn, m.workTable, chain) + return chain } -func (m *AclManager) createFilterChainWithHook(name string, hookNum nftables.ChainHook) *nftables.Chain { +func (m *AclManager) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain { polAccept := nftables.ChainPolicyAccept chain := &nftables.Chain{ Name: name, @@ -746,74 +611,6 @@ func (m *AclManager) createFilterChainWithHook(name string, hookNum nftables.Cha return m.rConn.AddChain(chain) } -func (m *AclManager) createPreroutingMangle() *nftables.Chain { - polAccept := nftables.ChainPolicyAccept - chain := &nftables.Chain{ - Name: "netbird-acl-prerouting-filter", - Table: m.workTable, - Hooknum: nftables.ChainHookPrerouting, - Priority: nftables.ChainPriorityMangle, - Type: nftables.ChainTypeFilter, - Policy: &polAccept, - } - - chain = m.rConn.AddChain(chain) - - ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4()) - expressions := []expr.Any{ - &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: 12, - Len: 4, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: m.wgIface.Address().Network.Mask, - }, - &expr.Cmp{ - Op: expr.CmpOpNeq, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: 16, - Len: 4, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: m.wgIface.Address().IP.To4(), - }, - &expr.Immediate{ - Register: 1, - Data: postroutingMark, - }, - &expr.Meta{ - Key: expr.MetaKeyMARK, - SourceRegister: true, - Register: 1, - }, - } - _ = m.rConn.AddRule(&nftables.Rule{ - Table: m.workTable, - Chain: chain, - Exprs: expressions, - }) - return chain -} - func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any { expressions := []expr.Any{ &expr.Meta{Key: ifaceKey, Register: 1}, @@ -832,101 +629,9 @@ func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.Met return nil } -func (m *AclManager) addJumpRuleToInputChain() { - expressions := []expr.Any{ - &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Verdict{ - Kind: expr.VerdictJump, - Chain: m.chainInputRules.Name, - }, - } - - _ = m.rConn.AddRule(&nftables.Rule{ - Table: m.workTable, - Chain: m.chainFwFilter, - Exprs: expressions, - }) -} - -func (m *AclManager) addRouteAllowRule(chain *nftables.Chain, netIfName expr.MetaKey) { - ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4()) - var srcOp, dstOp expr.CmpOp - if netIfName == expr.MetaKeyIIFNAME { - srcOp = expr.CmpOpEq - dstOp = expr.CmpOpNeq - } else { - srcOp = expr.CmpOpNeq - dstOp = expr.CmpOpEq - } - expressions := []expr.Any{ - &expr.Meta{Key: netIfName, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: 12, - Len: 4, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: m.wgIface.Address().Network.Mask, - }, - &expr.Cmp{ - Op: srcOp, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: 16, - Len: 4, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: m.wgIface.Address().Network.Mask, - }, - &expr.Cmp{ - Op: dstOp, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - } - _ = m.rConn.AddRule(&nftables.Rule{ - Table: chain.Table, - Chain: chain, - Exprs: expressions, - }) -} - func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) { ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4()) - var srcOp, dstOp expr.CmpOp - if iifname == expr.MetaKeyIIFNAME { - srcOp = expr.CmpOpNeq - dstOp = expr.CmpOpEq - } else { - srcOp = expr.CmpOpEq - dstOp = expr.CmpOpNeq - } + dstOp := expr.CmpOpNeq expressions := []expr.Any{ &expr.Meta{Key: iifname, Register: 1}, &expr.Cmp{ @@ -934,24 +639,6 @@ func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) { Register: 1, Data: ifname(m.wgIface.Name()), }, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: 12, - Len: 4, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: m.wgIface.Address().Network.Mask, - }, - &expr.Cmp{ - Op: srcOp, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, &expr.Payload{ DestRegister: 2, Base: expr.PayloadBaseNetworkHeader, @@ -982,7 +669,6 @@ func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) { } func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) { - ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4()) expressions := []expr.Any{ &expr.Meta{Key: ifaceKey, Register: 1}, &expr.Cmp{ @@ -990,47 +676,12 @@ func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr Register: 1, Data: ifname(m.wgIface.Name()), }, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: 12, - Len: 4, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: m.wgIface.Address().Network.Mask, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: 16, - Len: 4, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: m.wgIface.Address().Network.Mask, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, &expr.Verdict{ Kind: expr.VerdictJump, Chain: to, }, } + _ = m.rConn.AddRule(&nftables.Rule{ Table: chain.Table, Chain: chain, @@ -1132,7 +783,7 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error { return nil } -func generateRuleId( +func generatePeerRuleId( ip net.IP, sPort *firewall.Port, dPort *firewall.Port, @@ -1155,33 +806,6 @@ func generateRuleId( } return "set:" + ipset.Name + rulesetID } -func generateRuleIdForMangle(ipset *nftables.Set, ip net.IP, proto firewall.Protocol, port *firewall.Port) string { - // case of icmp port is empty - var p string - if port != nil { - p = port.String() - } - if ipset != nil { - return fmt.Sprintf("p:set:%s:%s:%v", ipset.Name, proto, p) - } else { - return fmt.Sprintf("p:ip:%s:%s:%v", ip.String(), proto, p) - } -} - -func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool { - if proto == "all" { - return false - } - - if direction != firewall.RuleDirectionIN { - return false - } - - if dPort == nil && proto != firewall.ProtocolICMP { - return false - } - return true -} func encodePort(port firewall.Port) []byte { bs := make([]byte, 2) @@ -1191,6 +815,19 @@ func encodePort(port firewall.Port) []byte { func ifname(n string) []byte { b := make([]byte, 16) - copy(b, []byte(n+"\x00")) + copy(b, n+"\x00") return b } + +func protoToInt(protocol firewall.Protocol) (uint8, error) { + switch protocol { + case firewall.ProtocolTCP: + return unix.IPPROTO_TCP, nil + case firewall.ProtocolUDP: + return unix.IPPROTO_UDP, nil + case firewall.ProtocolICMP: + return unix.IPPROTO_ICMP, nil + } + + return 0, fmt.Errorf("unsupported protocol: %s", protocol) +} diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index a376c98c3..01b08bd71 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -5,9 +5,11 @@ import ( "context" "fmt" "net" + "net/netip" "sync" "github.com/google/nftables" + "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" log "github.com/sirupsen/logrus" @@ -15,8 +17,11 @@ import ( ) const ( - // tableName is the name of the table that is used for filtering by the Netbird client - tableName = "netbird" + // tableNameNetbird is the name of the table that is used for filtering by the Netbird client + tableNameNetbird = "netbird" + + tableNameFilter = "filter" + chainNameInput = "INPUT" ) // Manager of iptables firewall @@ -41,12 +46,12 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { return nil, err } - m.router, err = newRouter(context, workTable) + m.router, err = newRouter(context, workTable, wgIface) if err != nil { return nil, err } - m.aclManager, err = newAclManager(workTable, wgIface, m.router.RouteingFwChainName()) + m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw) if err != nil { return nil, err } @@ -54,11 +59,11 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { return m, nil } -// AddFiltering rule to the firewall +// AddPeerFiltering rule to the firewall // // If comment argument is empty firewall manager should set // rule ID as comment for the rule -func (m *Manager) AddFiltering( +func (m *Manager) AddPeerFiltering( ip net.IP, proto firewall.Protocol, sPort *firewall.Port, @@ -76,33 +81,52 @@ func (m *Manager) AddFiltering( return nil, fmt.Errorf("unsupported IP version: %s", ip.String()) } - return m.aclManager.AddFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment) + return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment) } -// DeleteRule from the firewall by rule definition -func (m *Manager) DeleteRule(rule firewall.Rule) error { +func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) { m.mutex.Lock() defer m.mutex.Unlock() - return m.aclManager.DeleteRule(rule) + if !destination.Addr().Is4() { + return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String()) + } + + return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action) +} + +// DeletePeerRule from the firewall by rule definition +func (m *Manager) DeletePeerRule(rule firewall.Rule) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.aclManager.DeletePeerRule(rule) +} + +// DeleteRouteRule deletes a routing rule +func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.DeleteRouteRule(rule) } func (m *Manager) IsServerRouteSupported() bool { return true } -func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error { +func (m *Manager) AddNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.AddRoutingRules(pair) + return m.router.AddNatRule(pair) } -func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error { +func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.RemoveRoutingRules(pair) + return m.router.RemoveNatRule(pair) } // AllowNetbird allows netbird interface traffic @@ -126,7 +150,7 @@ func (m *Manager) AllowNetbird() error { var chain *nftables.Chain for _, c := range chains { - if c.Table.Name == "filter" && c.Name == "INPUT" { + if c.Table.Name == tableNameFilter && c.Name == chainNameForward { chain = c break } @@ -157,6 +181,27 @@ func (m *Manager) AllowNetbird() error { return nil } +// SetLegacyManagement sets the route manager to use legacy management +func (m *Manager) SetLegacyManagement(isLegacy bool) error { + oldLegacy := m.router.legacyManagement + + if oldLegacy != isLegacy { + m.router.legacyManagement = isLegacy + log.Debugf("Set legacy management to %v", isLegacy) + } + + // client reconnected to a newer mgmt, we need to cleanup the legacy rules + if !isLegacy && oldLegacy { + if err := m.router.RemoveAllLegacyRouteRules(); err != nil { + return fmt.Errorf("remove legacy routing rules: %v", err) + } + + log.Debugf("Legacy routing rules removed") + } + + return nil +} + // Reset firewall to the default state func (m *Manager) Reset() error { m.mutex.Lock() @@ -185,14 +230,16 @@ func (m *Manager) Reset() error { } } - m.router.ResetForwardRules() + if err := m.router.Reset(); err != nil { + return fmt.Errorf("reset forward rules: %v", err) + } tables, err := m.rConn.ListTables() if err != nil { return fmt.Errorf("list of tables: %w", err) } for _, t := range tables { - if t.Name == tableName { + if t.Name == tableNameNetbird { m.rConn.DelTable(t) } } @@ -218,12 +265,12 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) { } for _, t := range tables { - if t.Name == tableName { + if t.Name == tableNameNetbird { m.rConn.DelTable(t) } } - table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}) + table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}) err = m.rConn.Flush() return table, err } @@ -239,9 +286,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) { Register: 1, Data: ifname(m.wgIface.Name()), }, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, + &expr.Verdict{}, }, UserData: []byte(allowNetbirdInputRuleID), } @@ -251,7 +296,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) { func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule { ifName := ifname(m.wgIface.Name()) for _, rule := range existedRules { - if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" { + if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput { if len(rule.Exprs) < 4 { if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME { continue @@ -265,3 +310,38 @@ func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftable } return nil } + +func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) { + rule := &nftables.Rule{ + Table: table, + Chain: chain, + Exprs: getEstablishedExprs(1), + } + + conn.InsertRule(rule) +} + +func getEstablishedExprs(register uint32) []expr.Any { + return []expr.Any{ + &expr.Ct{ + Key: expr.CtKeySTATE, + Register: register, + }, + &expr.Bitwise{ + SourceRegister: register, + DestRegister: register, + Len: 4, + Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED), + Xor: binaryutil.NativeEndian.PutUint32(0), + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: register, + Data: []byte{0, 0, 0, 0}, + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + } +} diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 1f226e315..bbe18ab07 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -9,14 +9,30 @@ import ( "time" "github.com/google/nftables" + "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" "github.com/stretchr/testify/require" "golang.org/x/sys/unix" fw "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" ) +var ifaceMock = &iFaceMock{ + NameFunc: func() string { + return "lo" + }, + AddressFunc: func() iface.WGAddress { + return iface.WGAddress{ + IP: net.ParseIP("100.96.0.1"), + Network: &net.IPNet{ + IP: net.ParseIP("100.96.0.0"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + } + }, +} + // iFaceMapper defines subset methods of interface required for manager type iFaceMock struct { NameFunc func() string @@ -40,23 +56,9 @@ func (i *iFaceMock) Address() iface.WGAddress { func (i *iFaceMock) IsUserspaceBind() bool { return false } func TestNftablesManager(t *testing.T) { - mock := &iFaceMock{ - NameFunc: func() string { - return "lo" - }, - AddressFunc: func() iface.WGAddress { - return iface.WGAddress{ - IP: net.ParseIP("100.96.0.1"), - Network: &net.IPNet{ - IP: net.ParseIP("100.96.0.0"), - Mask: net.IPv4Mask(255, 255, 255, 0), - }, - } - }, - } // just check on the local interface - manager, err := Create(context.Background(), mock) + manager, err := Create(context.Background(), ifaceMock) require.NoError(t, err) time.Sleep(time.Second * 3) @@ -70,7 +72,7 @@ func TestNftablesManager(t *testing.T) { testClient := &nftables.Conn{} - rule, err := manager.AddFiltering( + rule, err := manager.AddPeerFiltering( ip, fw.ProtocolTCP, nil, @@ -88,17 +90,35 @@ func TestNftablesManager(t *testing.T) { rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules) require.NoError(t, err, "failed to get rules") - require.Len(t, rules, 1, "expected 1 rules") + require.Len(t, rules, 2, "expected 2 rules") + + expectedExprs1 := []expr.Any{ + &expr.Ct{ + Key: expr.CtKeySTATE, + Register: 1, + }, + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED), + Xor: binaryutil.NativeEndian.PutUint32(0), + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: []byte{0, 0, 0, 0}, + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + } + require.ElementsMatch(t, rules[0].Exprs, expectedExprs1, "expected the same expressions") ipToAdd, _ := netip.AddrFromSlice(ip) add := ipToAdd.Unmap() - expectedExprs := []expr.Any{ - &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname("lo"), - }, + expectedExprs2 := []expr.Any{ &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, @@ -134,10 +154,10 @@ func TestNftablesManager(t *testing.T) { }, &expr.Verdict{Kind: expr.VerdictDrop}, } - require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions") + require.ElementsMatch(t, rules[1].Exprs, expectedExprs2, "expected the same expressions") for _, r := range rule { - err = manager.DeleteRule(r) + err = manager.DeletePeerRule(r) require.NoError(t, err, "failed to delete rule") } @@ -146,7 +166,8 @@ func TestNftablesManager(t *testing.T) { rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules) require.NoError(t, err, "failed to get rules") - require.Len(t, rules, 0, "expected 0 rules after deletion") + // established rule remains + require.Len(t, rules, 1, "expected 1 rules after deletion") err = manager.Reset() require.NoError(t, err, "failed to reset") @@ -187,9 +208,9 @@ func TestNFtablesCreatePerformance(t *testing.T) { for i := 0; i < testMax; i++ { port := &fw.Port{Values: []int{1000 + i}} if i%2 == 0 { - _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") + _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") } else { - _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") + _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") } require.NoError(t, err, "failed to add rule") diff --git a/client/firewall/nftables/route_linux.go b/client/firewall/nftables/route_linux.go deleted file mode 100644 index 71d5ac88e..000000000 --- a/client/firewall/nftables/route_linux.go +++ /dev/null @@ -1,431 +0,0 @@ -package nftables - -import ( - "bytes" - "context" - "errors" - "fmt" - "net" - "net/netip" - - "github.com/google/nftables" - "github.com/google/nftables/binaryutil" - "github.com/google/nftables/expr" - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/firewall/manager" -) - -const ( - chainNameRouteingFw = "netbird-rt-fwd" - chainNameRoutingNat = "netbird-rt-nat" - - userDataAcceptForwardRuleSrc = "frwacceptsrc" - userDataAcceptForwardRuleDst = "frwacceptdst" - - loopbackInterface = "lo\x00" -) - -// some presets for building nftable rules -var ( - zeroXor = binaryutil.NativeEndian.PutUint32(0) - - exprCounterAccept = []expr.Any{ - &expr.Counter{}, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - } - - errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found") -) - -type router struct { - ctx context.Context - stop context.CancelFunc - conn *nftables.Conn - workTable *nftables.Table - filterTable *nftables.Table - chains map[string]*nftables.Chain - // rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules - rules map[string]*nftables.Rule - isDefaultFwdRulesEnabled bool -} - -func newRouter(parentCtx context.Context, workTable *nftables.Table) (*router, error) { - ctx, cancel := context.WithCancel(parentCtx) - - r := &router{ - ctx: ctx, - stop: cancel, - conn: &nftables.Conn{}, - workTable: workTable, - chains: make(map[string]*nftables.Chain), - rules: make(map[string]*nftables.Rule), - } - - var err error - r.filterTable, err = r.loadFilterTable() - if err != nil { - if errors.Is(err, errFilterTableNotFound) { - log.Warnf("table 'filter' not found for forward rules") - } else { - return nil, err - } - } - - err = r.cleanUpDefaultForwardRules() - if err != nil { - log.Errorf("failed to clean up rules from FORWARD chain: %s", err) - } - - err = r.createContainers() - if err != nil { - log.Errorf("failed to create containers for route: %s", err) - } - return r, err -} - -func (r *router) RouteingFwChainName() string { - return chainNameRouteingFw -} - -// ResetForwardRules cleans existing nftables default forward rules from the system -func (r *router) ResetForwardRules() { - err := r.cleanUpDefaultForwardRules() - if err != nil { - log.Errorf("failed to reset forward rules: %s", err) - } -} - -func (r *router) loadFilterTable() (*nftables.Table, error) { - tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4) - if err != nil { - return nil, fmt.Errorf("nftables: unable to list tables: %v", err) - } - - for _, table := range tables { - if table.Name == "filter" { - return table, nil - } - } - - return nil, errFilterTableNotFound -} - -func (r *router) createContainers() error { - - r.chains[chainNameRouteingFw] = r.conn.AddChain(&nftables.Chain{ - Name: chainNameRouteingFw, - Table: r.workTable, - }) - - r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{ - Name: chainNameRoutingNat, - Table: r.workTable, - Hooknum: nftables.ChainHookPostrouting, - Priority: nftables.ChainPriorityNATSource - 1, - Type: nftables.ChainTypeNAT, - }) - - // Add RETURN rule for loopback interface - loRule := &nftables.Rule{ - Table: r.workTable, - Chain: r.chains[chainNameRoutingNat], - Exprs: []expr.Any{ - &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: []byte(loopbackInterface), - }, - &expr.Verdict{Kind: expr.VerdictReturn}, - }, - } - r.conn.InsertRule(loRule) - - err := r.refreshRulesMap() - if err != nil { - log.Errorf("failed to clean up rules from FORWARD chain: %s", err) - } - - err = r.conn.Flush() - if err != nil { - return fmt.Errorf("nftables: unable to initialize table: %v", err) - } - return nil -} - -// AddRoutingRules appends a nftable rule pair to the forwarding chain and if enabled, to the nat chain -func (r *router) AddRoutingRules(pair manager.RouterPair) error { - err := r.refreshRulesMap() - if err != nil { - return err - } - - err = r.addRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false) - if err != nil { - return err - } - err = r.addRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false) - if err != nil { - return err - } - - if pair.Masquerade { - err = r.addRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true) - if err != nil { - return err - } - err = r.addRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true) - if err != nil { - return err - } - } - - if r.filterTable != nil && !r.isDefaultFwdRulesEnabled { - log.Debugf("add default accept forward rule") - r.acceptForwardRule(pair.Source) - } - - err = r.conn.Flush() - if err != nil { - return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.Destination, err) - } - return nil -} - -// addRoutingRule inserts a nftable rule to the conn client flush queue -func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error { - sourceExp := generateCIDRMatcherExpressions(true, pair.Source) - destExp := generateCIDRMatcherExpressions(false, pair.Destination) - - var expression []expr.Any - if isNat { - expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) // nolint:gocritic - } else { - expression = append(sourceExp, append(destExp, exprCounterAccept...)...) // nolint:gocritic - } - - ruleKey := manager.GenKey(format, pair.ID) - - _, exists := r.rules[ruleKey] - if exists { - err := r.removeRoutingRule(format, pair) - if err != nil { - return err - } - } - - r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ - Table: r.workTable, - Chain: r.chains[chainName], - Exprs: expression, - UserData: []byte(ruleKey), - }) - return nil -} - -func (r *router) acceptForwardRule(sourceNetwork string) { - src := generateCIDRMatcherExpressions(true, sourceNetwork) - dst := generateCIDRMatcherExpressions(false, "0.0.0.0/0") - - var exprs []expr.Any - exprs = append(src, append(dst, &expr.Verdict{ // nolint:gocritic - Kind: expr.VerdictAccept, - })...) - - rule := &nftables.Rule{ - Table: r.filterTable, - Chain: &nftables.Chain{ - Name: "FORWARD", - Table: r.filterTable, - Type: nftables.ChainTypeFilter, - Hooknum: nftables.ChainHookForward, - Priority: nftables.ChainPriorityFilter, - }, - Exprs: exprs, - UserData: []byte(userDataAcceptForwardRuleSrc), - } - - r.conn.AddRule(rule) - - src = generateCIDRMatcherExpressions(true, "0.0.0.0/0") - dst = generateCIDRMatcherExpressions(false, sourceNetwork) - - exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic - Kind: expr.VerdictAccept, - })...) - - rule = &nftables.Rule{ - Table: r.filterTable, - Chain: &nftables.Chain{ - Name: "FORWARD", - Table: r.filterTable, - Type: nftables.ChainTypeFilter, - Hooknum: nftables.ChainHookForward, - Priority: nftables.ChainPriorityFilter, - }, - Exprs: exprs, - UserData: []byte(userDataAcceptForwardRuleDst), - } - r.conn.AddRule(rule) - r.isDefaultFwdRulesEnabled = true -} - -// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains -func (r *router) RemoveRoutingRules(pair manager.RouterPair) error { - err := r.refreshRulesMap() - if err != nil { - return err - } - - err = r.removeRoutingRule(manager.ForwardingFormat, pair) - if err != nil { - return err - } - - err = r.removeRoutingRule(manager.InForwardingFormat, manager.GetInPair(pair)) - if err != nil { - return err - } - - err = r.removeRoutingRule(manager.NatFormat, pair) - if err != nil { - return err - } - - err = r.removeRoutingRule(manager.InNatFormat, manager.GetInPair(pair)) - if err != nil { - return err - } - - if len(r.rules) == 0 { - err := r.cleanUpDefaultForwardRules() - if err != nil { - log.Errorf("failed to clean up rules from FORWARD chain: %s", err) - } - } - - err = r.conn.Flush() - if err != nil { - return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err) - } - log.Debugf("nftables: removed rules for %s", pair.Destination) - return nil -} - -// removeRoutingRule add a nftable rule to the removal queue and delete from rules map -func (r *router) removeRoutingRule(format string, pair manager.RouterPair) error { - ruleKey := manager.GenKey(format, pair.ID) - - rule, found := r.rules[ruleKey] - if found { - ruleType := "forwarding" - if rule.Chain.Type == nftables.ChainTypeNAT { - ruleType = "nat" - } - - err := r.conn.DelRule(rule) - if err != nil { - return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.Destination, err) - } - - log.Debugf("nftables: removing %s rule for %s", ruleType, pair.Destination) - - delete(r.rules, ruleKey) - } - return nil -} - -// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid -// duplicates and to get missing attributes that we don't have when adding new rules -func (r *router) refreshRulesMap() error { - for _, chain := range r.chains { - rules, err := r.conn.GetRules(chain.Table, chain) - if err != nil { - return fmt.Errorf("nftables: unable to list rules: %v", err) - } - for _, rule := range rules { - if len(rule.UserData) > 0 { - r.rules[string(rule.UserData)] = rule - } - } - } - return nil -} - -func (r *router) cleanUpDefaultForwardRules() error { - if r.filterTable == nil { - r.isDefaultFwdRulesEnabled = false - return nil - } - - chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) - if err != nil { - return err - } - - var rules []*nftables.Rule - for _, chain := range chains { - if chain.Table.Name != r.filterTable.Name { - continue - } - if chain.Name != "FORWARD" { - continue - } - - rules, err = r.conn.GetRules(r.filterTable, chain) - if err != nil { - return err - } - } - - for _, rule := range rules { - if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleSrc)) || bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleDst)) { - err := r.conn.DelRule(rule) - if err != nil { - return err - } - } - } - r.isDefaultFwdRulesEnabled = false - return r.conn.Flush() -} - -// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR -func generateCIDRMatcherExpressions(source bool, cidr string) []expr.Any { - ip, network, _ := net.ParseCIDR(cidr) - ipToAdd, _ := netip.AddrFromSlice(ip) - add := ipToAdd.Unmap() - - var offSet uint32 - if source { - offSet = 12 // src offset - } else { - offSet = 16 // dst offset - } - - return []expr.Any{ - // fetch src add - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: offSet, - Len: 4, - }, - // net mask - &expr.Bitwise{ - DestRegister: 1, - SourceRegister: 1, - Len: 4, - Mask: network.Mask, - Xor: zeroXor, - }, - // net address - &expr.Cmp{ - Register: 1, - Data: add.AsSlice(), - }, - } -} diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go new file mode 100644 index 000000000..404ba6957 --- /dev/null +++ b/client/firewall/nftables/router_linux.go @@ -0,0 +1,869 @@ +package nftables + +import ( + "bytes" + "context" + "encoding/binary" + "errors" + "fmt" + "net" + "net/netip" + "strings" + + "github.com/coreos/go-iptables/iptables" + "github.com/davecgh/go-spew/spew" + "github.com/google/nftables" + "github.com/google/nftables/binaryutil" + "github.com/google/nftables/expr" + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nberrors "github.com/netbirdio/netbird/client/errors" + firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/internal/acl/id" + "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" +) + +const ( + chainNameRoutingFw = "netbird-rt-fwd" + chainNameRoutingNat = "netbird-rt-postrouting" + chainNameForward = "FORWARD" + + userDataAcceptForwardRuleIif = "frwacceptiif" + userDataAcceptForwardRuleOif = "frwacceptoif" +) + +const refreshRulesMapError = "refresh rules map: %w" + +var ( + errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found") +) + +type router struct { + ctx context.Context + stop context.CancelFunc + conn *nftables.Conn + workTable *nftables.Table + filterTable *nftables.Table + chains map[string]*nftables.Chain + // rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules + rules map[string]*nftables.Rule + ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set] + + wgIface iFaceMapper + legacyManagement bool +} + +func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFaceMapper) (*router, error) { + ctx, cancel := context.WithCancel(parentCtx) + + r := &router{ + ctx: ctx, + stop: cancel, + conn: &nftables.Conn{}, + workTable: workTable, + chains: make(map[string]*nftables.Chain), + rules: make(map[string]*nftables.Rule), + wgIface: wgIface, + } + + r.ipsetCounter = refcounter.New( + r.createIpSet, + r.deleteIpSet, + ) + + var err error + r.filterTable, err = r.loadFilterTable() + if err != nil { + if errors.Is(err, errFilterTableNotFound) { + log.Warnf("table 'filter' not found for forward rules") + } else { + return nil, err + } + } + + err = r.removeAcceptForwardRules() + if err != nil { + log.Errorf("failed to clean up rules from FORWARD chain: %s", err) + } + + err = r.createContainers() + if err != nil { + log.Errorf("failed to create containers for route: %s", err) + } + return r, err +} + +// Reset cleans existing nftables default forward rules from the system +func (r *router) Reset() error { + // clear without deleting the ipsets, the nf table will be deleted by the caller + r.ipsetCounter.Clear() + + return r.removeAcceptForwardRules() +} + +func (r *router) loadFilterTable() (*nftables.Table, error) { + tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4) + if err != nil { + return nil, fmt.Errorf("nftables: unable to list tables: %v", err) + } + + for _, table := range tables { + if table.Name == "filter" { + return table, nil + } + } + + return nil, errFilterTableNotFound +} + +func (r *router) createContainers() error { + r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{ + Name: chainNameRoutingFw, + Table: r.workTable, + }) + + insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw]) + + prio := *nftables.ChainPriorityNATSource - 1 + + r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{ + Name: chainNameRoutingNat, + Table: r.workTable, + Hooknum: nftables.ChainHookPostrouting, + Priority: &prio, + Type: nftables.ChainTypeNAT, + }) + + if err := r.acceptForwardRules(); err != nil { + log.Errorf("failed to add accept rules for the forward chain: %s", err) + } + + if err := r.refreshRulesMap(); err != nil { + log.Errorf("failed to clean up rules from FORWARD chain: %s", err) + } + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("nftables: unable to initialize table: %v", err) + } + + return nil +} + +// AddRouteFiltering appends a nftables rule to the routing chain +func (r *router) AddRouteFiltering( + sources []netip.Prefix, + destination netip.Prefix, + proto firewall.Protocol, + sPort *firewall.Port, + dPort *firewall.Port, + action firewall.Action, +) (firewall.Rule, error) { + + ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action) + if _, ok := r.rules[string(ruleKey)]; ok { + return ruleKey, nil + } + + chain := r.chains[chainNameRoutingFw] + var exprs []expr.Any + + switch { + case len(sources) == 1 && sources[0].Bits() == 0: + // If it's 0.0.0.0/0, we don't need to add any source matching + case len(sources) == 1: + // If there's only one source, we can use it directly + exprs = append(exprs, generateCIDRMatcherExpressions(true, sources[0])...) + default: + // If there are multiple sources, create or get an ipset + var err error + exprs, err = r.getIpSetExprs(sources, exprs) + if err != nil { + return nil, fmt.Errorf("get ipset expressions: %w", err) + } + } + + // Handle destination + exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...) + + // Handle protocol + if proto != firewall.ProtocolALL { + protoNum, err := protoToInt(proto) + if err != nil { + return nil, fmt.Errorf("convert protocol to number: %w", err) + } + exprs = append(exprs, &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1}) + exprs = append(exprs, &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{protoNum}, + }) + + exprs = append(exprs, applyPort(sPort, true)...) + exprs = append(exprs, applyPort(dPort, false)...) + } + + exprs = append(exprs, &expr.Counter{}) + + var verdict expr.VerdictKind + if action == firewall.ActionAccept { + verdict = expr.VerdictAccept + } else { + verdict = expr.VerdictDrop + } + exprs = append(exprs, &expr.Verdict{Kind: verdict}) + + rule := &nftables.Rule{ + Table: r.workTable, + Chain: chain, + Exprs: exprs, + UserData: []byte(ruleKey), + } + + rule = r.conn.AddRule(rule) + + log.Tracef("Adding route rule %s", spew.Sdump(rule)) + if err := r.conn.Flush(); err != nil { + return nil, fmt.Errorf(flushError, err) + } + + r.rules[string(ruleKey)] = rule + + log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action) + + return ruleKey, nil +} + +func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) { + setName := firewall.GenerateSetName(sources) + ref, err := r.ipsetCounter.Increment(setName, sources) + if err != nil { + return nil, fmt.Errorf("create or get ipset for sources: %w", err) + } + + exprs = append(exprs, + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: 12, + Len: 4, + }, + &expr.Lookup{ + SourceRegister: 1, + SetName: ref.Out.Name, + SetID: ref.Out.ID, + }, + ) + return exprs, nil +} + +func (r *router) DeleteRouteRule(rule firewall.Rule) error { + if err := r.refreshRulesMap(); err != nil { + return fmt.Errorf(refreshRulesMapError, err) + } + + ruleKey := rule.GetRuleID() + nftRule, exists := r.rules[ruleKey] + if !exists { + log.Debugf("route rule %s not found", ruleKey) + return nil + } + + if nftRule.Handle == 0 { + return fmt.Errorf("route rule %s has no handle", ruleKey) + } + + setName := r.findSetNameInRule(nftRule) + + if err := r.deleteNftRule(nftRule, ruleKey); err != nil { + return fmt.Errorf("delete: %w", err) + } + + if setName != "" { + if _, err := r.ipsetCounter.Decrement(setName); err != nil { + return fmt.Errorf("decrement ipset reference: %w", err) + } + } + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf(flushError, err) + } + + return nil +} + +func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) { + // overlapping prefixes will result in an error, so we need to merge them + sources = firewall.MergeIPRanges(sources) + + set := &nftables.Set{ + Name: setName, + Table: r.workTable, + // required for prefixes + Interval: true, + KeyType: nftables.TypeIPAddr, + } + + var elements []nftables.SetElement + for _, prefix := range sources { + // TODO: Implement IPv6 support + if prefix.Addr().Is6() { + log.Printf("Skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) + continue + } + + // nftables needs half-open intervals [firstIP, lastIP) for prefixes + // e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc + firstIP := prefix.Addr() + lastIP := calculateLastIP(prefix).Next() + + elements = append(elements, + // the nft tool also adds a line like this, see https://github.com/google/nftables/issues/247 + // nftables.SetElement{Key: []byte{0, 0, 0, 0}, IntervalEnd: true}, + nftables.SetElement{Key: firstIP.AsSlice()}, + nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true}, + ) + } + + if err := r.conn.AddSet(set, elements); err != nil { + return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err) + } + + if err := r.conn.Flush(); err != nil { + return nil, fmt.Errorf("flush error: %w", err) + } + + log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2) + + return set, nil +} + +// calculateLastIP determines the last IP in a given prefix. +func calculateLastIP(prefix netip.Prefix) netip.Addr { + hostMask := ^uint32(0) >> prefix.Masked().Bits() + lastIP := uint32FromNetipAddr(prefix.Addr()) | hostMask + + return netip.AddrFrom4(uint32ToBytes(lastIP)) +} + +// Utility function to convert netip.Addr to uint32. +func uint32FromNetipAddr(addr netip.Addr) uint32 { + b := addr.As4() + return binary.BigEndian.Uint32(b[:]) +} + +// Utility function to convert uint32 to a netip-compatible byte slice. +func uint32ToBytes(ip uint32) [4]byte { + var b [4]byte + binary.BigEndian.PutUint32(b[:], ip) + return b +} + +func (r *router) deleteIpSet(setName string, set *nftables.Set) error { + r.conn.DelSet(set) + if err := r.conn.Flush(); err != nil { + return fmt.Errorf(flushError, err) + } + + log.Debugf("Deleted unused ipset %s", setName) + return nil +} + +func (r *router) findSetNameInRule(rule *nftables.Rule) string { + for _, e := range rule.Exprs { + if lookup, ok := e.(*expr.Lookup); ok { + return lookup.SetName + } + } + return "" +} + +func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error { + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("delete rule %s: %w", ruleKey, err) + } + delete(r.rules, ruleKey) + + log.Debugf("removed route rule %s", ruleKey) + + return nil +} + +// AddNatRule appends a nftables rule pair to the nat chain +func (r *router) AddNatRule(pair firewall.RouterPair) error { + if err := r.refreshRulesMap(); err != nil { + return fmt.Errorf(refreshRulesMapError, err) + } + + if r.legacyManagement { + log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination) + if err := r.addLegacyRouteRule(pair); err != nil { + return fmt.Errorf("add legacy routing rule: %w", err) + } + } + + if pair.Masquerade { + if err := r.addNatRule(pair); err != nil { + return fmt.Errorf("add nat rule: %w", err) + } + + if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil { + return fmt.Errorf("add inverse nat rule: %w", err) + } + } + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("nftables: insert rules for %s: %v", pair.Destination, err) + } + + return nil +} + +// addNatRule inserts a nftables rule to the conn client flush queue +func (r *router) addNatRule(pair firewall.RouterPair) error { + sourceExp := generateCIDRMatcherExpressions(true, pair.Source) + destExp := generateCIDRMatcherExpressions(false, pair.Destination) + + dir := expr.MetaKeyIIFNAME + if pair.Inverse { + dir = expr.MetaKeyOIFNAME + } + + intf := ifname(r.wgIface.Name()) + exprs := []expr.Any{ + &expr.Meta{ + Key: dir, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: intf, + }, + } + + exprs = append(exprs, sourceExp...) + exprs = append(exprs, destExp...) + exprs = append(exprs, + &expr.Counter{}, &expr.Masq{}, + ) + + ruleKey := firewall.GenKey(firewall.NatFormat, pair) + + if _, exists := r.rules[ruleKey]; exists { + if err := r.removeNatRule(pair); err != nil { + return fmt.Errorf("remove routing rule: %w", err) + } + } + + r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameRoutingNat], + Exprs: exprs, + UserData: []byte(ruleKey), + }) + return nil +} + +// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls +func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { + sourceExp := generateCIDRMatcherExpressions(true, pair.Source) + destExp := generateCIDRMatcherExpressions(false, pair.Destination) + + exprs := []expr.Any{ + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + } + + expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic + + ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) + + if _, exists := r.rules[ruleKey]; exists { + if err := r.removeLegacyRouteRule(pair); err != nil { + return fmt.Errorf("remove legacy routing rule: %w", err) + } + } + + r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameRoutingFw], + Exprs: expression, + UserData: []byte(ruleKey), + }) + return nil +} + +// removeLegacyRouteRule removes a legacy routing rule for mgmt servers pre route acls +func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error { + ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) + + if rule, exists := r.rules[ruleKey]; exists { + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) + } + + log.Debugf("nftables: removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination) + + delete(r.rules, ruleKey) + } else { + log.Debugf("nftables: legacy forwarding rule %s not found", ruleKey) + } + + return nil +} + +// GetLegacyManagement returns the route manager's legacy management mode +func (r *router) GetLegacyManagement() bool { + return r.legacyManagement +} + +// SetLegacyManagement sets the route manager to use legacy management mode +func (r *router) SetLegacyManagement(isLegacy bool) { + r.legacyManagement = isLegacy +} + +// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls +func (r *router) RemoveAllLegacyRouteRules() error { + if err := r.refreshRulesMap(); err != nil { + return fmt.Errorf(refreshRulesMapError, err) + } + + var merr *multierror.Error + for k, rule := range r.rules { + if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) { + continue + } + if err := r.conn.DelRule(rule); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err)) + } + } + return nberrors.FormatErrorOrNil(merr) +} + +// acceptForwardRules adds iif/oif rules in the filter table/forward chain to make sure +// that our traffic is not dropped by existing rules there. +// The existing FORWARD rules/policies decide outbound traffic towards our interface. +// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. +func (r *router) acceptForwardRules() error { + if r.filterTable == nil { + log.Debugf("table 'filter' not found for forward rules, skipping accept rules") + return nil + } + + fw := "iptables" + + defer func() { + log.Debugf("Used %s to add accept forward rules", fw) + }() + + // Try iptables first and fallback to nftables if iptables is not available + ipt, err := iptables.New() + if err != nil { + // filter table exists but iptables is not + log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) + + fw = "nftables" + return r.acceptForwardRulesNftables() + } + + return r.acceptForwardRulesIptables(ipt) +} + +func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error { + var merr *multierror.Error + for _, rule := range r.getAcceptForwardRules() { + if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil { + merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err)) + } else { + log.Debugf("added iptables rule: %v", rule) + } + } + + return nberrors.FormatErrorOrNil(merr) +} + +func (r *router) getAcceptForwardRules() [][]string { + intf := r.wgIface.Name() + return [][]string{ + {"-i", intf, "-j", "ACCEPT"}, + {"-o", intf, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}, + } +} + +func (r *router) acceptForwardRulesNftables() error { + intf := ifname(r.wgIface.Name()) + + // Rule for incoming interface (iif) with counter + iifRule := &nftables.Rule{ + Table: r.filterTable, + Chain: &nftables.Chain{ + Name: chainNameForward, + Table: r.filterTable, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityFilter, + }, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: intf, + }, + &expr.Counter{}, + &expr.Verdict{Kind: expr.VerdictAccept}, + }, + UserData: []byte(userDataAcceptForwardRuleIif), + } + r.conn.InsertRule(iifRule) + + oifExprs := []expr.Any{ + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: intf, + }, + } + + // Rule for outgoing interface (oif) with counter + oifRule := &nftables.Rule{ + Table: r.filterTable, + Chain: &nftables.Chain{ + Name: "FORWARD", + Table: r.filterTable, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityFilter, + }, + Exprs: append(oifExprs, getEstablishedExprs(2)...), + UserData: []byte(userDataAcceptForwardRuleOif), + } + + r.conn.InsertRule(oifRule) + + return nil +} + +func (r *router) removeAcceptForwardRules() error { + if r.filterTable == nil { + return nil + } + + // Try iptables first and fallback to nftables if iptables is not available + ipt, err := iptables.New() + if err != nil { + log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) + return r.removeAcceptForwardRulesNftables() + } + + return r.removeAcceptForwardRulesIptables(ipt) +} + +func (r *router) removeAcceptForwardRulesNftables() error { + chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) + if err != nil { + return fmt.Errorf("list chains: %v", err) + } + + for _, chain := range chains { + if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward { + continue + } + + rules, err := r.conn.GetRules(r.filterTable, chain) + if err != nil { + return fmt.Errorf("get rules: %v", err) + } + + for _, rule := range rules { + if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) || + bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) { + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("delete rule: %v", err) + } + } + } + } + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf(flushError, err) + } + + return nil +} + +func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error { + var merr *multierror.Error + for _, rule := range r.getAcceptForwardRules() { + if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil { + merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err)) + } + } + + return nberrors.FormatErrorOrNil(merr) +} + +// RemoveNatRule removes a nftables rule pair from nat chains +func (r *router) RemoveNatRule(pair firewall.RouterPair) error { + if err := r.refreshRulesMap(); err != nil { + return fmt.Errorf(refreshRulesMapError, err) + } + + if err := r.removeNatRule(pair); err != nil { + return fmt.Errorf("remove nat rule: %w", err) + } + + if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { + return fmt.Errorf("remove inverse nat rule: %w", err) + } + + if err := r.removeLegacyRouteRule(pair); err != nil { + return fmt.Errorf("remove legacy routing rule: %w", err) + } + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err) + } + + log.Debugf("nftables: removed nat rules for %s", pair.Destination) + return nil +} + +// removeNatRule adds a nftables rule to the removal queue and deletes it from the rules map +func (r *router) removeNatRule(pair firewall.RouterPair) error { + ruleKey := firewall.GenKey(firewall.NatFormat, pair) + + if rule, exists := r.rules[ruleKey]; exists { + err := r.conn.DelRule(rule) + if err != nil { + return fmt.Errorf("remove nat rule %s -> %s: %v", pair.Source, pair.Destination, err) + } + + log.Debugf("nftables: removed nat rule %s -> %s", pair.Source, pair.Destination) + + delete(r.rules, ruleKey) + } else { + log.Debugf("nftables: nat rule %s not found", ruleKey) + } + + return nil +} + +// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid +// duplicates and to get missing attributes that we don't have when adding new rules +func (r *router) refreshRulesMap() error { + for _, chain := range r.chains { + rules, err := r.conn.GetRules(chain.Table, chain) + if err != nil { + return fmt.Errorf("nftables: unable to list rules: %v", err) + } + for _, rule := range rules { + if len(rule.UserData) > 0 { + r.rules[string(rule.UserData)] = rule + } + } + } + return nil +} + +// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR +func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any { + var offset uint32 + if source { + offset = 12 // src offset + } else { + offset = 16 // dst offset + } + + ones := prefix.Bits() + // 0.0.0.0/0 doesn't need extra expressions + if ones == 0 { + return nil + } + + mask := net.CIDRMask(ones, 32) + + return []expr.Any{ + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: offset, + Len: 4, + }, + // netmask + &expr.Bitwise{ + DestRegister: 1, + SourceRegister: 1, + Len: 4, + Mask: mask, + Xor: []byte{0, 0, 0, 0}, + }, + // net address + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: prefix.Masked().Addr().AsSlice(), + }, + } +} + +func applyPort(port *firewall.Port, isSource bool) []expr.Any { + if port == nil { + return nil + } + + var exprs []expr.Any + + offset := uint32(2) // Default offset for destination port + if isSource { + offset = 0 // Offset for source port + } + + exprs = append(exprs, &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseTransportHeader, + Offset: offset, + Len: 2, + }) + + if port.IsRange && len(port.Values) == 2 { + // Handle port range + exprs = append(exprs, + &expr.Cmp{ + Op: expr.CmpOpGte, + Register: 1, + Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[0])), + }, + &expr.Cmp{ + Op: expr.CmpOpLte, + Register: 1, + Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[1])), + }, + ) + } else { + // Handle single port or multiple ports + for i, p := range port.Values { + if i > 0 { + // Add a bitwise OR operation between port checks + exprs = append(exprs, &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Mask: []byte{0x00, 0x00, 0xff, 0xff}, + Xor: []byte{0x00, 0x00, 0x00, 0x00}, + }) + } + exprs = append(exprs, &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: binaryutil.BigEndian.PutUint16(uint16(p)), + }) + } + } + + return exprs +} diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index 913fbd5d2..25b7587ac 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -4,11 +4,15 @@ package nftables import ( "context" + "encoding/binary" + "net/netip" + "os/exec" "testing" "github.com/coreos/go-iptables/iptables" "github.com/google/nftables" "github.com/google/nftables/expr" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" firewall "github.com/netbirdio/netbird/client/firewall/manager" @@ -24,56 +28,50 @@ const ( NFTABLES ) -func TestNftablesManager_InsertRoutingRules(t *testing.T) { +func TestNftablesManager_AddNatRule(t *testing.T) { if check() != NFTABLES { t.Skip("nftables not supported on this OS") } table, err := createWorkTable() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Failed to create work table") defer deleteWorkTable() for _, testCase := range test.InsertRuleTestCases { t.Run(testCase.Name, func(t *testing.T) { - manager, err := newRouter(context.TODO(), table) + manager, err := newRouter(context.TODO(), table, ifaceMock) require.NoError(t, err, "failed to create router") nftablesTestingClient := &nftables.Conn{} - defer manager.ResetForwardRules() + defer func(manager *router) { + require.NoError(t, manager.Reset(), "failed to reset rules") + }(manager) require.NoError(t, err, "shouldn't return error") - err = manager.AddRoutingRules(testCase.InputPair) - defer func() { - _ = manager.RemoveRoutingRules(testCase.InputPair) - }() - require.NoError(t, err, "forwarding pair should be inserted") + err = manager.AddNatRule(testCase.InputPair) + require.NoError(t, err, "pair should be inserted") - sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) - destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) - testingExpression := append(sourceExp, destExp...) //nolint:gocritic - fwdRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID) - - found := 0 - for _, chain := range manager.chains { - rules, err := nftablesTestingClient.GetRules(chain.Table, chain) - require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) - for _, rule := range rules { - if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey { - require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match") - found = 1 - } - } - } - - require.Equal(t, 1, found, "should find at least 1 rule to test") + defer func(manager *router, pair firewall.RouterPair) { + require.NoError(t, manager.RemoveNatRule(pair), "failed to remove rule") + }(manager, testCase.InputPair) if testCase.InputPair.Masquerade { - natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID) + sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) + destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) + testingExpression := append(sourceExp, destExp...) //nolint:gocritic + testingExpression = append(testingExpression, + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(ifaceMock.Name()), + }, + ) + + natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) found := 0 for _, chain := range manager.chains { rules, err := nftablesTestingClient.GetRules(chain.Table, chain) @@ -88,27 +86,20 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) { require.Equal(t, 1, found, "should find at least 1 rule to test") } - sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source) - destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination) - testingExpression = append(sourceExp, destExp...) //nolint:gocritic - inFwdRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID) - - found = 0 - for _, chain := range manager.chains { - rules, err := nftablesTestingClient.GetRules(chain.Table, chain) - require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) - for _, rule := range rules { - if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey { - require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match") - found = 1 - } - } - } - - require.Equal(t, 1, found, "should find at least 1 rule to test") - if testCase.InputPair.Masquerade { - inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID) + sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) + destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) + testingExpression := append(sourceExp, destExp...) //nolint:gocritic + testingExpression = append(testingExpression, + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(ifaceMock.Name()), + }, + ) + + inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) found := 0 for _, chain := range manager.chains { rules, err := nftablesTestingClient.GetRules(chain.Table, chain) @@ -122,45 +113,37 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) { } require.Equal(t, 1, found, "should find at least 1 rule to test") } + }) } } -func TestNftablesManager_RemoveRoutingRules(t *testing.T) { +func TestNftablesManager_RemoveNatRule(t *testing.T) { if check() != NFTABLES { t.Skip("nftables not supported on this OS") } table, err := createWorkTable() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Failed to create work table") defer deleteWorkTable() for _, testCase := range test.RemoveRuleTestCases { t.Run(testCase.Name, func(t *testing.T) { - manager, err := newRouter(context.TODO(), table) + manager, err := newRouter(context.TODO(), table, ifaceMock) require.NoError(t, err, "failed to create router") nftablesTestingClient := &nftables.Conn{} - defer manager.ResetForwardRules() + defer func(manager *router) { + require.NoError(t, manager.Reset(), "failed to reset rules") + }(manager) sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) - forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic - forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID) - insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{ - Table: manager.workTable, - Chain: manager.chains[chainNameRouteingFw], - Exprs: forwardExp, - UserData: []byte(forwardRuleKey), - }) - natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic - natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID) + natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{ Table: manager.workTable, @@ -169,20 +152,11 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) { UserData: []byte(natRuleKey), }) - sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source) - destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination) - - forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic - inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID) - insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{ - Table: manager.workTable, - Chain: manager.chains[chainNameRouteingFw], - Exprs: forwardExp, - UserData: []byte(inForwardRuleKey), - }) + sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInversePair(testCase.InputPair).Source) + destExp = generateCIDRMatcherExpressions(false, firewall.GetInversePair(testCase.InputPair).Destination) natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic - inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID) + inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{ Table: manager.workTable, @@ -194,9 +168,10 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) { err = nftablesTestingClient.Flush() require.NoError(t, err, "shouldn't return error") - manager.ResetForwardRules() + err = manager.Reset() + require.NoError(t, err, "shouldn't return error") - err = manager.RemoveRoutingRules(testCase.InputPair) + err = manager.RemoveNatRule(testCase.InputPair) require.NoError(t, err, "shouldn't return error") for _, chain := range manager.chains { @@ -204,9 +179,7 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) { require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) for _, rule := range rules { if len(rule.UserData) > 0 { - require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist") require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist") - require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist") require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist") } } @@ -215,6 +188,468 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) { } } +func TestRouter_AddRouteFiltering(t *testing.T) { + if check() != NFTABLES { + t.Skip("nftables not supported on this system") + } + + workTable, err := createWorkTable() + require.NoError(t, err, "Failed to create work table") + + defer deleteWorkTable() + + r, err := newRouter(context.Background(), workTable, ifaceMock) + require.NoError(t, err, "Failed to create router") + + defer func(r *router) { + require.NoError(t, r.Reset(), "Failed to reset rules") + }(r) + + tests := []struct { + name string + sources []netip.Prefix + destination netip.Prefix + proto firewall.Protocol + sPort *firewall.Port + dPort *firewall.Port + direction firewall.RuleDirection + action firewall.Action + expectSet bool + }{ + { + name: "Basic TCP rule with single source", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")}, + destination: netip.MustParsePrefix("10.0.0.0/24"), + proto: firewall.ProtocolTCP, + sPort: nil, + dPort: &firewall.Port{Values: []int{80}}, + direction: firewall.RuleDirectionIN, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "UDP rule with multiple sources", + sources: []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/16"), + netip.MustParsePrefix("192.168.0.0/16"), + }, + destination: netip.MustParsePrefix("10.0.0.0/8"), + proto: firewall.ProtocolUDP, + sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true}, + dPort: nil, + direction: firewall.RuleDirectionOUT, + action: firewall.ActionDrop, + expectSet: true, + }, + { + name: "All protocols rule", + sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, + destination: netip.MustParsePrefix("0.0.0.0/0"), + proto: firewall.ProtocolALL, + sPort: nil, + dPort: nil, + direction: firewall.RuleDirectionIN, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "ICMP rule", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + destination: netip.MustParsePrefix("10.0.0.0/8"), + proto: firewall.ProtocolICMP, + sPort: nil, + dPort: nil, + direction: firewall.RuleDirectionIN, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "TCP rule with multiple source ports", + sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")}, + destination: netip.MustParsePrefix("192.168.0.0/16"), + proto: firewall.ProtocolTCP, + sPort: &firewall.Port{Values: []int{80, 443, 8080}}, + dPort: nil, + direction: firewall.RuleDirectionOUT, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "UDP rule with single IP and port range", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")}, + destination: netip.MustParsePrefix("10.0.0.0/24"), + proto: firewall.ProtocolUDP, + sPort: nil, + dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true}, + direction: firewall.RuleDirectionIN, + action: firewall.ActionDrop, + expectSet: false, + }, + { + name: "TCP rule with source and destination ports", + sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}, + destination: netip.MustParsePrefix("172.16.0.0/16"), + proto: firewall.ProtocolTCP, + sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true}, + dPort: &firewall.Port{Values: []int{22}}, + direction: firewall.RuleDirectionOUT, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "Drop all incoming traffic", + sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, + destination: netip.MustParsePrefix("192.168.0.0/24"), + proto: firewall.ProtocolALL, + sPort: nil, + dPort: nil, + direction: firewall.RuleDirectionIN, + action: firewall.ActionDrop, + expectSet: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) + require.NoError(t, err, "AddRouteFiltering failed") + + t.Cleanup(func() { + require.NoError(t, r.DeleteRouteRule(ruleKey), "Failed to delete rule") + }) + + // Check if the rule is in the internal map + rule, ok := r.rules[ruleKey.GetRuleID()] + assert.True(t, ok, "Rule not found in internal map") + + t.Log("Internal rule expressions:") + for i, expr := range rule.Exprs { + t.Logf(" [%d] %T: %+v", i, expr, expr) + } + + // Verify internal rule content + verifyRule(t, rule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet) + + // Check if the rule exists in nftables and verify its content + rules, err := r.conn.GetRules(r.workTable, r.chains[chainNameRoutingFw]) + require.NoError(t, err, "Failed to get rules from nftables") + + var nftRule *nftables.Rule + for _, rule := range rules { + if string(rule.UserData) == ruleKey.GetRuleID() { + nftRule = rule + break + } + } + + require.NotNil(t, nftRule, "Rule not found in nftables") + t.Log("Actual nftables rule expressions:") + for i, expr := range nftRule.Exprs { + t.Logf(" [%d] %T: %+v", i, expr, expr) + } + + // Verify actual nftables rule content + verifyRule(t, nftRule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet) + }) + } +} + +func TestNftablesCreateIpSet(t *testing.T) { + if check() != NFTABLES { + t.Skip("nftables not supported on this system") + } + + workTable, err := createWorkTable() + require.NoError(t, err, "Failed to create work table") + + defer deleteWorkTable() + + r, err := newRouter(context.Background(), workTable, ifaceMock) + require.NoError(t, err, "Failed to create router") + + defer func() { + require.NoError(t, r.Reset(), "Failed to reset router") + }() + + tests := []struct { + name string + sources []netip.Prefix + expected []netip.Prefix + }{ + { + name: "Single IP", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")}, + }, + { + name: "Multiple IPs", + sources: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.1/32"), + netip.MustParsePrefix("10.0.0.1/32"), + netip.MustParsePrefix("172.16.0.1/32"), + }, + }, + { + name: "Single Subnet", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}, + }, + { + name: "Multiple Subnets with Various Prefix Lengths", + sources: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("172.16.0.0/16"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("203.0.113.0/26"), + }, + }, + { + name: "Mix of Single IPs and Subnets in Different Positions", + sources: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.1/32"), + netip.MustParsePrefix("10.0.0.0/16"), + netip.MustParsePrefix("172.16.0.1/32"), + netip.MustParsePrefix("203.0.113.0/24"), + }, + }, + { + name: "Overlapping IPs/Subnets", + sources: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("10.0.0.0/16"), + netip.MustParsePrefix("10.0.0.1/32"), + netip.MustParsePrefix("192.168.0.0/16"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.1.1/32"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("192.168.0.0/16"), + }, + }, + } + + // Add this helper function inside TestNftablesCreateIpSet + printNftSets := func() { + cmd := exec.Command("nft", "list", "sets") + output, err := cmd.CombinedOutput() + if err != nil { + t.Logf("Failed to run 'nft list sets': %v", err) + } else { + t.Logf("Current nft sets:\n%s", output) + } + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setName := firewall.GenerateSetName(tt.sources) + set, err := r.createIpSet(setName, tt.sources) + if err != nil { + t.Logf("Failed to create IP set: %v", err) + printNftSets() + require.NoError(t, err, "Failed to create IP set") + } + require.NotNil(t, set, "Created set is nil") + + // Verify set properties + assert.Equal(t, setName, set.Name, "Set name mismatch") + assert.Equal(t, r.workTable, set.Table, "Set table mismatch") + assert.True(t, set.Interval, "Set interval property should be true") + assert.Equal(t, nftables.TypeIPAddr, set.KeyType, "Set key type mismatch") + + // Fetch the created set from nftables + fetchedSet, err := r.conn.GetSetByName(r.workTable, setName) + require.NoError(t, err, "Failed to fetch created set") + require.NotNil(t, fetchedSet, "Fetched set is nil") + + // Verify set elements + elements, err := r.conn.GetSetElements(fetchedSet) + require.NoError(t, err, "Failed to get set elements") + + // Count the number of unique prefixes (excluding interval end markers) + uniquePrefixes := make(map[string]bool) + for _, elem := range elements { + if !elem.IntervalEnd { + ip := netip.AddrFrom4(*(*[4]byte)(elem.Key)) + uniquePrefixes[ip.String()] = true + } + } + + // Check against expected merged prefixes + expectedCount := len(tt.expected) + if expectedCount == 0 { + expectedCount = len(tt.sources) + } + assert.Equal(t, expectedCount, len(uniquePrefixes), "Number of unique prefixes in set doesn't match expected") + + // Verify each expected prefix is in the set + for _, expected := range tt.expected { + found := false + for _, elem := range elements { + if !elem.IntervalEnd { + ip := netip.AddrFrom4(*(*[4]byte)(elem.Key)) + if expected.Contains(ip) { + found = true + break + } + } + } + assert.True(t, found, "Expected prefix %s not found in set", expected) + } + + r.conn.DelSet(set) + if err := r.conn.Flush(); err != nil { + t.Logf("Failed to delete set: %v", err) + printNftSets() + } + require.NoError(t, err, "Failed to delete set") + }) + } +} + +func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) { + t.Helper() + + assert.NotNil(t, rule, "Rule should not be nil") + + // Verify sources and destination + if expectSet { + assert.True(t, containsSetLookup(rule.Exprs), "Rule should contain set lookup for multiple sources") + } else if len(sources) == 1 && sources[0].Bits() != 0 { + if direction == firewall.RuleDirectionIN { + assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], true), "Rule should contain source CIDR matcher for %s", sources[0]) + } else { + assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], false), "Rule should contain destination CIDR matcher for %s", sources[0]) + } + } + + if direction == firewall.RuleDirectionIN { + assert.True(t, containsCIDRMatcher(rule.Exprs, destination, false), "Rule should contain destination CIDR matcher for %s", destination) + } else { + assert.True(t, containsCIDRMatcher(rule.Exprs, destination, true), "Rule should contain source CIDR matcher for %s", destination) + } + + // Verify protocol + if proto != firewall.ProtocolALL { + assert.True(t, containsProtocol(rule.Exprs, proto), "Rule should contain protocol matcher for %s", proto) + } + + // Verify ports + if sPort != nil { + assert.True(t, containsPort(rule.Exprs, sPort, true), "Rule should contain source port matcher for %v", sPort) + } + if dPort != nil { + assert.True(t, containsPort(rule.Exprs, dPort, false), "Rule should contain destination port matcher for %v", dPort) + } + + // Verify action + assert.True(t, containsAction(rule.Exprs, action), "Rule should contain correct action: %s", action) +} + +func containsSetLookup(exprs []expr.Any) bool { + for _, e := range exprs { + if _, ok := e.(*expr.Lookup); ok { + return true + } + } + return false +} + +func containsCIDRMatcher(exprs []expr.Any, prefix netip.Prefix, isSource bool) bool { + var offset uint32 + if isSource { + offset = 12 // src offset + } else { + offset = 16 // dst offset + } + + var payloadFound, bitwiseFound, cmpFound bool + for _, e := range exprs { + switch ex := e.(type) { + case *expr.Payload: + if ex.Base == expr.PayloadBaseNetworkHeader && ex.Offset == offset && ex.Len == 4 { + payloadFound = true + } + case *expr.Bitwise: + if ex.Len == 4 && len(ex.Mask) == 4 && len(ex.Xor) == 4 { + bitwiseFound = true + } + case *expr.Cmp: + if ex.Op == expr.CmpOpEq && len(ex.Data) == 4 { + cmpFound = true + } + } + } + return (payloadFound && bitwiseFound && cmpFound) || prefix.Bits() == 0 +} + +func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool { + var offset uint32 = 2 // Default offset for destination port + if isSource { + offset = 0 // Offset for source port + } + + var payloadFound, portMatchFound bool + for _, e := range exprs { + switch ex := e.(type) { + case *expr.Payload: + if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 { + payloadFound = true + } + case *expr.Cmp: + if port.IsRange { + if ex.Op == expr.CmpOpGte || ex.Op == expr.CmpOpLte { + portMatchFound = true + } + } else { + if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 { + portValue := binary.BigEndian.Uint16(ex.Data) + for _, p := range port.Values { + if uint16(p) == portValue { + portMatchFound = true + break + } + } + } + } + } + if payloadFound && portMatchFound { + return true + } + } + return false +} + +func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool { + var metaFound, cmpFound bool + expectedProto, _ := protoToInt(proto) + for _, e := range exprs { + switch ex := e.(type) { + case *expr.Meta: + if ex.Key == expr.MetaKeyL4PROTO { + metaFound = true + } + case *expr.Cmp: + if ex.Op == expr.CmpOpEq && len(ex.Data) == 1 && ex.Data[0] == expectedProto { + cmpFound = true + } + } + } + return metaFound && cmpFound +} + +func containsAction(exprs []expr.Any, action firewall.Action) bool { + for _, e := range exprs { + if verdict, ok := e.(*expr.Verdict); ok { + switch action { + case firewall.ActionAccept: + return verdict.Kind == expr.VerdictAccept + case firewall.ActionDrop: + return verdict.Kind == expr.VerdictDrop + } + } + } + return false +} + // check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found. func check() int { nf := nftables.Conn{} @@ -250,12 +685,12 @@ func createWorkTable() (*nftables.Table, error) { } for _, t := range tables { - if t.Name == tableName { + if t.Name == tableNameNetbird { sConn.DelTable(t) } } - table := sConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}) + table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}) err = sConn.Flush() return table, err @@ -273,7 +708,7 @@ func deleteWorkTable() { } for _, t := range tables { - if t.Name == tableName { + if t.Name == tableNameNetbird { sConn.DelTable(t) } } diff --git a/client/firewall/test/cases_linux.go b/client/firewall/test/cases_linux.go index 432d113dd..267e93efd 100644 --- a/client/firewall/test/cases_linux.go +++ b/client/firewall/test/cases_linux.go @@ -1,8 +1,10 @@ -//go:build !android - package test -import firewall "github.com/netbirdio/netbird/client/firewall/manager" +import ( + "net/netip" + + firewall "github.com/netbirdio/netbird/client/firewall/manager" +) var ( InsertRuleTestCases = []struct { @@ -13,8 +15,8 @@ var ( Name: "Insert Forwarding IPV4 Rule", InputPair: firewall.RouterPair{ ID: "zxa", - Source: "100.100.100.1/32", - Destination: "100.100.200.0/24", + Source: netip.MustParsePrefix("100.100.100.1/32"), + Destination: netip.MustParsePrefix("100.100.200.0/24"), Masquerade: false, }, }, @@ -22,8 +24,8 @@ var ( Name: "Insert Forwarding And Nat IPV4 Rules", InputPair: firewall.RouterPair{ ID: "zxa", - Source: "100.100.100.1/32", - Destination: "100.100.200.0/24", + Source: netip.MustParsePrefix("100.100.100.1/32"), + Destination: netip.MustParsePrefix("100.100.200.0/24"), Masquerade: true, }, }, @@ -38,8 +40,8 @@ var ( Name: "Remove Forwarding And Nat IPV4 Rules", InputPair: firewall.RouterPair{ ID: "zxa", - Source: "100.100.100.1/32", - Destination: "100.100.200.0/24", + Source: netip.MustParsePrefix("100.100.100.1/32"), + Destination: netip.MustParsePrefix("100.100.200.0/24"), Masquerade: true, }, }, diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 75792e9c0..0e3ee9799 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -3,6 +3,7 @@ package uspfilter import ( "fmt" "net" + "net/netip" "sync" "github.com/google/gopacket" @@ -11,7 +12,8 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" ) const layerTypeAll = 0 @@ -22,7 +24,7 @@ var ( // IFaceMapper defines subset methods of interface required for manager type IFaceMapper interface { - SetFilter(iface.PacketFilter) error + SetFilter(device.PacketFilter) error Address() iface.WGAddress } @@ -103,26 +105,26 @@ func (m *Manager) IsServerRouteSupported() bool { } } -func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error { +func (m *Manager) AddNatRule(pair firewall.RouterPair) error { if m.nativeFirewall == nil { return errRouteNotSupported } - return m.nativeFirewall.InsertRoutingRules(pair) + return m.nativeFirewall.AddNatRule(pair) } -// RemoveRoutingRules removes a routing firewall rule -func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error { +// RemoveNatRule removes a routing firewall rule +func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { if m.nativeFirewall == nil { return errRouteNotSupported } - return m.nativeFirewall.RemoveRoutingRules(pair) + return m.nativeFirewall.RemoveNatRule(pair) } -// AddFiltering rule to the firewall +// AddPeerFiltering rule to the firewall // // If comment argument is empty firewall manager should set // rule ID as comment for the rule -func (m *Manager) AddFiltering( +func (m *Manager) AddPeerFiltering( ip net.IP, proto firewall.Protocol, sPort *firewall.Port, @@ -188,8 +190,22 @@ func (m *Manager) AddFiltering( return []firewall.Rule{&r}, nil } -// DeleteRule from the firewall by rule definition -func (m *Manager) DeleteRule(rule firewall.Rule) error { +func (m *Manager) AddRouteFiltering(sources [] netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action ) (firewall.Rule, error) { + if m.nativeFirewall == nil { + return nil, errRouteNotSupported + } + return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action) +} + +func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { + if m.nativeFirewall == nil { + return errRouteNotSupported + } + return m.nativeFirewall.DeleteRouteRule(rule) +} + +// DeletePeerRule from the firewall by rule definition +func (m *Manager) DeletePeerRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -215,6 +231,11 @@ func (m *Manager) DeleteRule(rule firewall.Rule) error { return nil } +// SetLegacyManagement doesn't need to be implemented for this manager +func (m *Manager) SetLegacyManagement(_ bool) error { + return nil +} + // Flush doesn't need to be implemented for this manager func (m *Manager) Flush() error { return nil } @@ -395,7 +416,7 @@ func (m *Manager) RemovePacketHook(hookID string) error { for _, r := range arr { if r.id == hookID { rule := r - return m.DeleteRule(&rule) + return m.DeletePeerRule(&rule) } } } @@ -403,7 +424,7 @@ func (m *Manager) RemovePacketHook(hookID string) error { for _, r := range arr { if r.id == hookID { rule := r - return m.DeleteRule(&rule) + return m.DeletePeerRule(&rule) } } } diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index 514a90539..c188deea4 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -11,15 +11,16 @@ import ( "github.com/stretchr/testify/require" fw "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" ) type IFaceMock struct { - SetFilterFunc func(iface.PacketFilter) error + SetFilterFunc func(device.PacketFilter) error AddressFunc func() iface.WGAddress } -func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error { +func (i *IFaceMock) SetFilter(iface device.PacketFilter) error { if i.SetFilterFunc == nil { return fmt.Errorf("not implemented") } @@ -35,7 +36,7 @@ func (i *IFaceMock) Address() iface.WGAddress { func TestManagerCreate(t *testing.T) { ifaceMock := &IFaceMock{ - SetFilterFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(device.PacketFilter) error { return nil }, } m, err := Create(ifaceMock) @@ -49,10 +50,10 @@ func TestManagerCreate(t *testing.T) { } } -func TestManagerAddFiltering(t *testing.T) { +func TestManagerAddPeerFiltering(t *testing.T) { isSetFilterCalled := false ifaceMock := &IFaceMock{ - SetFilterFunc: func(iface.PacketFilter) error { + SetFilterFunc: func(device.PacketFilter) error { isSetFilterCalled = true return nil }, @@ -71,7 +72,7 @@ func TestManagerAddFiltering(t *testing.T) { action := fw.ActionDrop comment := "Test rule" - rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment) + rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) if err != nil { t.Errorf("failed to add filtering: %v", err) return @@ -90,7 +91,7 @@ func TestManagerAddFiltering(t *testing.T) { func TestManagerDeleteRule(t *testing.T) { ifaceMock := &IFaceMock{ - SetFilterFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(device.PacketFilter) error { return nil }, } m, err := Create(ifaceMock) @@ -106,7 +107,7 @@ func TestManagerDeleteRule(t *testing.T) { action := fw.ActionDrop comment := "Test rule" - rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment) + rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) if err != nil { t.Errorf("failed to add filtering: %v", err) return @@ -119,14 +120,14 @@ func TestManagerDeleteRule(t *testing.T) { action = fw.ActionDrop comment = "Test rule 2" - rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment) + rule2, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) if err != nil { t.Errorf("failed to add filtering: %v", err) return } for _, r := range rule { - err = m.DeleteRule(r) + err = m.DeletePeerRule(r) if err != nil { t.Errorf("failed to delete rule: %v", err) return @@ -140,7 +141,7 @@ func TestManagerDeleteRule(t *testing.T) { } for _, r := range rule2 { - err = m.DeleteRule(r) + err = m.DeletePeerRule(r) if err != nil { t.Errorf("failed to delete rule: %v", err) return @@ -236,7 +237,7 @@ func TestAddUDPPacketHook(t *testing.T) { func TestManagerReset(t *testing.T) { ifaceMock := &IFaceMock{ - SetFilterFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(device.PacketFilter) error { return nil }, } m, err := Create(ifaceMock) @@ -252,7 +253,7 @@ func TestManagerReset(t *testing.T) { action := fw.ActionDrop comment := "Test rule" - _, err = m.AddFiltering(ip, proto, nil, port, direction, action, "", comment) + _, err = m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) if err != nil { t.Errorf("failed to add filtering: %v", err) return @@ -271,7 +272,7 @@ func TestManagerReset(t *testing.T) { func TestNotMatchByIP(t *testing.T) { ifaceMock := &IFaceMock{ - SetFilterFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(device.PacketFilter) error { return nil }, } m, err := Create(ifaceMock) @@ -290,7 +291,7 @@ func TestNotMatchByIP(t *testing.T) { action := fw.ActionAccept comment := "Test rule" - _, err = m.AddFiltering(ip, proto, nil, nil, direction, action, "", comment) + _, err = m.AddPeerFiltering(ip, proto, nil, nil, direction, action, "", comment) if err != nil { t.Errorf("failed to add filtering: %v", err) return @@ -339,7 +340,7 @@ func TestNotMatchByIP(t *testing.T) { func TestRemovePacketHook(t *testing.T) { // creating mock iface iface := &IFaceMock{ - SetFilterFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(device.PacketFilter) error { return nil }, } // creating manager instance @@ -388,7 +389,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { // just check on the local interface ifaceMock := &IFaceMock{ - SetFilterFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(device.PacketFilter) error { return nil }, } manager, err := Create(ifaceMock) require.NoError(t, err) @@ -406,9 +407,9 @@ func TestUSPFilterCreatePerformance(t *testing.T) { for i := 0; i < testMax; i++ { port := &fw.Port{Values: []int{1000 + i}} if i%2 == 0 { - _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") + _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") } else { - _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") + _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") } require.NoError(t, err, "failed to add rule") diff --git a/iface/bind/bind.go b/client/iface/bind/bind.go similarity index 100% rename from iface/bind/bind.go rename to client/iface/bind/bind.go diff --git a/iface/bind/udp_mux.go b/client/iface/bind/udp_mux.go similarity index 100% rename from iface/bind/udp_mux.go rename to client/iface/bind/udp_mux.go diff --git a/iface/bind/udp_mux_universal.go b/client/iface/bind/udp_mux_universal.go similarity index 100% rename from iface/bind/udp_mux_universal.go rename to client/iface/bind/udp_mux_universal.go diff --git a/iface/bind/udp_muxed_conn.go b/client/iface/bind/udp_muxed_conn.go similarity index 100% rename from iface/bind/udp_muxed_conn.go rename to client/iface/bind/udp_muxed_conn.go diff --git a/client/iface/configurer/err.go b/client/iface/configurer/err.go new file mode 100644 index 000000000..a64bba2dd --- /dev/null +++ b/client/iface/configurer/err.go @@ -0,0 +1,5 @@ +package configurer + +import "errors" + +var ErrPeerNotFound = errors.New("peer not found") diff --git a/iface/wg_configurer_kernel_unix.go b/client/iface/configurer/kernel_unix.go similarity index 83% rename from iface/wg_configurer_kernel_unix.go rename to client/iface/configurer/kernel_unix.go index 8b47082da..7c1c41669 100644 --- a/iface/wg_configurer_kernel_unix.go +++ b/client/iface/configurer/kernel_unix.go @@ -1,6 +1,6 @@ //go:build (linux && !android) || freebsd -package iface +package configurer import ( "fmt" @@ -12,18 +12,17 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -type wgKernelConfigurer struct { +type KernelConfigurer struct { deviceName string } -func newWGConfigurer(deviceName string) wgConfigurer { - wgc := &wgKernelConfigurer{ +func NewKernelConfigurer(deviceName string) *KernelConfigurer { + return &KernelConfigurer{ deviceName: deviceName, } - return wgc } -func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) error { +func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error { log.Debugf("adding Wireguard private key") key, err := wgtypes.ParseKey(privateKey) if err != nil { @@ -44,7 +43,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err return nil } -func (c *wgKernelConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { +func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { // parse allowed ips _, ipNet, err := net.ParseCIDR(allowedIps) if err != nil { @@ -75,7 +74,7 @@ func (c *wgKernelConfigurer) updatePeer(peerKey string, allowedIps string, keepA return nil } -func (c *wgKernelConfigurer) removePeer(peerKey string) error { +func (c *KernelConfigurer) RemovePeer(peerKey string) error { peerKeyParsed, err := wgtypes.ParseKey(peerKey) if err != nil { return err @@ -96,7 +95,7 @@ func (c *wgKernelConfigurer) removePeer(peerKey string) error { return nil } -func (c *wgKernelConfigurer) addAllowedIP(peerKey string, allowedIP string) error { +func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error { _, ipNet, err := net.ParseCIDR(allowedIP) if err != nil { return err @@ -123,7 +122,7 @@ func (c *wgKernelConfigurer) addAllowedIP(peerKey string, allowedIP string) erro return nil } -func (c *wgKernelConfigurer) removeAllowedIP(peerKey string, allowedIP string) error { +func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP string) error { _, ipNet, err := net.ParseCIDR(allowedIP) if err != nil { return fmt.Errorf("parse allowed IP: %w", err) @@ -165,7 +164,7 @@ func (c *wgKernelConfigurer) removeAllowedIP(peerKey string, allowedIP string) e return nil } -func (c *wgKernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) { +func (c *KernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) { wg, err := wgctrl.New() if err != nil { return wgtypes.Peer{}, fmt.Errorf("wgctl: %w", err) @@ -189,7 +188,7 @@ func (c *wgKernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer return wgtypes.Peer{}, ErrPeerNotFound } -func (c *wgKernelConfigurer) configure(config wgtypes.Config) error { +func (c *KernelConfigurer) configure(config wgtypes.Config) error { wg, err := wgctrl.New() if err != nil { return err @@ -205,10 +204,10 @@ func (c *wgKernelConfigurer) configure(config wgtypes.Config) error { return wg.ConfigureDevice(c.deviceName, config) } -func (c *wgKernelConfigurer) close() { +func (c *KernelConfigurer) Close() { } -func (c *wgKernelConfigurer) getStats(peerKey string) (WGStats, error) { +func (c *KernelConfigurer) GetStats(peerKey string) (WGStats, error) { peer, err := c.getPeer(c.deviceName, peerKey) if err != nil { return WGStats{}, fmt.Errorf("get wireguard stats: %w", err) diff --git a/iface/name.go b/client/iface/configurer/name.go similarity index 87% rename from iface/name.go rename to client/iface/configurer/name.go index 706cb65ad..e2133d0ea 100644 --- a/iface/name.go +++ b/client/iface/configurer/name.go @@ -1,6 +1,6 @@ //go:build linux || windows || freebsd -package iface +package configurer // WgInterfaceDefault is a default interface name of Wiretrustee const WgInterfaceDefault = "wt0" diff --git a/iface/name_darwin.go b/client/iface/configurer/name_darwin.go similarity index 86% rename from iface/name_darwin.go rename to client/iface/configurer/name_darwin.go index a4016ce15..034ce388d 100644 --- a/iface/name_darwin.go +++ b/client/iface/configurer/name_darwin.go @@ -1,6 +1,6 @@ //go:build darwin -package iface +package configurer // WgInterfaceDefault is a default interface name of Wiretrustee const WgInterfaceDefault = "utun100" diff --git a/iface/uapi.go b/client/iface/configurer/uapi.go similarity index 96% rename from iface/uapi.go rename to client/iface/configurer/uapi.go index d7ff52e7b..4801841de 100644 --- a/iface/uapi.go +++ b/client/iface/configurer/uapi.go @@ -1,6 +1,6 @@ //go:build !windows -package iface +package configurer import ( "net" diff --git a/iface/uapi_windows.go b/client/iface/configurer/uapi_windows.go similarity index 88% rename from iface/uapi_windows.go rename to client/iface/configurer/uapi_windows.go index e1f466364..46fa90c2e 100644 --- a/iface/uapi_windows.go +++ b/client/iface/configurer/uapi_windows.go @@ -1,4 +1,4 @@ -package iface +package configurer import ( "net" diff --git a/iface/wg_configurer_usp.go b/client/iface/configurer/usp.go similarity index 93% rename from iface/wg_configurer_usp.go rename to client/iface/configurer/usp.go index cd1d9d0b6..21d65ab2a 100644 --- a/iface/wg_configurer_usp.go +++ b/client/iface/configurer/usp.go @@ -1,4 +1,4 @@ -package iface +package configurer import ( "encoding/hex" @@ -19,15 +19,15 @@ import ( var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found") -type wgUSPConfigurer struct { +type WGUSPConfigurer struct { device *device.Device deviceName string uapiListener net.Listener } -func newWGUSPConfigurer(device *device.Device, deviceName string) wgConfigurer { - wgCfg := &wgUSPConfigurer{ +func NewUSPConfigurer(device *device.Device, deviceName string) *WGUSPConfigurer { + wgCfg := &WGUSPConfigurer{ device: device, deviceName: deviceName, } @@ -35,7 +35,7 @@ func newWGUSPConfigurer(device *device.Device, deviceName string) wgConfigurer { return wgCfg } -func (c *wgUSPConfigurer) configureInterface(privateKey string, port int) error { +func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error { log.Debugf("adding Wireguard private key") key, err := wgtypes.ParseKey(privateKey) if err != nil { @@ -52,7 +52,7 @@ func (c *wgUSPConfigurer) configureInterface(privateKey string, port int) error return c.device.IpcSet(toWgUserspaceString(config)) } -func (c *wgUSPConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { +func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { // parse allowed ips _, ipNet, err := net.ParseCIDR(allowedIps) if err != nil { @@ -80,7 +80,7 @@ func (c *wgUSPConfigurer) updatePeer(peerKey string, allowedIps string, keepAliv return c.device.IpcSet(toWgUserspaceString(config)) } -func (c *wgUSPConfigurer) removePeer(peerKey string) error { +func (c *WGUSPConfigurer) RemovePeer(peerKey string) error { peerKeyParsed, err := wgtypes.ParseKey(peerKey) if err != nil { return err @@ -97,7 +97,7 @@ func (c *wgUSPConfigurer) removePeer(peerKey string) error { return c.device.IpcSet(toWgUserspaceString(config)) } -func (c *wgUSPConfigurer) addAllowedIP(peerKey string, allowedIP string) error { +func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error { _, ipNet, err := net.ParseCIDR(allowedIP) if err != nil { return err @@ -121,7 +121,7 @@ func (c *wgUSPConfigurer) addAllowedIP(peerKey string, allowedIP string) error { return c.device.IpcSet(toWgUserspaceString(config)) } -func (c *wgUSPConfigurer) removeAllowedIP(peerKey string, ip string) error { +func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error { ipc, err := c.device.IpcGet() if err != nil { return err @@ -185,7 +185,7 @@ func (c *wgUSPConfigurer) removeAllowedIP(peerKey string, ip string) error { } // startUAPI starts the UAPI listener for managing the WireGuard interface via external tool -func (t *wgUSPConfigurer) startUAPI() { +func (t *WGUSPConfigurer) startUAPI() { var err error t.uapiListener, err = openUAPI(t.deviceName) if err != nil { @@ -207,7 +207,7 @@ func (t *wgUSPConfigurer) startUAPI() { }(t.uapiListener) } -func (t *wgUSPConfigurer) close() { +func (t *WGUSPConfigurer) Close() { if t.uapiListener != nil { err := t.uapiListener.Close() if err != nil { @@ -223,7 +223,7 @@ func (t *wgUSPConfigurer) close() { } } -func (t *wgUSPConfigurer) getStats(peerKey string) (WGStats, error) { +func (t *WGUSPConfigurer) GetStats(peerKey string) (WGStats, error) { ipc, err := t.device.IpcGet() if err != nil { return WGStats{}, fmt.Errorf("ipc get: %w", err) diff --git a/iface/wg_configurer_usp_test.go b/client/iface/configurer/usp_test.go similarity index 99% rename from iface/wg_configurer_usp_test.go rename to client/iface/configurer/usp_test.go index ac0fc6130..775339f24 100644 --- a/iface/wg_configurer_usp_test.go +++ b/client/iface/configurer/usp_test.go @@ -1,4 +1,4 @@ -package iface +package configurer import ( "encoding/hex" diff --git a/client/iface/configurer/wgstats.go b/client/iface/configurer/wgstats.go new file mode 100644 index 000000000..56d0d7310 --- /dev/null +++ b/client/iface/configurer/wgstats.go @@ -0,0 +1,9 @@ +package configurer + +import "time" + +type WGStats struct { + LastHandshake time.Time + TxBytes int64 + RxBytes int64 +} diff --git a/client/iface/device.go b/client/iface/device.go new file mode 100644 index 000000000..0d4e69145 --- /dev/null +++ b/client/iface/device.go @@ -0,0 +1,18 @@ +//go:build !android + +package iface + +import ( + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" +) + +type WGTunDevice interface { + Create() (device.WGConfigurer, error) + Up() (*bind.UniversalUDPMuxDefault, error) + UpdateAddr(address WGAddress) error + WgAddress() WGAddress + DeviceName() string + Close() error + FilteredDevice() *device.FilteredDevice +} diff --git a/iface/tun_adapter.go b/client/iface/device/adapter.go similarity index 94% rename from iface/tun_adapter.go rename to client/iface/device/adapter.go index adec93ed1..6ebc05390 100644 --- a/iface/tun_adapter.go +++ b/client/iface/device/adapter.go @@ -1,4 +1,4 @@ -package iface +package device // TunAdapter is an interface for create tun device from external service type TunAdapter interface { diff --git a/iface/address.go b/client/iface/device/address.go similarity index 69% rename from iface/address.go rename to client/iface/device/address.go index 5ff4fbc06..15de301da 100644 --- a/iface/address.go +++ b/client/iface/device/address.go @@ -1,18 +1,18 @@ -package iface +package device import ( "fmt" "net" ) -// WGAddress Wireguard parsed address +// WGAddress WireGuard parsed address type WGAddress struct { IP net.IP Network *net.IPNet } -// parseWGAddress parse a string ("1.2.3.4/24") address to WG Address -func parseWGAddress(address string) (WGAddress, error) { +// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address +func ParseWGAddress(address string) (WGAddress, error) { ip, network, err := net.ParseCIDR(address) if err != nil { return WGAddress{}, err diff --git a/iface/tun_args.go b/client/iface/device/args.go similarity index 88% rename from iface/tun_args.go rename to client/iface/device/args.go index 0eac2c4c0..d7b86b335 100644 --- a/iface/tun_args.go +++ b/client/iface/device/args.go @@ -1,4 +1,4 @@ -package iface +package device type MobileIFaceArguments struct { TunAdapter TunAdapter // only for Android diff --git a/iface/tun_android.go b/client/iface/device/device_android.go similarity index 61% rename from iface/tun_android.go rename to client/iface/device/device_android.go index 504993094..29e3f409d 100644 --- a/iface/tun_android.go +++ b/client/iface/device/device_android.go @@ -1,7 +1,6 @@ //go:build android -// +build android -package iface +package device import ( "strings" @@ -12,11 +11,12 @@ import ( "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" ) -// ignore the wgTunDevice interface on Android because the creation of the tun device is different on this platform -type wgTunDevice struct { +// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform +type WGTunDevice struct { address WGAddress port int key string @@ -24,15 +24,15 @@ type wgTunDevice struct { iceBind *bind.ICEBind tunAdapter TunAdapter - name string - device *device.Device - wrapper *DeviceWrapper - udpMux *bind.UniversalUDPMuxDefault - configurer wgConfigurer + name string + device *device.Device + filteredDevice *FilteredDevice + udpMux *bind.UniversalUDPMuxDefault + configurer WGConfigurer } -func newTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter, filterFn bind.FilterFn) wgTunDevice { - return wgTunDevice{ +func NewTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter, filterFn bind.FilterFn) *WGTunDevice { + return &WGTunDevice{ address: address, port: port, key: key, @@ -42,7 +42,7 @@ func newTunDevice(address WGAddress, port int, key string, mtu int, transportNet } } -func (t *wgTunDevice) Create(routes []string, dns string, searchDomains []string) (wgConfigurer, error) { +func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) { log.Info("create tun interface") routesString := routesToString(routes) @@ -61,24 +61,24 @@ func (t *wgTunDevice) Create(routes []string, dns string, searchDomains []string return nil, err } t.name = name - t.wrapper = newDeviceWrapper(tunDevice) + t.filteredDevice = newDeviceFilter(tunDevice) log.Debugf("attaching to interface %v", name) - t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] ")) + t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] ")) // without this property mobile devices can discover remote endpoints if the configured one was wrong. // this helps with support for the older NetBird clients that had a hardcoded direct mode // t.device.DisableSomeRoamingForBrokenMobileSemantics() - t.configurer = newWGUSPConfigurer(t.device, t.name) - err = t.configurer.configureInterface(t.key, t.port) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name) + err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { t.device.Close() - t.configurer.close() + t.configurer.Close() return nil, err } return t.configurer, nil } -func (t *wgTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err @@ -93,14 +93,14 @@ func (t *wgTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *wgTunDevice) UpdateAddr(addr WGAddress) error { +func (t *WGTunDevice) UpdateAddr(addr WGAddress) error { // todo implement return nil } -func (t *wgTunDevice) Close() error { +func (t *WGTunDevice) Close() error { if t.configurer != nil { - t.configurer.close() + t.configurer.Close() } if t.device != nil { @@ -115,20 +115,20 @@ func (t *wgTunDevice) Close() error { return nil } -func (t *wgTunDevice) Device() *device.Device { +func (t *WGTunDevice) Device() *device.Device { return t.device } -func (t *wgTunDevice) DeviceName() string { +func (t *WGTunDevice) DeviceName() string { return t.name } -func (t *wgTunDevice) WgAddress() WGAddress { +func (t *WGTunDevice) WgAddress() WGAddress { return t.address } -func (t *wgTunDevice) Wrapper() *DeviceWrapper { - return t.wrapper +func (t *WGTunDevice) FilteredDevice() *FilteredDevice { + return t.filteredDevice } func routesToString(routes []string) string { diff --git a/iface/tun_darwin.go b/client/iface/device/device_darwin.go similarity index 69% rename from iface/tun_darwin.go rename to client/iface/device/device_darwin.go index fcf9f8ba0..03e85a7f1 100644 --- a/iface/tun_darwin.go +++ b/client/iface/device/device_darwin.go @@ -1,6 +1,6 @@ //go:build !ios -package iface +package device import ( "fmt" @@ -11,10 +11,11 @@ import ( "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" ) -type tunDevice struct { +type TunDevice struct { name string address WGAddress port int @@ -22,14 +23,14 @@ type tunDevice struct { mtu int iceBind *bind.ICEBind - device *device.Device - wrapper *DeviceWrapper - udpMux *bind.UniversalUDPMuxDefault - configurer wgConfigurer + device *device.Device + filteredDevice *FilteredDevice + udpMux *bind.UniversalUDPMuxDefault + configurer WGConfigurer } -func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice { - return &tunDevice{ +func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice { + return &TunDevice{ name: name, address: address, port: port, @@ -39,16 +40,16 @@ func newTunDevice(name string, address WGAddress, port int, key string, mtu int, } } -func (t *tunDevice) Create() (wgConfigurer, error) { +func (t *TunDevice) Create() (WGConfigurer, error) { tunDevice, err := tun.CreateTUN(t.name, t.mtu) if err != nil { return nil, fmt.Errorf("error creating tun device: %s", err) } - t.wrapper = newDeviceWrapper(tunDevice) + t.filteredDevice = newDeviceFilter(tunDevice) // We need to create a wireguard-go device and listen to configuration requests t.device = device.NewDevice( - t.wrapper, + t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "), ) @@ -59,17 +60,17 @@ func (t *tunDevice) Create() (wgConfigurer, error) { return nil, fmt.Errorf("error assigning ip: %s", err) } - t.configurer = newWGUSPConfigurer(t.device, t.name) - err = t.configurer.configureInterface(t.key, t.port) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name) + err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { t.device.Close() - t.configurer.close() + t.configurer.Close() return nil, fmt.Errorf("error configuring interface: %s", err) } return t.configurer, nil } -func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err @@ -84,14 +85,14 @@ func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *tunDevice) UpdateAddr(address WGAddress) error { +func (t *TunDevice) UpdateAddr(address WGAddress) error { t.address = address return t.assignAddr() } -func (t *tunDevice) Close() error { +func (t *TunDevice) Close() error { if t.configurer != nil { - t.configurer.close() + t.configurer.Close() } if t.device != nil { @@ -105,20 +106,20 @@ func (t *tunDevice) Close() error { return nil } -func (t *tunDevice) WgAddress() WGAddress { +func (t *TunDevice) WgAddress() WGAddress { return t.address } -func (t *tunDevice) DeviceName() string { +func (t *TunDevice) DeviceName() string { return t.name } -func (t *tunDevice) Wrapper() *DeviceWrapper { - return t.wrapper +func (t *TunDevice) FilteredDevice() *FilteredDevice { + return t.filteredDevice } // assignAddr Adds IP address to the tunnel interface and network route based on the range provided -func (t *tunDevice) assignAddr() error { +func (t *TunDevice) assignAddr() error { cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String()) if out, err := cmd.CombinedOutput(); err != nil { log.Errorf("adding address command '%v' failed with output: %s", cmd.String(), out) diff --git a/iface/device_wrapper.go b/client/iface/device/device_filter.go similarity index 81% rename from iface/device_wrapper.go rename to client/iface/device/device_filter.go index 2fa219395..f87f10429 100644 --- a/iface/device_wrapper.go +++ b/client/iface/device/device_filter.go @@ -1,4 +1,4 @@ -package iface +package device import ( "net" @@ -28,22 +28,23 @@ type PacketFilter interface { SetNetwork(*net.IPNet) } -// DeviceWrapper to override Read or Write of packets -type DeviceWrapper struct { +// FilteredDevice to override Read or Write of packets +type FilteredDevice struct { tun.Device + filter PacketFilter mutex sync.RWMutex } -// newDeviceWrapper constructor function -func newDeviceWrapper(device tun.Device) *DeviceWrapper { - return &DeviceWrapper{ +// newDeviceFilter constructor function +func newDeviceFilter(device tun.Device) *FilteredDevice { + return &FilteredDevice{ Device: device, } } // Read wraps read method with filtering feature -func (d *DeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { +func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { if n, err = d.Device.Read(bufs, sizes, offset); err != nil { return 0, err } @@ -68,7 +69,7 @@ func (d *DeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err } // Write wraps write method with filtering feature -func (d *DeviceWrapper) Write(bufs [][]byte, offset int) (int, error) { +func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) { d.mutex.RLock() filter := d.filter d.mutex.RUnlock() @@ -92,7 +93,7 @@ func (d *DeviceWrapper) Write(bufs [][]byte, offset int) (int, error) { } // SetFilter sets packet filter to device -func (d *DeviceWrapper) SetFilter(filter PacketFilter) { +func (d *FilteredDevice) SetFilter(filter PacketFilter) { d.mutex.Lock() d.filter = filter d.mutex.Unlock() diff --git a/iface/device_wrapper_test.go b/client/iface/device/device_filter_test.go similarity index 95% rename from iface/device_wrapper_test.go rename to client/iface/device/device_filter_test.go index 2d3725ea4..d3278b918 100644 --- a/iface/device_wrapper_test.go +++ b/client/iface/device/device_filter_test.go @@ -1,4 +1,4 @@ -package iface +package device import ( "net" @@ -7,7 +7,8 @@ import ( "github.com/golang/mock/gomock" "github.com/google/gopacket" "github.com/google/gopacket/layers" - mocks "github.com/netbirdio/netbird/iface/mocks" + + mocks "github.com/netbirdio/netbird/client/iface/mocks" ) func TestDeviceWrapperRead(t *testing.T) { @@ -51,7 +52,7 @@ func TestDeviceWrapperRead(t *testing.T) { return 1, nil }) - wrapped := newDeviceWrapper(tun) + wrapped := newDeviceFilter(tun) bufs := [][]byte{{}} sizes := []int{0} @@ -99,7 +100,7 @@ func TestDeviceWrapperRead(t *testing.T) { tun := mocks.NewMockDevice(ctrl) tun.EXPECT().Write(mockBufs, 0).Return(1, nil) - wrapped := newDeviceWrapper(tun) + wrapped := newDeviceFilter(tun) bufs := [][]byte{buffer.Bytes()} @@ -147,7 +148,7 @@ func TestDeviceWrapperRead(t *testing.T) { filter := mocks.NewMockPacketFilter(ctrl) filter.EXPECT().DropIncoming(gomock.Any()).Return(true) - wrapped := newDeviceWrapper(tun) + wrapped := newDeviceFilter(tun) wrapped.filter = filter bufs := [][]byte{buffer.Bytes()} @@ -202,7 +203,7 @@ func TestDeviceWrapperRead(t *testing.T) { filter := mocks.NewMockPacketFilter(ctrl) filter.EXPECT().DropOutgoing(gomock.Any()).Return(true) - wrapped := newDeviceWrapper(tun) + wrapped := newDeviceFilter(tun) wrapped.filter = filter bufs := [][]byte{{}} diff --git a/iface/tun_ios.go b/client/iface/device/device_ios.go similarity index 63% rename from iface/tun_ios.go rename to client/iface/device/device_ios.go index 6d53cc333..226e8a2e0 100644 --- a/iface/tun_ios.go +++ b/client/iface/device/device_ios.go @@ -1,7 +1,7 @@ //go:build ios // +build ios -package iface +package device import ( "os" @@ -12,10 +12,11 @@ import ( "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" ) -type tunDevice struct { +type TunDevice struct { name string address WGAddress port int @@ -23,14 +24,14 @@ type tunDevice struct { iceBind *bind.ICEBind tunFd int - device *device.Device - wrapper *DeviceWrapper - udpMux *bind.UniversalUDPMuxDefault - configurer wgConfigurer + device *device.Device + filteredDevice *FilteredDevice + udpMux *bind.UniversalUDPMuxDefault + configurer WGConfigurer } -func newTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int, filterFn bind.FilterFn) *tunDevice { - return &tunDevice{ +func NewTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int, filterFn bind.FilterFn) *TunDevice { + return &TunDevice{ name: name, address: address, port: port, @@ -40,7 +41,7 @@ func newTunDevice(name string, address WGAddress, port int, key string, transpor } } -func (t *tunDevice) Create() (wgConfigurer, error) { +func (t *TunDevice) Create() (WGConfigurer, error) { log.Infof("create tun interface") dupTunFd, err := unix.Dup(t.tunFd) @@ -62,24 +63,24 @@ func (t *tunDevice) Create() (wgConfigurer, error) { return nil, err } - t.wrapper = newDeviceWrapper(tunDevice) + t.filteredDevice = newDeviceFilter(tunDevice) log.Debug("Attaching to interface") - t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] ")) + t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] ")) // without this property mobile devices can discover remote endpoints if the configured one was wrong. // this helps with support for the older NetBird clients that had a hardcoded direct mode // t.device.DisableSomeRoamingForBrokenMobileSemantics() - t.configurer = newWGUSPConfigurer(t.device, t.name) - err = t.configurer.configureInterface(t.key, t.port) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name) + err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { t.device.Close() - t.configurer.close() + t.configurer.Close() return nil, err } return t.configurer, nil } -func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err @@ -94,17 +95,17 @@ func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *tunDevice) Device() *device.Device { +func (t *TunDevice) Device() *device.Device { return t.device } -func (t *tunDevice) DeviceName() string { +func (t *TunDevice) DeviceName() string { return t.name } -func (t *tunDevice) Close() error { +func (t *TunDevice) Close() error { if t.configurer != nil { - t.configurer.close() + t.configurer.Close() } if t.device != nil { @@ -119,15 +120,15 @@ func (t *tunDevice) Close() error { return nil } -func (t *tunDevice) WgAddress() WGAddress { +func (t *TunDevice) WgAddress() WGAddress { return t.address } -func (t *tunDevice) UpdateAddr(addr WGAddress) error { +func (t *TunDevice) UpdateAddr(addr WGAddress) error { // todo implement return nil } -func (t *tunDevice) Wrapper() *DeviceWrapper { - return t.wrapper +func (t *TunDevice) FilteredDevice() *FilteredDevice { + return t.filteredDevice } diff --git a/iface/tun_kernel_unix.go b/client/iface/device/device_kernel_unix.go similarity index 75% rename from iface/tun_kernel_unix.go rename to client/iface/device/device_kernel_unix.go index 220c07888..f355d2cf7 100644 --- a/iface/tun_kernel_unix.go +++ b/client/iface/device/device_kernel_unix.go @@ -1,6 +1,6 @@ //go:build (linux && !android) || freebsd -package iface +package device import ( "context" @@ -10,11 +10,12 @@ import ( "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/sharedsock" ) -type tunKernelDevice struct { +type TunKernelDevice struct { name string address WGAddress wgPort int @@ -31,11 +32,11 @@ type tunKernelDevice struct { filterFn bind.FilterFn } -func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) wgTunDevice { +func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice { checkUser() ctx, cancel := context.WithCancel(context.Background()) - return &tunKernelDevice{ + return &TunKernelDevice{ ctx: ctx, ctxCancel: cancel, name: name, @@ -47,7 +48,7 @@ func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu in } } -func (t *tunKernelDevice) Create() (wgConfigurer, error) { +func (t *TunKernelDevice) Create() (WGConfigurer, error) { link := newWGLink(t.name) if err := link.recreate(); err != nil { @@ -67,16 +68,16 @@ func (t *tunKernelDevice) Create() (wgConfigurer, error) { return nil, fmt.Errorf("set mtu: %w", err) } - configurer := newWGConfigurer(t.name) + configurer := configurer.NewKernelConfigurer(t.name) - if err := configurer.configureInterface(t.key, t.wgPort); err != nil { + if err := configurer.ConfigureInterface(t.key, t.wgPort); err != nil { return nil, fmt.Errorf("error configuring interface: %s", err) } return configurer, nil } -func (t *tunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { if t.udpMux != nil { return t.udpMux, nil } @@ -111,12 +112,12 @@ func (t *tunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return t.udpMux, nil } -func (t *tunKernelDevice) UpdateAddr(address WGAddress) error { +func (t *TunKernelDevice) UpdateAddr(address WGAddress) error { t.address = address return t.assignAddr() } -func (t *tunKernelDevice) Close() error { +func (t *TunKernelDevice) Close() error { if t.link == nil { return nil } @@ -144,19 +145,19 @@ func (t *tunKernelDevice) Close() error { return closErr } -func (t *tunKernelDevice) WgAddress() WGAddress { +func (t *TunKernelDevice) WgAddress() WGAddress { return t.address } -func (t *tunKernelDevice) DeviceName() string { +func (t *TunKernelDevice) DeviceName() string { return t.name } -func (t *tunKernelDevice) Wrapper() *DeviceWrapper { +func (t *TunKernelDevice) FilteredDevice() *FilteredDevice { return nil } // assignAddr Adds IP address to the tunnel interface -func (t *tunKernelDevice) assignAddr() error { +func (t *TunKernelDevice) assignAddr() error { return t.link.assignAddr(t.address) } diff --git a/iface/tun_netstack.go b/client/iface/device/device_netstack.go similarity index 56% rename from iface/tun_netstack.go rename to client/iface/device/device_netstack.go index de1ff6654..440a1ca19 100644 --- a/iface/tun_netstack.go +++ b/client/iface/device/device_netstack.go @@ -1,7 +1,7 @@ //go:build !android // +build !android -package iface +package device import ( "fmt" @@ -10,11 +10,12 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/device" - "github.com/netbirdio/netbird/iface/bind" - "github.com/netbirdio/netbird/iface/netstack" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/netstack" ) -type tunNetstackDevice struct { +type TunNetstackDevice struct { name string address WGAddress port int @@ -23,15 +24,15 @@ type tunNetstackDevice struct { listenAddress string iceBind *bind.ICEBind - device *device.Device - wrapper *DeviceWrapper - nsTun *netstack.NetStackTun - udpMux *bind.UniversalUDPMuxDefault - configurer wgConfigurer + device *device.Device + filteredDevice *FilteredDevice + nsTun *netstack.NetStackTun + udpMux *bind.UniversalUDPMuxDefault + configurer WGConfigurer } -func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string, filterFn bind.FilterFn) wgTunDevice { - return &tunNetstackDevice{ +func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string, filterFn bind.FilterFn) *TunNetstackDevice { + return &TunNetstackDevice{ name: name, address: address, port: wgPort, @@ -42,23 +43,23 @@ func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string } } -func (t *tunNetstackDevice) Create() (wgConfigurer, error) { +func (t *TunNetstackDevice) Create() (WGConfigurer, error) { log.Info("create netstack tun interface") t.nsTun = netstack.NewNetStackTun(t.listenAddress, t.address.IP.String(), t.mtu) tunIface, err := t.nsTun.Create() if err != nil { return nil, fmt.Errorf("error creating tun device: %s", err) } - t.wrapper = newDeviceWrapper(tunIface) + t.filteredDevice = newDeviceFilter(tunIface) t.device = device.NewDevice( - t.wrapper, + t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "), ) - t.configurer = newWGUSPConfigurer(t.device, t.name) - err = t.configurer.configureInterface(t.key, t.port) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name) + err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { _ = tunIface.Close() return nil, fmt.Errorf("error configuring interface: %s", err) @@ -68,7 +69,7 @@ func (t *tunNetstackDevice) Create() (wgConfigurer, error) { return t.configurer, nil } -func (t *tunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) { if t.device == nil { return nil, fmt.Errorf("device is not ready yet") } @@ -87,13 +88,13 @@ func (t *tunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *tunNetstackDevice) UpdateAddr(WGAddress) error { +func (t *TunNetstackDevice) UpdateAddr(WGAddress) error { return nil } -func (t *tunNetstackDevice) Close() error { +func (t *TunNetstackDevice) Close() error { if t.configurer != nil { - t.configurer.close() + t.configurer.Close() } if t.device != nil { @@ -106,14 +107,14 @@ func (t *tunNetstackDevice) Close() error { return nil } -func (t *tunNetstackDevice) WgAddress() WGAddress { +func (t *TunNetstackDevice) WgAddress() WGAddress { return t.address } -func (t *tunNetstackDevice) DeviceName() string { +func (t *TunNetstackDevice) DeviceName() string { return t.name } -func (t *tunNetstackDevice) Wrapper() *DeviceWrapper { - return t.wrapper +func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice { + return t.filteredDevice } diff --git a/iface/tun_usp_unix.go b/client/iface/device/device_usp_unix.go similarity index 63% rename from iface/tun_usp_unix.go rename to client/iface/device/device_usp_unix.go index 1c1d3ac89..4175f6556 100644 --- a/iface/tun_usp_unix.go +++ b/client/iface/device/device_usp_unix.go @@ -1,6 +1,6 @@ //go:build (linux && !android) || freebsd -package iface +package device import ( "fmt" @@ -12,10 +12,11 @@ import ( "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" ) -type tunUSPDevice struct { +type USPDevice struct { name string address WGAddress port int @@ -23,39 +24,38 @@ type tunUSPDevice struct { mtu int iceBind *bind.ICEBind - device *device.Device - wrapper *DeviceWrapper - udpMux *bind.UniversalUDPMuxDefault - configurer wgConfigurer + device *device.Device + filteredDevice *FilteredDevice + udpMux *bind.UniversalUDPMuxDefault + configurer WGConfigurer } -func newTunUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice { +func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *USPDevice { log.Infof("using userspace bind mode") checkUser() - return &tunUSPDevice{ + return &USPDevice{ name: name, address: address, port: port, key: key, mtu: mtu, - iceBind: bind.NewICEBind(transportNet, filterFn), - } + iceBind: bind.NewICEBind(transportNet, filterFn)} } -func (t *tunUSPDevice) Create() (wgConfigurer, error) { +func (t *USPDevice) Create() (WGConfigurer, error) { log.Info("create tun interface") tunIface, err := tun.CreateTUN(t.name, t.mtu) if err != nil { log.Debugf("failed to create tun interface (%s, %d): %s", t.name, t.mtu, err) return nil, fmt.Errorf("error creating tun device: %s", err) } - t.wrapper = newDeviceWrapper(tunIface) + t.filteredDevice = newDeviceFilter(tunIface) // We need to create a wireguard-go device and listen to configuration requests t.device = device.NewDevice( - t.wrapper, + t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "), ) @@ -66,17 +66,17 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) { return nil, fmt.Errorf("error assigning ip: %s", err) } - t.configurer = newWGUSPConfigurer(t.device, t.name) - err = t.configurer.configureInterface(t.key, t.port) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name) + err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { t.device.Close() - t.configurer.close() + t.configurer.Close() return nil, fmt.Errorf("error configuring interface: %s", err) } return t.configurer, nil } -func (t *tunUSPDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) { if t.device == nil { return nil, fmt.Errorf("device is not ready yet") } @@ -96,14 +96,14 @@ func (t *tunUSPDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *tunUSPDevice) UpdateAddr(address WGAddress) error { +func (t *USPDevice) UpdateAddr(address WGAddress) error { t.address = address return t.assignAddr() } -func (t *tunUSPDevice) Close() error { +func (t *USPDevice) Close() error { if t.configurer != nil { - t.configurer.close() + t.configurer.Close() } if t.device != nil { @@ -116,20 +116,20 @@ func (t *tunUSPDevice) Close() error { return nil } -func (t *tunUSPDevice) WgAddress() WGAddress { +func (t *USPDevice) WgAddress() WGAddress { return t.address } -func (t *tunUSPDevice) DeviceName() string { +func (t *USPDevice) DeviceName() string { return t.name } -func (t *tunUSPDevice) Wrapper() *DeviceWrapper { - return t.wrapper +func (t *USPDevice) FilteredDevice() *FilteredDevice { + return t.filteredDevice } // assignAddr Adds IP address to the tunnel interface -func (t *tunUSPDevice) assignAddr() error { +func (t *USPDevice) assignAddr() error { link := newWGLink(t.name) return link.assignAddr(t.address) diff --git a/iface/tun_windows.go b/client/iface/device/device_windows.go similarity index 75% rename from iface/tun_windows.go rename to client/iface/device/device_windows.go index afb67bcc0..f3e216ccd 100644 --- a/iface/tun_windows.go +++ b/client/iface/device/device_windows.go @@ -1,4 +1,4 @@ -package iface +package device import ( "fmt" @@ -11,12 +11,13 @@ import ( "golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" ) const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}" -type tunDevice struct { +type TunDevice struct { name string address WGAddress port int @@ -26,13 +27,13 @@ type tunDevice struct { device *device.Device nativeTunDevice *tun.NativeTun - wrapper *DeviceWrapper + filteredDevice *FilteredDevice udpMux *bind.UniversalUDPMuxDefault - configurer wgConfigurer + configurer WGConfigurer } -func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice { - return &tunDevice{ +func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice { + return &TunDevice{ name: name, address: address, port: port, @@ -50,7 +51,7 @@ func getGUID() (windows.GUID, error) { return windows.GUIDFromString(guidString) } -func (t *tunDevice) Create() (wgConfigurer, error) { +func (t *TunDevice) Create() (WGConfigurer, error) { guid, err := getGUID() if err != nil { log.Errorf("failed to get GUID: %s", err) @@ -62,11 +63,11 @@ func (t *tunDevice) Create() (wgConfigurer, error) { return nil, fmt.Errorf("error creating tun device: %s", err) } t.nativeTunDevice = tunDevice.(*tun.NativeTun) - t.wrapper = newDeviceWrapper(tunDevice) + t.filteredDevice = newDeviceFilter(tunDevice) // We need to create a wireguard-go device and listen to configuration requests t.device = device.NewDevice( - t.wrapper, + t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "), ) @@ -92,17 +93,17 @@ func (t *tunDevice) Create() (wgConfigurer, error) { return nil, fmt.Errorf("error assigning ip: %s", err) } - t.configurer = newWGUSPConfigurer(t.device, t.name) - err = t.configurer.configureInterface(t.key, t.port) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name) + err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { t.device.Close() - t.configurer.close() + t.configurer.Close() return nil, fmt.Errorf("error configuring interface: %s", err) } return t.configurer, nil } -func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err @@ -117,14 +118,14 @@ func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *tunDevice) UpdateAddr(address WGAddress) error { +func (t *TunDevice) UpdateAddr(address WGAddress) error { t.address = address return t.assignAddr() } -func (t *tunDevice) Close() error { +func (t *TunDevice) Close() error { if t.configurer != nil { - t.configurer.close() + t.configurer.Close() } if t.device != nil { @@ -138,19 +139,19 @@ func (t *tunDevice) Close() error { } return nil } -func (t *tunDevice) WgAddress() WGAddress { +func (t *TunDevice) WgAddress() WGAddress { return t.address } -func (t *tunDevice) DeviceName() string { +func (t *TunDevice) DeviceName() string { return t.name } -func (t *tunDevice) Wrapper() *DeviceWrapper { - return t.wrapper +func (t *TunDevice) FilteredDevice() *FilteredDevice { + return t.filteredDevice } -func (t *tunDevice) getInterfaceGUIDString() (string, error) { +func (t *TunDevice) GetInterfaceGUIDString() (string, error) { if t.nativeTunDevice == nil { return "", fmt.Errorf("interface has not been initialized yet") } @@ -164,7 +165,7 @@ func (t *tunDevice) getInterfaceGUIDString() (string, error) { } // assignAddr Adds IP address to the tunnel interface and network route based on the range provided -func (t *tunDevice) assignAddr() error { +func (t *TunDevice) assignAddr() error { luid := winipcfg.LUID(t.nativeTunDevice.LUID()) log.Debugf("adding address %s to interface: %s", t.address.IP, t.name) return luid.SetIPAddresses([]netip.Prefix{netip.MustParsePrefix(t.address.String())}) diff --git a/client/iface/device/interface.go b/client/iface/device/interface.go new file mode 100644 index 000000000..0196b0085 --- /dev/null +++ b/client/iface/device/interface.go @@ -0,0 +1,20 @@ +package device + +import ( + "net" + "time" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/iface/configurer" +) + +type WGConfigurer interface { + ConfigureInterface(privateKey string, port int) error + UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + RemovePeer(peerKey string) error + AddAllowedIP(peerKey string, allowedIP string) error + RemoveAllowedIP(peerKey string, allowedIP string) error + Close() + GetStats(peerKey string) (configurer.WGStats, error) +} diff --git a/iface/module.go b/client/iface/device/kernel_module.go similarity index 92% rename from iface/module.go rename to client/iface/device/kernel_module.go index ca70cf3c7..1bdd6f7c6 100644 --- a/iface/module.go +++ b/client/iface/device/kernel_module.go @@ -1,6 +1,6 @@ //go:build (!linux && !freebsd) || android -package iface +package device // WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only) func WireGuardModuleIsLoaded() bool { diff --git a/iface/module_freebsd.go b/client/iface/device/kernel_module_freebsd.go similarity index 84% rename from iface/module_freebsd.go rename to client/iface/device/kernel_module_freebsd.go index 00ad882c2..dd6c8b408 100644 --- a/iface/module_freebsd.go +++ b/client/iface/device/kernel_module_freebsd.go @@ -1,4 +1,4 @@ -package iface +package device // WireGuardModuleIsLoaded check if kernel support wireguard func WireGuardModuleIsLoaded() bool { @@ -10,8 +10,8 @@ func WireGuardModuleIsLoaded() bool { return false } -// tunModuleIsLoaded check if tun module exist, if is not attempt to load it -func tunModuleIsLoaded() bool { +// ModuleTunIsLoaded check if tun module exist, if is not attempt to load it +func ModuleTunIsLoaded() bool { // Assume tun supported by freebsd kernel by default // TODO: implement check for module loaded in kernel or build-it return true diff --git a/iface/module_linux.go b/client/iface/device/kernel_module_linux.go similarity index 98% rename from iface/module_linux.go rename to client/iface/device/kernel_module_linux.go index 11c0482d5..0d195779d 100644 --- a/iface/module_linux.go +++ b/client/iface/device/kernel_module_linux.go @@ -1,7 +1,7 @@ //go:build linux && !android // Package iface provides wireguard network interface creation and management -package iface +package device import ( "bufio" @@ -66,8 +66,8 @@ func getModuleRoot() string { return filepath.Join(moduleLibDir, string(uname.Release[:i])) } -// tunModuleIsLoaded check if tun module exist, if is not attempt to load it -func tunModuleIsLoaded() bool { +// ModuleTunIsLoaded check if tun module exist, if is not attempt to load it +func ModuleTunIsLoaded() bool { _, err := os.Stat("/dev/net/tun") if err == nil { return true diff --git a/iface/module_linux_test.go b/client/iface/device/kernel_module_linux_test.go similarity index 98% rename from iface/module_linux_test.go rename to client/iface/device/kernel_module_linux_test.go index 97e9b1f78..de9656e47 100644 --- a/iface/module_linux_test.go +++ b/client/iface/device/kernel_module_linux_test.go @@ -1,4 +1,6 @@ -package iface +//go:build linux && !android + +package device import ( "bufio" @@ -132,7 +134,7 @@ func resetGlobals() { } func createFiles(t *testing.T) (string, []module) { - t.Helper() + t.Helper() writeFile := func(path, text string) { if err := os.WriteFile(path, []byte(text), 0644); err != nil { t.Fatal(err) @@ -168,7 +170,7 @@ func createFiles(t *testing.T) (string, []module) { } func getRandomLoadedModule(t *testing.T) (string, error) { - t.Helper() + t.Helper() f, err := os.Open("/proc/modules") if err != nil { return "", err diff --git a/iface/tun_link_freebsd.go b/client/iface/device/wg_link_freebsd.go similarity index 95% rename from iface/tun_link_freebsd.go rename to client/iface/device/wg_link_freebsd.go index be7921fdb..104010f47 100644 --- a/iface/tun_link_freebsd.go +++ b/client/iface/device/wg_link_freebsd.go @@ -1,10 +1,11 @@ -package iface +package device import ( "fmt" - "github.com/netbirdio/netbird/iface/freebsd" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/freebsd" ) type wgLink struct { diff --git a/iface/tun_link_linux.go b/client/iface/device/wg_link_linux.go similarity index 99% rename from iface/tun_link_linux.go rename to client/iface/device/wg_link_linux.go index 3ce644e84..a15cffe48 100644 --- a/iface/tun_link_linux.go +++ b/client/iface/device/wg_link_linux.go @@ -1,6 +1,6 @@ //go:build linux && !android -package iface +package device import ( "fmt" diff --git a/iface/wg_log.go b/client/iface/device/wg_log.go similarity index 93% rename from iface/wg_log.go rename to client/iface/device/wg_log.go index b44f6fc0b..db2f3111f 100644 --- a/iface/wg_log.go +++ b/client/iface/device/wg_log.go @@ -1,4 +1,4 @@ -package iface +package device import ( "os" diff --git a/client/iface/device/windows_guid.go b/client/iface/device/windows_guid.go new file mode 100644 index 000000000..1c7d40d13 --- /dev/null +++ b/client/iface/device/windows_guid.go @@ -0,0 +1,4 @@ +package device + +// CustomWindowsGUIDString is a custom GUID string for the interface +var CustomWindowsGUIDString string diff --git a/client/iface/device_android.go b/client/iface/device_android.go new file mode 100644 index 000000000..3d15080ff --- /dev/null +++ b/client/iface/device_android.go @@ -0,0 +1,16 @@ +package iface + +import ( + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" +) + +type WGTunDevice interface { + Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error) + Up() (*bind.UniversalUDPMuxDefault, error) + UpdateAddr(address WGAddress) error + WgAddress() WGAddress + DeviceName() string + Close() error + FilteredDevice() *device.FilteredDevice +} diff --git a/iface/freebsd/errors.go b/client/iface/freebsd/errors.go similarity index 100% rename from iface/freebsd/errors.go rename to client/iface/freebsd/errors.go diff --git a/iface/freebsd/iface.go b/client/iface/freebsd/iface.go similarity index 100% rename from iface/freebsd/iface.go rename to client/iface/freebsd/iface.go diff --git a/iface/freebsd/iface_internal_test.go b/client/iface/freebsd/iface_internal_test.go similarity index 100% rename from iface/freebsd/iface_internal_test.go rename to client/iface/freebsd/iface_internal_test.go diff --git a/iface/freebsd/link.go b/client/iface/freebsd/link.go similarity index 100% rename from iface/freebsd/link.go rename to client/iface/freebsd/link.go diff --git a/iface/iface.go b/client/iface/iface.go similarity index 79% rename from iface/iface.go rename to client/iface/iface.go index 545feffcf..accf5ce0a 100644 --- a/iface/iface.go +++ b/client/iface/iface.go @@ -9,28 +9,27 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/device" ) const ( - DefaultMTU = 1280 - DefaultWgPort = 51820 + DefaultMTU = 1280 + DefaultWgPort = 51820 + WgInterfaceDefault = configurer.WgInterfaceDefault ) -// WGIface represents a interface instance +type WGAddress = device.WGAddress + +// WGIface represents an interface instance type WGIface struct { - tun wgTunDevice + tun WGTunDevice userspaceBind bool mu sync.Mutex - configurer wgConfigurer - filter PacketFilter -} - -type WGStats struct { - LastHandshake time.Time - TxBytes int64 - RxBytes int64 + configurer device.WGConfigurer + filter device.PacketFilter } // IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind @@ -44,7 +43,7 @@ func (w *WGIface) Name() string { } // Address returns the interface address -func (w *WGIface) Address() WGAddress { +func (w *WGIface) Address() device.WGAddress { return w.tun.WgAddress() } @@ -75,7 +74,7 @@ func (w *WGIface) UpdateAddr(newAddr string) error { w.mu.Lock() defer w.mu.Unlock() - addr, err := parseWGAddress(newAddr) + addr, err := device.ParseWGAddress(newAddr) if err != nil { return err } @@ -90,7 +89,7 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.D defer w.mu.Unlock() log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint) - return w.configurer.updatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) + return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) } // RemovePeer removes a Wireguard Peer from the interface iface @@ -99,7 +98,7 @@ func (w *WGIface) RemovePeer(peerKey string) error { defer w.mu.Unlock() log.Debugf("Removing peer %s from interface %s ", peerKey, w.tun.DeviceName()) - return w.configurer.removePeer(peerKey) + return w.configurer.RemovePeer(peerKey) } // AddAllowedIP adds a prefix to the allowed IPs list of peer @@ -108,7 +107,7 @@ func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error { defer w.mu.Unlock() log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP) - return w.configurer.addAllowedIP(peerKey, allowedIP) + return w.configurer.AddAllowedIP(peerKey, allowedIP) } // RemoveAllowedIP removes a prefix from the allowed IPs list of peer @@ -117,7 +116,7 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error { defer w.mu.Unlock() log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP) - return w.configurer.removeAllowedIP(peerKey, allowedIP) + return w.configurer.RemoveAllowedIP(peerKey, allowedIP) } // Close closes the tunnel interface @@ -144,23 +143,23 @@ func (w *WGIface) Close() error { } // SetFilter sets packet filters for the userspace implementation -func (w *WGIface) SetFilter(filter PacketFilter) error { +func (w *WGIface) SetFilter(filter device.PacketFilter) error { w.mu.Lock() defer w.mu.Unlock() - if w.tun.Wrapper() == nil { + if w.tun.FilteredDevice() == nil { return fmt.Errorf("userspace packet filtering not handled on this device") } w.filter = filter w.filter.SetNetwork(w.tun.WgAddress().Network) - w.tun.Wrapper().SetFilter(filter) + w.tun.FilteredDevice().SetFilter(filter) return nil } // GetFilter returns packet filter used by interface if it uses userspace device implementation -func (w *WGIface) GetFilter() PacketFilter { +func (w *WGIface) GetFilter() device.PacketFilter { w.mu.Lock() defer w.mu.Unlock() @@ -168,16 +167,16 @@ func (w *WGIface) GetFilter() PacketFilter { } // GetDevice to interact with raw device (with filtering) -func (w *WGIface) GetDevice() *DeviceWrapper { +func (w *WGIface) GetDevice() *device.FilteredDevice { w.mu.Lock() defer w.mu.Unlock() - return w.tun.Wrapper() + return w.tun.FilteredDevice() } // GetStats returns the last handshake time, rx and tx bytes for the given peer -func (w *WGIface) GetStats(peerKey string) (WGStats, error) { - return w.configurer.getStats(peerKey) +func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) { + return w.configurer.GetStats(peerKey) } func (w *WGIface) waitUntilRemoved() error { diff --git a/iface/iface_android.go b/client/iface/iface_android.go similarity index 67% rename from iface/iface_android.go rename to client/iface/iface_android.go index 99f6885a5..5ed476e70 100644 --- a/iface/iface_android.go +++ b/client/iface/iface_android.go @@ -5,18 +5,19 @@ import ( "github.com/pion/transport/v3" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { - wgAddress, err := parseWGAddress(address) +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { + wgAddress, err := device.ParseWGAddress(address) if err != nil { return nil, err } wgIFace := &WGIface{ - tun: newTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter, filterFn), + tun: device.NewTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter, filterFn), userspaceBind: true, } return wgIFace, nil diff --git a/iface/iface_create.go b/client/iface/iface_create.go similarity index 100% rename from iface/iface_create.go rename to client/iface/iface_create.go diff --git a/iface/iface_darwin.go b/client/iface/iface_darwin.go similarity index 68% rename from iface/iface_darwin.go rename to client/iface/iface_darwin.go index f48f324c3..b46ea0f80 100644 --- a/iface/iface_darwin.go +++ b/client/iface/iface_darwin.go @@ -9,13 +9,14 @@ import ( "github.com/cenkalti/backoff/v4" "github.com/pion/transport/v3" - "github.com/netbirdio/netbird/iface/bind" - "github.com/netbirdio/netbird/iface/netstack" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/netstack" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, _ *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { - wgAddress, err := parseWGAddress(address) +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, _ *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { + wgAddress, err := device.ParseWGAddress(address) if err != nil { return nil, err } @@ -25,11 +26,11 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, } if netstack.IsEnabled() { - wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) + wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) return wgIFace, nil } - wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn) + wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn) return wgIFace, nil } diff --git a/iface/iface_destroy_bsd.go b/client/iface/iface_destroy_bsd.go similarity index 100% rename from iface/iface_destroy_bsd.go rename to client/iface/iface_destroy_bsd.go diff --git a/iface/iface_destroy_linux.go b/client/iface/iface_destroy_linux.go similarity index 100% rename from iface/iface_destroy_linux.go rename to client/iface/iface_destroy_linux.go diff --git a/iface/iface_destroy_mobile.go b/client/iface/iface_destroy_mobile.go similarity index 100% rename from iface/iface_destroy_mobile.go rename to client/iface/iface_destroy_mobile.go diff --git a/iface/iface_destroy_windows.go b/client/iface/iface_destroy_windows.go similarity index 100% rename from iface/iface_destroy_windows.go rename to client/iface/iface_destroy_windows.go diff --git a/iface/iface_ios.go b/client/iface/iface_ios.go similarity index 59% rename from iface/iface_ios.go rename to client/iface/iface_ios.go index 6babe5964..fc0214748 100644 --- a/iface/iface_ios.go +++ b/client/iface/iface_ios.go @@ -7,17 +7,18 @@ import ( "github.com/pion/transport/v3" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { - wgAddress, err := parseWGAddress(address) +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { + wgAddress, err := device.ParseWGAddress(address) if err != nil { return nil, err } wgIFace := &WGIface{ - tun: newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd, filterFn), + tun: device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd, filterFn), userspaceBind: true, } return wgIFace, nil diff --git a/iface/iface_moc.go b/client/iface/iface_moc.go similarity index 76% rename from iface/iface_moc.go rename to client/iface/iface_moc.go index fab3054a0..703da9ce0 100644 --- a/iface/iface_moc.go +++ b/client/iface/iface_moc.go @@ -6,7 +6,9 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/device" ) type MockWGIface struct { @@ -14,7 +16,7 @@ type MockWGIface struct { CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error IsUserspaceBindFunc func() bool NameFunc func() string - AddressFunc func() WGAddress + AddressFunc func() device.WGAddress ToInterfaceFunc func() *net.Interface UpFunc func() (*bind.UniversalUDPMuxDefault, error) UpdateAddrFunc func(newAddr string) error @@ -23,10 +25,10 @@ type MockWGIface struct { AddAllowedIPFunc func(peerKey string, allowedIP string) error RemoveAllowedIPFunc func(peerKey string, allowedIP string) error CloseFunc func() error - SetFilterFunc func(filter PacketFilter) error - GetFilterFunc func() PacketFilter - GetDeviceFunc func() *DeviceWrapper - GetStatsFunc func(peerKey string) (WGStats, error) + SetFilterFunc func(filter device.PacketFilter) error + GetFilterFunc func() device.PacketFilter + GetDeviceFunc func() *device.FilteredDevice + GetStatsFunc func(peerKey string) (configurer.WGStats, error) GetInterfaceGUIDStringFunc func() (string, error) } @@ -50,7 +52,7 @@ func (m *MockWGIface) Name() string { return m.NameFunc() } -func (m *MockWGIface) Address() WGAddress { +func (m *MockWGIface) Address() device.WGAddress { return m.AddressFunc() } @@ -86,18 +88,18 @@ func (m *MockWGIface) Close() error { return m.CloseFunc() } -func (m *MockWGIface) SetFilter(filter PacketFilter) error { +func (m *MockWGIface) SetFilter(filter device.PacketFilter) error { return m.SetFilterFunc(filter) } -func (m *MockWGIface) GetFilter() PacketFilter { +func (m *MockWGIface) GetFilter() device.PacketFilter { return m.GetFilterFunc() } -func (m *MockWGIface) GetDevice() *DeviceWrapper { +func (m *MockWGIface) GetDevice() *device.FilteredDevice { return m.GetDeviceFunc() } -func (m *MockWGIface) GetStats(peerKey string) (WGStats, error) { +func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) { return m.GetStatsFunc(peerKey) } diff --git a/iface/iface_test.go b/client/iface/iface_test.go similarity index 98% rename from iface/iface_test.go rename to client/iface/iface_test.go index 8de9f647e..87a68addb 100644 --- a/iface/iface_test.go +++ b/client/iface/iface_test.go @@ -14,6 +14,8 @@ import ( "github.com/stretchr/testify/assert" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/iface/device" ) // keep darwin compatibility @@ -414,7 +416,7 @@ func Test_ConnectPeers(t *testing.T) { } guid := fmt.Sprintf("{%s}", uuid.New().String()) - CustomWindowsGUIDString = strings.ToLower(guid) + device.CustomWindowsGUIDString = strings.ToLower(guid) iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, peer1wgPort, peer1Key.String(), DefaultMTU, newNet, nil, nil) if err != nil { @@ -436,7 +438,7 @@ func Test_ConnectPeers(t *testing.T) { } guid = fmt.Sprintf("{%s}", uuid.New().String()) - CustomWindowsGUIDString = strings.ToLower(guid) + device.CustomWindowsGUIDString = strings.ToLower(guid) newNet, err = stdnet.NewNet() if err != nil { diff --git a/iface/iface_unix.go b/client/iface/iface_unix.go similarity index 53% rename from iface/iface_unix.go rename to client/iface/iface_unix.go index 9608df1ad..09dbb2c1f 100644 --- a/iface/iface_unix.go +++ b/client/iface/iface_unix.go @@ -8,13 +8,14 @@ import ( "github.com/pion/transport/v3" - "github.com/netbirdio/netbird/iface/bind" - "github.com/netbirdio/netbird/iface/netstack" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/netstack" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { - wgAddress, err := parseWGAddress(address) +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { + wgAddress, err := device.ParseWGAddress(address) if err != nil { return nil, err } @@ -23,21 +24,21 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, // move the kernel/usp/netstack preference evaluation to upper layer if netstack.IsEnabled() { - wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) + wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) wgIFace.userspaceBind = true return wgIFace, nil } - if WireGuardModuleIsLoaded() { - wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet) + if device.WireGuardModuleIsLoaded() { + wgIFace.tun = device.NewKernelDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet) wgIFace.userspaceBind = false return wgIFace, nil } - if !tunModuleIsLoaded() { + if !device.ModuleTunIsLoaded() { return nil, fmt.Errorf("couldn't check or load tun module") } - wgIFace.tun = newTunUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, nil) + wgIFace.tun = device.NewUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, nil) wgIFace.userspaceBind = true return wgIFace, nil } diff --git a/iface/iface_windows.go b/client/iface/iface_windows.go similarity index 52% rename from iface/iface_windows.go rename to client/iface/iface_windows.go index c5edd27a9..6845ef3dd 100644 --- a/iface/iface_windows.go +++ b/client/iface/iface_windows.go @@ -5,13 +5,14 @@ import ( "github.com/pion/transport/v3" - "github.com/netbirdio/netbird/iface/bind" - "github.com/netbirdio/netbird/iface/netstack" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/netstack" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { - wgAddress, err := parseWGAddress(address) +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { + wgAddress, err := device.ParseWGAddress(address) if err != nil { return nil, err } @@ -21,11 +22,11 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, } if netstack.IsEnabled() { - wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) + wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) return wgIFace, nil } - wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn) + wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn) return wgIFace, nil } @@ -36,5 +37,5 @@ func (w *WGIface) CreateOnAndroid([]string, string, []string) error { // GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only func (w *WGIface) GetInterfaceGUIDString() (string, error) { - return w.tun.(*tunDevice).getInterfaceGUIDString() + return w.tun.(*device.TunDevice).GetInterfaceGUIDString() } diff --git a/iface/iwginterface.go b/client/iface/iwginterface.go similarity index 65% rename from iface/iwginterface.go rename to client/iface/iwginterface.go index 501f51d2b..cb6d7ccd9 100644 --- a/iface/iwginterface.go +++ b/client/iface/iwginterface.go @@ -8,7 +8,9 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/device" ) type IWGIface interface { @@ -16,7 +18,7 @@ type IWGIface interface { CreateOnAndroid(routeRange []string, ip string, domains []string) error IsUserspaceBind() bool Name() string - Address() WGAddress + Address() device.WGAddress ToInterface() *net.Interface Up() (*bind.UniversalUDPMuxDefault, error) UpdateAddr(newAddr string) error @@ -25,8 +27,8 @@ type IWGIface interface { AddAllowedIP(peerKey string, allowedIP string) error RemoveAllowedIP(peerKey string, allowedIP string) error Close() error - SetFilter(filter PacketFilter) error - GetFilter() PacketFilter - GetDevice() *DeviceWrapper - GetStats(peerKey string) (WGStats, error) + SetFilter(filter device.PacketFilter) error + GetFilter() device.PacketFilter + GetDevice() *device.FilteredDevice + GetStats(peerKey string) (configurer.WGStats, error) } diff --git a/iface/iwginterface_windows.go b/client/iface/iwginterface_windows.go similarity index 65% rename from iface/iwginterface_windows.go rename to client/iface/iwginterface_windows.go index b5053474e..6baeb66ae 100644 --- a/iface/iwginterface_windows.go +++ b/client/iface/iwginterface_windows.go @@ -6,7 +6,9 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/device" ) type IWGIface interface { @@ -14,7 +16,7 @@ type IWGIface interface { CreateOnAndroid(routeRange []string, ip string, domains []string) error IsUserspaceBind() bool Name() string - Address() WGAddress + Address() device.WGAddress ToInterface() *net.Interface Up() (*bind.UniversalUDPMuxDefault, error) UpdateAddr(newAddr string) error @@ -23,9 +25,9 @@ type IWGIface interface { AddAllowedIP(peerKey string, allowedIP string) error RemoveAllowedIP(peerKey string, allowedIP string) error Close() error - SetFilter(filter PacketFilter) error - GetFilter() PacketFilter - GetDevice() *DeviceWrapper - GetStats(peerKey string) (WGStats, error) + SetFilter(filter device.PacketFilter) error + GetFilter() device.PacketFilter + GetDevice() *device.FilteredDevice + GetStats(peerKey string) (configurer.WGStats, error) GetInterfaceGUIDString() (string, error) } diff --git a/iface/mocks/README.md b/client/iface/mocks/README.md similarity index 100% rename from iface/mocks/README.md rename to client/iface/mocks/README.md diff --git a/iface/mocks/filter.go b/client/iface/mocks/filter.go similarity index 97% rename from iface/mocks/filter.go rename to client/iface/mocks/filter.go index 2d80d69f1..6348e0e77 100644 --- a/iface/mocks/filter.go +++ b/client/iface/mocks/filter.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/netbirdio/netbird/iface (interfaces: PacketFilter) +// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter) // Package mocks is a generated GoMock package. package mocks diff --git a/iface/mocks/iface/mocks/filter.go b/client/iface/mocks/iface/mocks/filter.go similarity index 97% rename from iface/mocks/iface/mocks/filter.go rename to client/iface/mocks/iface/mocks/filter.go index 059a2b9a0..17e123abb 100644 --- a/iface/mocks/iface/mocks/filter.go +++ b/client/iface/mocks/iface/mocks/filter.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/netbirdio/netbird/iface (interfaces: PacketFilter) +// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter) // Package mocks is a generated GoMock package. package mocks diff --git a/iface/mocks/tun.go b/client/iface/mocks/tun.go similarity index 100% rename from iface/mocks/tun.go rename to client/iface/mocks/tun.go diff --git a/iface/netstack/dialer.go b/client/iface/netstack/dialer.go similarity index 100% rename from iface/netstack/dialer.go rename to client/iface/netstack/dialer.go diff --git a/iface/netstack/env.go b/client/iface/netstack/env.go similarity index 100% rename from iface/netstack/env.go rename to client/iface/netstack/env.go diff --git a/iface/netstack/proxy.go b/client/iface/netstack/proxy.go similarity index 100% rename from iface/netstack/proxy.go rename to client/iface/netstack/proxy.go diff --git a/iface/netstack/tun.go b/client/iface/netstack/tun.go similarity index 100% rename from iface/netstack/tun.go rename to client/iface/netstack/tun.go diff --git a/client/internal/acl/id/id.go b/client/internal/acl/id/id.go new file mode 100644 index 000000000..8ce73655d --- /dev/null +++ b/client/internal/acl/id/id.go @@ -0,0 +1,64 @@ +package id + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "net/netip" + "strconv" + + "github.com/netbirdio/netbird/client/firewall/manager" +) + +type RuleID string + +func (r RuleID) GetRuleID() string { + return string(r) +} + +func GenerateRouteRuleKey( + sources []netip.Prefix, + destination netip.Prefix, + proto manager.Protocol, + sPort *manager.Port, + dPort *manager.Port, + action manager.Action, +) RuleID { + manager.SortPrefixes(sources) + + h := sha256.New() + + // Write all fields to the hasher, with delimiters + h.Write([]byte("sources:")) + for _, src := range sources { + h.Write([]byte(src.String())) + h.Write([]byte(",")) + } + + h.Write([]byte("destination:")) + h.Write([]byte(destination.String())) + + h.Write([]byte("proto:")) + h.Write([]byte(proto)) + + h.Write([]byte("sPort:")) + if sPort != nil { + h.Write([]byte(sPort.String())) + } else { + h.Write([]byte("")) + } + + h.Write([]byte("dPort:")) + if dPort != nil { + h.Write([]byte(dPort.String())) + } else { + h.Write([]byte("")) + } + + h.Write([]byte("action:")) + h.Write([]byte(strconv.Itoa(int(action)))) + hash := hex.EncodeToString(h.Sum(nil)) + + // prepend destination prefix to be able to identify the rule + return RuleID(fmt.Sprintf("%s-%s", destination.String(), hash[:16])) +} diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index fd2c2c875..ce2a12af1 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "fmt" "net" + "net/netip" "strconv" "sync" "time" @@ -12,6 +13,7 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/ssh" mgmProto "github.com/netbirdio/netbird/management/proto" ) @@ -23,16 +25,18 @@ type Manager interface { // DefaultManager uses firewall manager to handle type DefaultManager struct { - firewall firewall.Manager - ipsetCounter int - rulesPairs map[string][]firewall.Rule - mutex sync.Mutex + firewall firewall.Manager + ipsetCounter int + peerRulesPairs map[id.RuleID][]firewall.Rule + routeRules map[id.RuleID]struct{} + mutex sync.Mutex } func NewDefaultManager(fm firewall.Manager) *DefaultManager { return &DefaultManager{ - firewall: fm, - rulesPairs: make(map[string][]firewall.Rule), + firewall: fm, + peerRulesPairs: make(map[id.RuleID][]firewall.Rule), + routeRules: make(map[id.RuleID]struct{}), } } @@ -46,7 +50,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { start := time.Now() defer func() { total := 0 - for _, pairs := range d.rulesPairs { + for _, pairs := range d.peerRulesPairs { total += len(pairs) } log.Infof( @@ -59,21 +63,34 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { return } - defer func() { - if err := d.firewall.Flush(); err != nil { - log.Error("failed to flush firewall rules: ", err) - } - }() + d.applyPeerACLs(networkMap) + // If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag, + // then the mgmt server is older than the client, and we need to allow all traffic for routes + isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty + if err := d.firewall.SetLegacyManagement(isLegacy); err != nil { + log.Errorf("failed to set legacy management flag: %v", err) + } + + if err := d.applyRouteACLs(networkMap.RoutesFirewallRules); err != nil { + log.Errorf("Failed to apply route ACLs: %v", err) + } + + if err := d.firewall.Flush(); err != nil { + log.Error("failed to flush firewall rules: ", err) + } +} + +func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { rules, squashedProtocols := d.squashAcceptRules(networkMap) - enableSSH := (networkMap.PeerConfig != nil && + enableSSH := networkMap.PeerConfig != nil && networkMap.PeerConfig.SshConfig != nil && - networkMap.PeerConfig.SshConfig.SshEnabled) - if _, ok := squashedProtocols[mgmProto.FirewallRule_ALL]; ok { + networkMap.PeerConfig.SshConfig.SshEnabled + if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok { enableSSH = enableSSH && !ok } - if _, ok := squashedProtocols[mgmProto.FirewallRule_TCP]; ok { + if _, ok := squashedProtocols[mgmProto.RuleProtocol_TCP]; ok { enableSSH = enableSSH && !ok } @@ -83,9 +100,9 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { if enableSSH { rules = append(rules, &mgmProto.FirewallRule{ PeerIP: "0.0.0.0", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_TCP, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, Port: strconv.Itoa(ssh.DefaultSSHPort), }) } @@ -97,20 +114,20 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { rules = append(rules, &mgmProto.FirewallRule{ PeerIP: "0.0.0.0", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, &mgmProto.FirewallRule{ PeerIP: "0.0.0.0", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, ) } - newRulePairs := make(map[string][]firewall.Rule) + newRulePairs := make(map[id.RuleID][]firewall.Rule) ipsetByRuleSelectors := make(map[string]string) for _, r := range rules { @@ -130,29 +147,97 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { break } if len(rules) > 0 { - d.rulesPairs[pairID] = rulePair + d.peerRulesPairs[pairID] = rulePair newRulePairs[pairID] = rulePair } } - for pairID, rules := range d.rulesPairs { + for pairID, rules := range d.peerRulesPairs { if _, ok := newRulePairs[pairID]; !ok { for _, rule := range rules { - if err := d.firewall.DeleteRule(rule); err != nil { - log.Errorf("failed to delete firewall rule: %v", err) + if err := d.firewall.DeletePeerRule(rule); err != nil { + log.Errorf("failed to delete peer firewall rule: %v", err) continue } } - delete(d.rulesPairs, pairID) + delete(d.peerRulesPairs, pairID) } } - d.rulesPairs = newRulePairs + d.peerRulesPairs = newRulePairs +} + +func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error { + var newRouteRules = make(map[id.RuleID]struct{}) + for _, rule := range rules { + id, err := d.applyRouteACL(rule) + if err != nil { + return fmt.Errorf("apply route ACL: %w", err) + } + newRouteRules[id] = struct{}{} + } + + for id := range d.routeRules { + if _, ok := newRouteRules[id]; !ok { + if err := d.firewall.DeleteRouteRule(id); err != nil { + log.Errorf("failed to delete route firewall rule: %v", err) + continue + } + delete(d.routeRules, id) + } + } + d.routeRules = newRouteRules + return nil +} + +func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) { + if len(rule.SourceRanges) == 0 { + return "", fmt.Errorf("source ranges is empty") + } + + var sources []netip.Prefix + for _, sourceRange := range rule.SourceRanges { + source, err := netip.ParsePrefix(sourceRange) + if err != nil { + return "", fmt.Errorf("parse source range: %w", err) + } + sources = append(sources, source) + } + + var destination netip.Prefix + if rule.IsDynamic { + destination = getDefault(sources[0]) + } else { + var err error + destination, err = netip.ParsePrefix(rule.Destination) + if err != nil { + return "", fmt.Errorf("parse destination: %w", err) + } + } + + protocol, err := convertToFirewallProtocol(rule.Protocol) + if err != nil { + return "", fmt.Errorf("invalid protocol: %w", err) + } + + action, err := convertFirewallAction(rule.Action) + if err != nil { + return "", fmt.Errorf("invalid action: %w", err) + } + + dPorts := convertPortInfo(rule.PortInfo) + + addedRule, err := d.firewall.AddRouteFiltering(sources, destination, protocol, nil, dPorts, action) + if err != nil { + return "", fmt.Errorf("add route rule: %w", err) + } + + return id.RuleID(addedRule.GetRuleID()), nil } func (d *DefaultManager) protoRuleToFirewallRule( r *mgmProto.FirewallRule, ipsetName string, -) (string, []firewall.Rule, error) { +) (id.RuleID, []firewall.Rule, error) { ip := net.ParseIP(r.PeerIP) if ip == nil { return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule") @@ -179,16 +264,16 @@ func (d *DefaultManager) protoRuleToFirewallRule( } } - ruleID := d.getRuleID(ip, protocol, int(r.Direction), port, action, "") - if rulesPair, ok := d.rulesPairs[ruleID]; ok { + ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action, "") + if rulesPair, ok := d.peerRulesPairs[ruleID]; ok { return ruleID, rulesPair, nil } var rules []firewall.Rule switch r.Direction { - case mgmProto.FirewallRule_IN: + case mgmProto.RuleDirection_IN: rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "") - case mgmProto.FirewallRule_OUT: + case mgmProto.RuleDirection_OUT: rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "") default: return "", nil, fmt.Errorf("invalid direction, skipping firewall rule") @@ -210,7 +295,7 @@ func (d *DefaultManager) addInRules( comment string, ) ([]firewall.Rule, error) { var rules []firewall.Rule - rule, err := d.firewall.AddFiltering( + rule, err := d.firewall.AddPeerFiltering( ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment) if err != nil { return nil, fmt.Errorf("failed to add firewall rule: %v", err) @@ -221,7 +306,7 @@ func (d *DefaultManager) addInRules( return rules, nil } - rule, err = d.firewall.AddFiltering( + rule, err = d.firewall.AddPeerFiltering( ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment) if err != nil { return nil, fmt.Errorf("failed to add firewall rule: %v", err) @@ -239,7 +324,7 @@ func (d *DefaultManager) addOutRules( comment string, ) ([]firewall.Rule, error) { var rules []firewall.Rule - rule, err := d.firewall.AddFiltering( + rule, err := d.firewall.AddPeerFiltering( ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment) if err != nil { return nil, fmt.Errorf("failed to add firewall rule: %v", err) @@ -250,7 +335,7 @@ func (d *DefaultManager) addOutRules( return rules, nil } - rule, err = d.firewall.AddFiltering( + rule, err = d.firewall.AddPeerFiltering( ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment) if err != nil { return nil, fmt.Errorf("failed to add firewall rule: %v", err) @@ -259,21 +344,21 @@ func (d *DefaultManager) addOutRules( return append(rules, rule...), nil } -// getRuleID() returns unique ID for the rule based on its parameters. -func (d *DefaultManager) getRuleID( +// getPeerRuleID() returns unique ID for the rule based on its parameters. +func (d *DefaultManager) getPeerRuleID( ip net.IP, proto firewall.Protocol, direction int, port *firewall.Port, action firewall.Action, comment string, -) string { +) id.RuleID { idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment if port != nil { idStr += port.String() } - return hex.EncodeToString(md5.New().Sum([]byte(idStr))) + return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr)))) } // squashAcceptRules does complex logic to convert many rules which allows connection by traffic type @@ -283,7 +368,7 @@ func (d *DefaultManager) getRuleID( // but other has port definitions or has drop policy. func (d *DefaultManager) squashAcceptRules( networkMap *mgmProto.NetworkMap, -) ([]*mgmProto.FirewallRule, map[mgmProto.FirewallRuleProtocol]struct{}) { +) ([]*mgmProto.FirewallRule, map[mgmProto.RuleProtocol]struct{}) { totalIPs := 0 for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) { for range p.AllowedIps { @@ -291,14 +376,14 @@ func (d *DefaultManager) squashAcceptRules( } } - type protoMatch map[mgmProto.FirewallRuleProtocol]map[string]int + type protoMatch map[mgmProto.RuleProtocol]map[string]int in := protoMatch{} out := protoMatch{} // trace which type of protocols was squashed squashedRules := []*mgmProto.FirewallRule{} - squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{} + squashedProtocols := map[mgmProto.RuleProtocol]struct{}{} // this function we use to do calculation, can we squash the rules by protocol or not. // We summ amount of Peers IP for given protocol we found in original rules list. @@ -308,7 +393,7 @@ func (d *DefaultManager) squashAcceptRules( // // We zeroed this to notify squash function that this protocol can't be squashed. addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols protoMatch) { - drop := r.Action == mgmProto.FirewallRule_DROP || r.Port != "" + drop := r.Action == mgmProto.RuleAction_DROP || r.Port != "" if drop { protocols[r.Protocol] = map[string]int{} return @@ -336,7 +421,7 @@ func (d *DefaultManager) squashAcceptRules( for i, r := range networkMap.FirewallRules { // calculate squash for different directions - if r.Direction == mgmProto.FirewallRule_IN { + if r.Direction == mgmProto.RuleDirection_IN { addRuleToCalculationMap(i, r, in) } else { addRuleToCalculationMap(i, r, out) @@ -345,14 +430,14 @@ func (d *DefaultManager) squashAcceptRules( // order of squashing by protocol is important // only for their first element ALL, it must be done first - protocolOrders := []mgmProto.FirewallRuleProtocol{ - mgmProto.FirewallRule_ALL, - mgmProto.FirewallRule_ICMP, - mgmProto.FirewallRule_TCP, - mgmProto.FirewallRule_UDP, + protocolOrders := []mgmProto.RuleProtocol{ + mgmProto.RuleProtocol_ALL, + mgmProto.RuleProtocol_ICMP, + mgmProto.RuleProtocol_TCP, + mgmProto.RuleProtocol_UDP, } - squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) { + squash := func(matches protoMatch, direction mgmProto.RuleDirection) { for _, protocol := range protocolOrders { if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 { // don't squash if : @@ -365,12 +450,12 @@ func (d *DefaultManager) squashAcceptRules( squashedRules = append(squashedRules, &mgmProto.FirewallRule{ PeerIP: "0.0.0.0", Direction: direction, - Action: mgmProto.FirewallRule_ACCEPT, + Action: mgmProto.RuleAction_ACCEPT, Protocol: protocol, }) squashedProtocols[protocol] = struct{}{} - if protocol == mgmProto.FirewallRule_ALL { + if protocol == mgmProto.RuleProtocol_ALL { // if we have ALL traffic type squashed rule // it allows all other type of traffic, so we can stop processing break @@ -378,11 +463,11 @@ func (d *DefaultManager) squashAcceptRules( } } - squash(in, mgmProto.FirewallRule_IN) - squash(out, mgmProto.FirewallRule_OUT) + squash(in, mgmProto.RuleDirection_IN) + squash(out, mgmProto.RuleDirection_OUT) // if all protocol was squashed everything is allow and we can ignore all other rules - if _, ok := squashedProtocols[mgmProto.FirewallRule_ALL]; ok { + if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok { return squashedRules, squashedProtocols } @@ -412,26 +497,26 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port) } -func (d *DefaultManager) rollBack(newRulePairs map[string][]firewall.Rule) { +func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) { log.Debugf("rollback ACL to previous state") for _, rules := range newRulePairs { for _, rule := range rules { - if err := d.firewall.DeleteRule(rule); err != nil { + if err := d.firewall.DeletePeerRule(rule); err != nil { log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err) } } } } -func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) (firewall.Protocol, error) { +func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) { switch protocol { - case mgmProto.FirewallRule_TCP: + case mgmProto.RuleProtocol_TCP: return firewall.ProtocolTCP, nil - case mgmProto.FirewallRule_UDP: + case mgmProto.RuleProtocol_UDP: return firewall.ProtocolUDP, nil - case mgmProto.FirewallRule_ICMP: + case mgmProto.RuleProtocol_ICMP: return firewall.ProtocolICMP, nil - case mgmProto.FirewallRule_ALL: + case mgmProto.RuleProtocol_ALL: return firewall.ProtocolALL, nil default: return firewall.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String()) @@ -442,13 +527,41 @@ func shouldSkipInvertedRule(protocol firewall.Protocol, port *firewall.Port) boo return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil } -func convertFirewallAction(action mgmProto.FirewallRuleAction) (firewall.Action, error) { +func convertFirewallAction(action mgmProto.RuleAction) (firewall.Action, error) { switch action { - case mgmProto.FirewallRule_ACCEPT: + case mgmProto.RuleAction_ACCEPT: return firewall.ActionAccept, nil - case mgmProto.FirewallRule_DROP: + case mgmProto.RuleAction_DROP: return firewall.ActionDrop, nil default: return firewall.ActionDrop, fmt.Errorf("invalid action type: %d", action) } } + +func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port { + if portInfo == nil { + return nil + } + + if portInfo.GetPort() != 0 { + return &firewall.Port{ + Values: []int{int(portInfo.GetPort())}, + } + } + + if portInfo.GetRange() != nil { + return &firewall.Port{ + IsRange: true, + Values: []int{int(portInfo.GetRange().Start), int(portInfo.GetRange().End)}, + } + } + + return nil +} + +func getDefault(prefix netip.Prefix) netip.Prefix { + if prefix.Addr().Is6() { + return netip.PrefixFrom(netip.IPv6Unspecified(), 0) + } + return netip.PrefixFrom(netip.IPv4Unspecified(), 0) +} diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 494d54bf2..7d999669a 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -9,8 +9,8 @@ import ( "github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/acl/mocks" - "github.com/netbirdio/netbird/iface" mgmProto "github.com/netbirdio/netbird/management/proto" ) @@ -19,16 +19,16 @@ func TestDefaultManager(t *testing.T) { FirewallRules: []*mgmProto.FirewallRule{ { PeerIP: "10.93.0.1", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_TCP, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, Port: "80", }, { PeerIP: "10.93.0.2", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_DROP, - Protocol: mgmProto.FirewallRule_UDP, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_DROP, + Protocol: mgmProto.RuleProtocol_UDP, Port: "53", }, }, @@ -65,16 +65,16 @@ func TestDefaultManager(t *testing.T) { t.Run("apply firewall rules", func(t *testing.T) { acl.ApplyFiltering(networkMap) - if len(acl.rulesPairs) != 2 { - t.Errorf("firewall rules not applied: %v", acl.rulesPairs) + if len(acl.peerRulesPairs) != 2 { + t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs) return } }) t.Run("add extra rules", func(t *testing.T) { existedPairs := map[string]struct{}{} - for id := range acl.rulesPairs { - existedPairs[id] = struct{}{} + for id := range acl.peerRulesPairs { + existedPairs[id.GetRuleID()] = struct{}{} } // remove first rule @@ -83,24 +83,24 @@ func TestDefaultManager(t *testing.T) { networkMap.FirewallRules, &mgmProto.FirewallRule{ PeerIP: "10.93.0.3", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_DROP, - Protocol: mgmProto.FirewallRule_ICMP, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_DROP, + Protocol: mgmProto.RuleProtocol_ICMP, }, ) acl.ApplyFiltering(networkMap) // we should have one old and one new rule in the existed rules - if len(acl.rulesPairs) != 2 { + if len(acl.peerRulesPairs) != 2 { t.Errorf("firewall rules not applied") return } // check that old rule was removed previousCount := 0 - for id := range acl.rulesPairs { - if _, ok := existedPairs[id]; ok { + for id := range acl.peerRulesPairs { + if _, ok := existedPairs[id.GetRuleID()]; ok { previousCount++ } } @@ -113,15 +113,15 @@ func TestDefaultManager(t *testing.T) { networkMap.FirewallRules = networkMap.FirewallRules[:0] networkMap.FirewallRulesIsEmpty = true - if acl.ApplyFiltering(networkMap); len(acl.rulesPairs) != 0 { - t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.rulesPairs)) + if acl.ApplyFiltering(networkMap); len(acl.peerRulesPairs) != 0 { + t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs)) return } networkMap.FirewallRulesIsEmpty = false acl.ApplyFiltering(networkMap) - if len(acl.rulesPairs) != 2 { - t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.rulesPairs)) + if len(acl.peerRulesPairs) != 2 { + t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs)) return } }) @@ -138,51 +138,51 @@ func TestDefaultManagerSquashRules(t *testing.T) { FirewallRules: []*mgmProto.FirewallRule{ { PeerIP: "10.93.0.1", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.2", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.3", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.4", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.1", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.2", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.3", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.4", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, }, } @@ -199,13 +199,13 @@ func TestDefaultManagerSquashRules(t *testing.T) { case r.PeerIP != "0.0.0.0": t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP) return - case r.Direction != mgmProto.FirewallRule_IN: + case r.Direction != mgmProto.RuleDirection_IN: t.Errorf("direction should be IN, got: %v", r.Direction) return - case r.Protocol != mgmProto.FirewallRule_ALL: + case r.Protocol != mgmProto.RuleProtocol_ALL: t.Errorf("protocol should be ALL, got: %v", r.Protocol) return - case r.Action != mgmProto.FirewallRule_ACCEPT: + case r.Action != mgmProto.RuleAction_ACCEPT: t.Errorf("action should be ACCEPT, got: %v", r.Action) return } @@ -215,13 +215,13 @@ func TestDefaultManagerSquashRules(t *testing.T) { case r.PeerIP != "0.0.0.0": t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP) return - case r.Direction != mgmProto.FirewallRule_OUT: + case r.Direction != mgmProto.RuleDirection_OUT: t.Errorf("direction should be OUT, got: %v", r.Direction) return - case r.Protocol != mgmProto.FirewallRule_ALL: + case r.Protocol != mgmProto.RuleProtocol_ALL: t.Errorf("protocol should be ALL, got: %v", r.Protocol) return - case r.Action != mgmProto.FirewallRule_ACCEPT: + case r.Action != mgmProto.RuleAction_ACCEPT: t.Errorf("action should be ACCEPT, got: %v", r.Action) return } @@ -238,51 +238,51 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) { FirewallRules: []*mgmProto.FirewallRule{ { PeerIP: "10.93.0.1", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.2", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.3", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.4", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_TCP, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, }, { PeerIP: "10.93.0.1", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.2", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.3", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.4", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_UDP, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_UDP, }, }, } @@ -308,21 +308,21 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { FirewallRules: []*mgmProto.FirewallRule{ { PeerIP: "10.93.0.1", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_TCP, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, }, { PeerIP: "10.93.0.2", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_TCP, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, }, { PeerIP: "10.93.0.3", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_UDP, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_UDP, }, }, } @@ -357,8 +357,8 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { acl.ApplyFiltering(networkMap) - if len(acl.rulesPairs) != 4 { - t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.rulesPairs)) + if len(acl.peerRulesPairs) != 4 { + t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.peerRulesPairs)) return } } diff --git a/client/internal/acl/mocks/iface_mapper.go b/client/internal/acl/mocks/iface_mapper.go index 621b29513..3ed12b6dd 100644 --- a/client/internal/acl/mocks/iface_mapper.go +++ b/client/internal/acl/mocks/iface_mapper.go @@ -8,7 +8,8 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - iface "github.com/netbirdio/netbird/iface" + iface "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" ) // MockIFaceMapper is a mock of IFaceMapper interface. @@ -77,7 +78,7 @@ func (mr *MockIFaceMapperMockRecorder) Name() *gomock.Call { } // SetFilter mocks base method. -func (m *MockIFaceMapper) SetFilter(arg0 iface.PacketFilter) error { +func (m *MockIFaceMapper) SetFilter(arg0 device.PacketFilter) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SetFilter", arg0) ret0, _ := ret[0].(error) diff --git a/client/internal/config.go b/client/internal/config.go index 1df1e0547..ee54c6380 100644 --- a/client/internal/config.go +++ b/client/internal/config.go @@ -16,9 +16,9 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/client/ssh" - "github.com/netbirdio/netbird/iface" mgm "github.com/netbirdio/netbird/management/client" "github.com/netbirdio/netbird/util" ) diff --git a/client/internal/connect.go b/client/internal/connect.go index 36b340cfb..74dc1f1b5 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -17,13 +17,14 @@ import ( "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" - "github.com/netbirdio/netbird/iface" mgm "github.com/netbirdio/netbird/management/client" mgmProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/relay/auth/hmac" @@ -70,7 +71,7 @@ func (c *ConnectClient) RunWithProbes( // RunOnAndroid with main logic on mobile system func (c *ConnectClient) RunOnAndroid( - tunAdapter iface.TunAdapter, + tunAdapter device.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, networkChangeListener listener.NetworkChangeListener, dnsAddresses []string, @@ -205,7 +206,7 @@ func (c *ConnectClient) run( localPeerState := peer.LocalPeerState{ IP: loginResp.GetPeerConfig().GetAddress(), PubKey: myPrivateKey.PublicKey().String(), - KernelInterface: iface.WireGuardModuleIsLoaded(), + KernelInterface: device.WireGuardModuleIsLoaded(), FQDN: loginResp.GetPeerConfig().GetFqdn(), } c.statusRecorder.UpdateLocalPeerState(localPeerState) @@ -268,12 +269,6 @@ func (c *ConnectClient) run( checks := loginResp.GetChecks() c.engineMutex.Lock() - if c.engine != nil && c.engine.ctx.Err() != nil { - log.Info("Stopping Netbird Engine") - if err := c.engine.Stop(); err != nil { - log.Errorf("Failed to stop engine: %v", err) - } - } c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks) c.engineMutex.Unlock() @@ -293,6 +288,15 @@ func (c *ConnectClient) run( } <-engineCtx.Done() + c.engineMutex.Lock() + if c.engine != nil && c.engine.wgInterface != nil { + log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name()) + if err := c.engine.Stop(); err != nil { + log.Errorf("Failed to stop engine: %v", err) + } + c.engine = nil + } + c.engineMutex.Unlock() c.statusRecorder.ClientTeardown() backOff.Reset() diff --git a/client/internal/dns/response_writer_test.go b/client/internal/dns/response_writer_test.go index 5a0047700..857964406 100644 --- a/client/internal/dns/response_writer_test.go +++ b/client/internal/dns/response_writer_test.go @@ -9,7 +9,7 @@ import ( "github.com/google/gopacket/layers" "github.com/miekg/dns" - "github.com/netbirdio/netbird/iface/mocks" + "github.com/netbirdio/netbird/client/iface/mocks" ) func TestResponseWriterLocalAddr(t *testing.T) { diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index b9552bc17..53d18a678 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -15,16 +15,18 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/firewall/uspfilter" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/device" + pfmock "github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/stdnet" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/formatter" - "github.com/netbirdio/netbird/iface" - pfmock "github.com/netbirdio/netbird/iface/mocks" ) type mocWGIface struct { - filter iface.PacketFilter + filter device.PacketFilter } func (w *mocWGIface) Name() string { @@ -43,11 +45,11 @@ func (w *mocWGIface) ToInterface() *net.Interface { panic("implement me") } -func (w *mocWGIface) GetFilter() iface.PacketFilter { +func (w *mocWGIface) GetFilter() device.PacketFilter { return w.filter } -func (w *mocWGIface) GetDevice() *iface.DeviceWrapper { +func (w *mocWGIface) GetDevice() *device.FilteredDevice { panic("implement me") } @@ -59,13 +61,13 @@ func (w *mocWGIface) IsUserspaceBind() bool { return false } -func (w *mocWGIface) SetFilter(filter iface.PacketFilter) error { +func (w *mocWGIface) SetFilter(filter device.PacketFilter) error { w.filter = filter return nil } -func (w *mocWGIface) GetStats(_ string) (iface.WGStats, error) { - return iface.WGStats{}, nil +func (w *mocWGIface) GetStats(_ string) (configurer.WGStats, error) { + return configurer.WGStats{}, nil } var zoneRecords = []nbdns.SimpleRecord{ diff --git a/client/internal/dns/wgiface.go b/client/internal/dns/wgiface.go index 2f08e8d52..69bc83659 100644 --- a/client/internal/dns/wgiface.go +++ b/client/internal/dns/wgiface.go @@ -5,7 +5,9 @@ package dns import ( "net" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/device" ) // WGIface defines subset methods of interface required for manager @@ -14,7 +16,7 @@ type WGIface interface { Address() iface.WGAddress ToInterface() *net.Interface IsUserspaceBind() bool - GetFilter() iface.PacketFilter - GetDevice() *iface.DeviceWrapper - GetStats(peerKey string) (iface.WGStats, error) + GetFilter() device.PacketFilter + GetDevice() *device.FilteredDevice + GetStats(peerKey string) (configurer.WGStats, error) } diff --git a/client/internal/dns/wgiface_windows.go b/client/internal/dns/wgiface_windows.go index f8bb80fb9..765132fdb 100644 --- a/client/internal/dns/wgiface_windows.go +++ b/client/internal/dns/wgiface_windows.go @@ -1,14 +1,18 @@ package dns -import "github.com/netbirdio/netbird/iface" +import ( + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/device" +) // WGIface defines subset methods of interface required for manager type WGIface interface { Name() string Address() iface.WGAddress IsUserspaceBind() bool - GetFilter() iface.PacketFilter - GetDevice() *iface.DeviceWrapper - GetStats(peerKey string) (iface.WGStats, error) + GetFilter() device.PacketFilter + GetDevice() *device.FilteredDevice + GetStats(peerKey string) (configurer.WGStats, error) GetInterfaceGUIDString() (string, error) } diff --git a/client/internal/engine.go b/client/internal/engine.go index 463507ad8..eac8ec098 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -23,9 +23,12 @@ import ( "github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/relay" @@ -36,8 +39,6 @@ import ( nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/iface" - "github.com/netbirdio/netbird/iface/bind" mgm "github.com/netbirdio/netbird/management/client" "github.com/netbirdio/netbird/management/domain" mgmProto "github.com/netbirdio/netbird/management/proto" @@ -250,6 +251,13 @@ func (e *Engine) Stop() error { } log.Info("Network monitor: stopped") + // stop/restore DNS first so dbus and friends don't complain because of a missing interface + e.stopDNSServer() + + if e.routeManager != nil { + e.routeManager.Stop() + } + err := e.removeAllPeers() if err != nil { return fmt.Errorf("failed to remove all peers: %s", err) @@ -619,7 +627,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { e.statusRecorder.UpdateLocalPeerState(peer.LocalPeerState{ IP: e.config.WgAddr, PubKey: e.config.WgPrivateKey.PublicKey().String(), - KernelInterface: iface.WireGuardModuleIsLoaded(), + KernelInterface: device.WireGuardModuleIsLoaded(), FQDN: conf.GetFqdn(), }) @@ -704,6 +712,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { return nil } + // Apply ACLs in the beginning to avoid security leaks + if e.acl != nil { + e.acl.ApplyFiltering(networkMap) + } + protoRoutes := networkMap.GetRoutes() if protoRoutes == nil { protoRoutes = []*mgmProto.Route{} @@ -770,10 +783,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { log.Errorf("failed to update dns server, err: %v", err) } - if e.acl != nil { - e.acl.ApplyFiltering(networkMap) - } - e.networkSerial = serial // Test received (upstream) servers for availability right away instead of upon usage. @@ -1114,18 +1123,12 @@ func (e *Engine) close() { } } - // stop/restore DNS first so dbus and friends don't complain because of a missing interface - e.stopDNSServer() - - if e.routeManager != nil { - e.routeManager.Stop() - } - log.Debugf("removing Netbird interface %s", e.config.WgIfaceName) if e.wgInterface != nil { if err := e.wgInterface.Close(); err != nil { log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err) } + e.wgInterface = nil } if !isNil(e.sshServer) { @@ -1164,15 +1167,15 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) { log.Errorf("failed to create pion's stdnet: %s", err) } - var mArgs *iface.MobileIFaceArguments + var mArgs *device.MobileIFaceArguments switch runtime.GOOS { case "android": - mArgs = &iface.MobileIFaceArguments{ + mArgs = &device.MobileIFaceArguments{ TunAdapter: e.mobileDep.TunAdapter, TunFd: int(e.mobileDep.FileDescriptor), } case "ios": - mArgs = &iface.MobileIFaceArguments{ + mArgs = &device.MobileIFaceArguments{ TunFd: int(e.mobileDep.FileDescriptor), } default: @@ -1393,7 +1396,7 @@ func (e *Engine) startNetworkMonitor() { } // Set a new timer to debounce rapid network changes - debounceTimer = time.AfterFunc(1*time.Second, func() { + debounceTimer = time.AfterFunc(2*time.Second, func() { // This function is called after the debounce period mu.Lock() defer mu.Unlock() @@ -1424,6 +1427,11 @@ func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) { } func (e *Engine) stopDNSServer() { + if e.dnsServer == nil { + return + } + e.dnsServer.Stop() + e.dnsServer = nil err := fmt.Errorf("DNS server stopped") nsGroupStates := e.statusRecorder.GetDNSStates() for i := range nsGroupStates { @@ -1431,10 +1439,6 @@ func (e *Engine) stopDNSServer() { nsGroupStates[i].Error = err } e.statusRecorder.UpdateDNSStates(nsGroupStates) - if e.dnsServer != nil { - e.dnsServer.Stop() - e.dnsServer = nil - } } // isChecksEqual checks if two slices of checks are equal. diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index f30566380..74b10ee44 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -6,7 +6,6 @@ import ( "net" "net/netip" "os" - "path/filepath" "runtime" "strings" "sync" @@ -25,14 +24,15 @@ import ( "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/iface" - "github.com/netbirdio/netbird/iface/bind" mgmt "github.com/netbirdio/netbird/management/client" mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server" @@ -823,20 +823,6 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { func TestEngine_MultiplePeers(t *testing.T) { // log.SetLevel(log.DebugLevel) - dir := t.TempDir() - - err := util.CopyFileContents("../testdata/store.json", filepath.Join(dir, "store.json")) - if err != nil { - t.Fatal(err) - } - defer func() { - err = os.Remove(filepath.Join(dir, "store.json")) //nolint - if err != nil { - t.Fatal(err) - return - } - }() - ctx, cancel := context.WithCancel(CtxInitState(context.Background())) defer cancel() @@ -846,7 +832,7 @@ func TestEngine_MultiplePeers(t *testing.T) { return } defer sigServer.Stop() - mgmtServer, mgmtAddr, err := startManagement(t, dir) + mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sql") if err != nil { t.Fatal(err) return @@ -874,7 +860,7 @@ func TestEngine_MultiplePeers(t *testing.T) { mu.Lock() defer mu.Unlock() guid := fmt.Sprintf("{%s}", uuid.New().String()) - iface.CustomWindowsGUIDString = strings.ToLower(guid) + device.CustomWindowsGUIDString = strings.ToLower(guid) err = engine.Start() if err != nil { t.Errorf("unable to start engine for peer %d with error %v", j, err) @@ -1056,7 +1042,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) { log.Fatalf("failed to listen: %v", err) } - srv, err := signalServer.NewServer(otel.Meter("")) + srv, err := signalServer.NewServer(context.Background(), otel.Meter("")) require.NoError(t, err) proto.RegisterSignalExchangeServer(s, srv) @@ -1069,7 +1055,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) { return s, lis.Addr().String(), nil } -func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error) { +func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) { t.Helper() config := &server.Config{ @@ -1094,7 +1080,7 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error) } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir) + store, cleanUp, err := server.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir) if err != nil { return nil, "", err } diff --git a/client/internal/mobile_dependency.go b/client/internal/mobile_dependency.go index 2355c67c3..2b0c92cc6 100644 --- a/client/internal/mobile_dependency.go +++ b/client/internal/mobile_dependency.go @@ -1,16 +1,16 @@ package internal import ( + "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/stdnet" - "github.com/netbirdio/netbird/iface" ) // MobileDependency collect all dependencies for mobile platform type MobileDependency struct { // Android only - TunAdapter iface.TunAdapter + TunAdapter device.TunAdapter IFaceDiscover stdnet.ExternalIFaceDiscover NetworkChangeListener listener.NetworkChangeListener HostDNSAddresses []string diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index ea6d892b9..1b740388d 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -15,9 +15,10 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/wgproxy" - "github.com/netbirdio/netbird/iface" relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" nbnet "github.com/netbirdio/netbird/util/net" @@ -31,6 +32,8 @@ const ( connPriorityRelay ConnPriority = 1 connPriorityICETurn ConnPriority = 1 connPriorityICEP2P ConnPriority = 2 + + reconnectMaxElapsedTime = 30 * time.Minute ) type WgConfig struct { @@ -79,9 +82,8 @@ type Conn struct { config ConnConfig statusRecorder *Status wgProxyFactory *wgproxy.Factory - wgProxyICE wgproxy.Proxy - wgProxyRelay wgproxy.Proxy signaler *Signaler + iFaceDiscover stdnet.ExternalIFaceDiscover relayManager *relayClient.Manager allowedIPsIP string handshaker *Handshaker @@ -102,11 +104,14 @@ type Conn struct { beforeAddPeerHooks []nbnet.AddHookFunc afterRemovePeerHooks []nbnet.RemoveHookFunc - endpointRelay *net.UDPAddr + wgProxyICE wgproxy.Proxy + wgProxyRelay wgproxy.Proxy // for reconnection operations iCEDisconnected chan bool relayDisconnected chan bool + connMonitor *ConnMonitor + reconnectCh <-chan struct{} } // NewConn creates a new not opened Conn to the remote peer. @@ -122,21 +127,31 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu connLog := log.WithField("peer", config.Key) var conn = &Conn{ - log: connLog, - ctx: ctx, - ctxCancel: ctxCancel, - config: config, - statusRecorder: statusRecorder, - wgProxyFactory: wgProxyFactory, - signaler: signaler, - relayManager: relayManager, - allowedIPsIP: allowedIPsIP.String(), - statusRelay: NewAtomicConnStatus(), - statusICE: NewAtomicConnStatus(), + log: connLog, + ctx: ctx, + ctxCancel: ctxCancel, + config: config, + statusRecorder: statusRecorder, + wgProxyFactory: wgProxyFactory, + signaler: signaler, + iFaceDiscover: iFaceDiscover, + relayManager: relayManager, + allowedIPsIP: allowedIPsIP.String(), + statusRelay: NewAtomicConnStatus(), + statusICE: NewAtomicConnStatus(), + iCEDisconnected: make(chan bool, 1), relayDisconnected: make(chan bool, 1), } + conn.connMonitor, conn.reconnectCh = NewConnMonitor( + signaler, + iFaceDiscover, + config, + conn.relayDisconnected, + conn.iCEDisconnected, + ) + rFns := WorkerRelayCallbacks{ OnConnReady: conn.relayConnectionIsReady, OnDisconnected: conn.onWorkerRelayStateDisconnected, @@ -199,6 +214,8 @@ func (conn *Conn) startHandshakeAndReconnect() { conn.log.Errorf("failed to send initial offer: %v", err) } + go conn.connMonitor.Start(conn.ctx) + if conn.workerRelay.IsController() { conn.reconnectLoopWithRetry() } else { @@ -239,8 +256,7 @@ func (conn *Conn) Close() { conn.wgProxyICE = nil } - err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) - if err != nil { + if err := conn.removeWgPeer(); err != nil { conn.log.Errorf("failed to remove wg endpoint: %v", err) } @@ -308,12 +324,14 @@ func (conn *Conn) reconnectLoopWithRetry() { // With it, we can decrease to send necessary offer select { case <-conn.ctx.Done(): + return case <-time.After(3 * time.Second): } ticker := conn.prepareExponentTicker() defer ticker.Stop() time.Sleep(1 * time.Second) + for { select { case t := <-ticker.C: @@ -341,20 +359,11 @@ func (conn *Conn) reconnectLoopWithRetry() { if err != nil { conn.log.Errorf("failed to do handshake: %v", err) } - case changed := <-conn.relayDisconnected: - if !changed { - continue - } - conn.log.Debugf("Relay state changed, reset reconnect timer") - ticker.Stop() - ticker = conn.prepareExponentTicker() - case changed := <-conn.iCEDisconnected: - if !changed { - continue - } - conn.log.Debugf("ICE state changed, reset reconnect timer") + + case <-conn.reconnectCh: ticker.Stop() ticker = conn.prepareExponentTicker() + case <-conn.ctx.Done(): conn.log.Debugf("context is done, stop reconnect loop") return @@ -365,10 +374,10 @@ func (conn *Conn) reconnectLoopWithRetry() { func (conn *Conn) prepareExponentTicker() *backoff.Ticker { bo := backoff.WithContext(&backoff.ExponentialBackOff{ InitialInterval: 800 * time.Millisecond, - RandomizationFactor: 0.01, + RandomizationFactor: 0.1, Multiplier: 2, MaxInterval: conn.config.Timeout, - MaxElapsedTime: 0, + MaxElapsedTime: reconnectMaxElapsedTime, Stop: backoff.Stop, Clock: backoff.SystemClock, }, conn.ctx) @@ -419,54 +428,59 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon conn.log.Debugf("ICE connection is ready") - conn.statusICE.Set(StatusConnected) - - defer conn.updateIceState(iceConnInfo) - if conn.currentConnPriority > priority { + conn.statusICE.Set(StatusConnected) + conn.updateIceState(iceConnInfo) return } conn.log.Infof("set ICE to active connection") - endpoint, wgProxy, err := conn.getEndpointForICEConnInfo(iceConnInfo) - if err != nil { - return + var ( + ep *net.UDPAddr + wgProxy wgproxy.Proxy + err error + ) + if iceConnInfo.RelayedOnLocal { + wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn) + if err != nil { + conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) + return + } + ep = wgProxy.EndpointAddr() + conn.wgProxyICE = wgProxy + } else { + directEp, err := net.ResolveUDPAddr("udp", iceConnInfo.RemoteConn.RemoteAddr().String()) + if err != nil { + log.Errorf("failed to resolveUDPaddr") + conn.handleConfigurationFailure(err, nil) + return + } + ep = directEp } - endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) - conn.log.Debugf("Conn resolved IP is %s for endopint %s", endpoint, endpointUdpAddr.IP) - - conn.connIDICE = nbnet.GenerateConnID() - for _, hook := range conn.beforeAddPeerHooks { - if err := hook(conn.connIDICE, endpointUdpAddr.IP); err != nil { - conn.log.Errorf("Before add peer hook failed: %v", err) - } + if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil { + conn.log.Errorf("Before add peer hook failed: %v", err) } conn.workerRelay.DisableWgWatcher() - err = conn.configureWGEndpoint(endpointUdpAddr) - if err != nil { - if wgProxy != nil { - if err := wgProxy.CloseConn(); err != nil { - conn.log.Warnf("Failed to close turn connection: %v", err) - } - } - conn.log.Warnf("Failed to update wg peer configuration: %v", err) + if conn.wgProxyRelay != nil { + conn.wgProxyRelay.Pause() + } + + if wgProxy != nil { + wgProxy.Work() + } + + if err = conn.configureWGEndpoint(ep); err != nil { + conn.handleConfigurationFailure(err, wgProxy) return } wgConfigWorkaround() - - if conn.wgProxyICE != nil { - if err := conn.wgProxyICE.CloseConn(); err != nil { - conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) - } - } - conn.wgProxyICE = wgProxy - conn.currentConnPriority = priority - + conn.statusICE.Set(StatusConnected) + conn.updateIceState(iceConnInfo) conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr) } @@ -481,11 +495,18 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { conn.log.Tracef("ICE connection state changed to %s", newState) + if conn.wgProxyICE != nil { + if err := conn.wgProxyICE.CloseConn(); err != nil { + conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) + } + } + // switch back to relay connection - if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay { + if conn.isReadyToUpgrade() { conn.log.Debugf("ICE disconnected, set Relay to active connection") - err := conn.configureWGEndpoint(conn.endpointRelay) - if err != nil { + conn.wgProxyRelay.Work() + + if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil { conn.log.Errorf("failed to switch to relay conn: %v", err) } conn.workerRelay.EnableWgWatcher(conn.ctx) @@ -495,10 +516,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { changed := conn.statusICE.Get() != newState && newState != StatusConnecting conn.statusICE.Set(newState) - select { - case conn.iCEDisconnected <- changed: - default: - } + conn.notifyReconnectLoopICEDisconnected(changed) peerState := State{ PubKey: conn.config.Key, @@ -519,61 +537,48 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { if conn.ctx.Err() != nil { if err := rci.relayedConn.Close(); err != nil { - log.Warnf("failed to close unnecessary relayed connection: %v", err) + conn.log.Warnf("failed to close unnecessary relayed connection: %v", err) } return } - conn.log.Debugf("Relay connection is ready to use") - conn.statusRelay.Set(StatusConnected) + conn.log.Debugf("Relay connection has been established, setup the WireGuard") - wgProxy := conn.wgProxyFactory.GetProxy() - endpoint, err := wgProxy.AddTurnConn(conn.ctx, rci.relayedConn) + wgProxy, err := conn.newProxy(rci.relayedConn) if err != nil { conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err) return } - conn.log.Infof("created new wgProxy for relay connection: %s", endpoint) - endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) - conn.endpointRelay = endpointUdpAddr - conn.log.Debugf("conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP) + conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String()) - defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) - - if conn.currentConnPriority > connPriorityRelay { - if conn.statusICE.Get() == StatusConnected { - log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority) - return - } + if conn.iceP2PIsActive() { + conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority) + conn.wgProxyRelay = wgProxy + conn.statusRelay.Set(StatusConnected) + conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) + return } - conn.connIDRelay = nbnet.GenerateConnID() - for _, hook := range conn.beforeAddPeerHooks { - if err := hook(conn.connIDRelay, endpointUdpAddr.IP); err != nil { - conn.log.Errorf("Before add peer hook failed: %v", err) - } + if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil { + conn.log.Errorf("Before add peer hook failed: %v", err) } - err = conn.configureWGEndpoint(endpointUdpAddr) - if err != nil { + wgProxy.Work() + if err := conn.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil { if err := wgProxy.CloseConn(); err != nil { conn.log.Warnf("Failed to close relay connection: %v", err) } - conn.log.Errorf("Failed to update wg peer configuration: %v", err) + conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err) return } conn.workerRelay.EnableWgWatcher(conn.ctx) + wgConfigWorkaround() - - if conn.wgProxyRelay != nil { - if err := conn.wgProxyRelay.CloseConn(); err != nil { - conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) - } - } - conn.wgProxyRelay = wgProxy conn.currentConnPriority = connPriorityRelay - + conn.statusRelay.Set(StatusConnected) + conn.wgProxyRelay = wgProxy + conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) conn.log.Infof("start to communicate with peer via relay") conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr) } @@ -586,25 +591,23 @@ func (conn *Conn) onWorkerRelayStateDisconnected() { return } - if conn.wgProxyRelay != nil { - log.Debugf("relayed connection is closed, clean up WireGuard config") - err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) - if err != nil { + conn.log.Debugf("relay connection is disconnected") + + if conn.currentConnPriority == connPriorityRelay { + conn.log.Debugf("clean up WireGuard config") + if err := conn.removeWgPeer(); err != nil { conn.log.Errorf("failed to remove wg endpoint: %v", err) } + } - conn.endpointRelay = nil + if conn.wgProxyRelay != nil { _ = conn.wgProxyRelay.CloseConn() conn.wgProxyRelay = nil } changed := conn.statusRelay.Get() != StatusDisconnected conn.statusRelay.Set(StatusDisconnected) - - select { - case conn.relayDisconnected <- changed: - default: - } + conn.notifyReconnectLoopRelayDisconnected(changed) peerState := State{ PubKey: conn.config.Key, @@ -612,9 +615,7 @@ func (conn *Conn) onWorkerRelayStateDisconnected() { Relayed: conn.isRelayed(), ConnStatusUpdate: time.Now(), } - - err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState) - if err != nil { + if err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState); err != nil { conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err) } } @@ -680,7 +681,7 @@ func (conn *Conn) setStatusToDisconnected() { // todo rethink status updates conn.log.Debugf("error while updating peer's state, err: %v", err) } - if err := conn.statusRecorder.UpdateWireGuardPeerState(conn.config.Key, iface.WGStats{}); err != nil { + if err := conn.statusRecorder.UpdateWireGuardPeerState(conn.config.Key, configurer.WGStats{}); err != nil { conn.log.Debugf("failed to reset wireguard stats for peer: %s", err) } } @@ -750,6 +751,16 @@ func (conn *Conn) isConnected() bool { return true } +func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error { + conn.connIDICE = nbnet.GenerateConnID() + for _, hook := range conn.beforeAddPeerHooks { + if err := hook(conn.connIDICE, ip); err != nil { + return err + } + } + return nil +} + func (conn *Conn) freeUpConnID() { if conn.connIDRelay != "" { for _, hook := range conn.afterRemovePeerHooks { @@ -770,21 +781,52 @@ func (conn *Conn) freeUpConnID() { } } -func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr, wgproxy.Proxy, error) { - if !iceConnInfo.RelayedOnLocal { - return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil - } - conn.log.Debugf("setup ice turn connection") +func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) { + conn.log.Debugf("setup proxied WireGuard connection") wgProxy := conn.wgProxyFactory.GetProxy() - ep, err := wgProxy.AddTurnConn(conn.ctx, iceConnInfo.RemoteConn) - if err != nil { + if err := wgProxy.AddTurnConn(conn.ctx, remoteConn); err != nil { conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) - if errClose := wgProxy.CloseConn(); errClose != nil { - conn.log.Warnf("failed to close turn proxy connection: %v", errClose) - } - return nil, nil, err + return nil, err + } + return wgProxy, nil +} + +func (conn *Conn) isReadyToUpgrade() bool { + return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay +} + +func (conn *Conn) iceP2PIsActive() bool { + return conn.currentConnPriority == connPriorityICEP2P && conn.statusICE.Get() == StatusConnected +} + +func (conn *Conn) removeWgPeer() error { + return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) +} + +func (conn *Conn) notifyReconnectLoopRelayDisconnected(changed bool) { + select { + case conn.relayDisconnected <- changed: + default: + } +} + +func (conn *Conn) notifyReconnectLoopICEDisconnected(changed bool) { + select { + case conn.iCEDisconnected <- changed: + default: + } +} + +func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) { + conn.log.Warnf("Failed to update wg peer configuration: %v", err) + if wgProxy != nil { + if ierr := wgProxy.CloseConn(); ierr != nil { + conn.log.Warnf("Failed to close wg proxy: %v", ierr) + } + } + if conn.wgProxyRelay != nil { + conn.wgProxyRelay.Work() } - return ep, wgProxy, nil } func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool { diff --git a/client/internal/peer/conn_monitor.go b/client/internal/peer/conn_monitor.go new file mode 100644 index 000000000..75722c990 --- /dev/null +++ b/client/internal/peer/conn_monitor.go @@ -0,0 +1,212 @@ +package peer + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/pion/ice/v3" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/stdnet" +) + +const ( + signalerMonitorPeriod = 5 * time.Second + candidatesMonitorPeriod = 5 * time.Minute + candidateGatheringTimeout = 5 * time.Second +) + +type ConnMonitor struct { + signaler *Signaler + iFaceDiscover stdnet.ExternalIFaceDiscover + config ConnConfig + relayDisconnected chan bool + iCEDisconnected chan bool + reconnectCh chan struct{} + currentCandidates []ice.Candidate + candidatesMu sync.Mutex +} + +func NewConnMonitor(signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, config ConnConfig, relayDisconnected, iCEDisconnected chan bool) (*ConnMonitor, <-chan struct{}) { + reconnectCh := make(chan struct{}, 1) + cm := &ConnMonitor{ + signaler: signaler, + iFaceDiscover: iFaceDiscover, + config: config, + relayDisconnected: relayDisconnected, + iCEDisconnected: iCEDisconnected, + reconnectCh: reconnectCh, + } + return cm, reconnectCh +} + +func (cm *ConnMonitor) Start(ctx context.Context) { + signalerReady := make(chan struct{}, 1) + go cm.monitorSignalerReady(ctx, signalerReady) + + localCandidatesChanged := make(chan struct{}, 1) + go cm.monitorLocalCandidatesChanged(ctx, localCandidatesChanged) + + for { + select { + case changed := <-cm.relayDisconnected: + if !changed { + continue + } + log.Debugf("Relay state changed, triggering reconnect") + cm.triggerReconnect() + + case changed := <-cm.iCEDisconnected: + if !changed { + continue + } + log.Debugf("ICE state changed, triggering reconnect") + cm.triggerReconnect() + + case <-signalerReady: + log.Debugf("Signaler became ready, triggering reconnect") + cm.triggerReconnect() + + case <-localCandidatesChanged: + log.Debugf("Local candidates changed, triggering reconnect") + cm.triggerReconnect() + + case <-ctx.Done(): + return + } + } +} + +func (cm *ConnMonitor) monitorSignalerReady(ctx context.Context, signalerReady chan<- struct{}) { + if cm.signaler == nil { + return + } + + ticker := time.NewTicker(signalerMonitorPeriod) + defer ticker.Stop() + + lastReady := true + for { + select { + case <-ticker.C: + currentReady := cm.signaler.Ready() + if !lastReady && currentReady { + select { + case signalerReady <- struct{}{}: + default: + } + } + lastReady = currentReady + case <-ctx.Done(): + return + } + } +} + +func (cm *ConnMonitor) monitorLocalCandidatesChanged(ctx context.Context, localCandidatesChanged chan<- struct{}) { + ufrag, pwd, err := generateICECredentials() + if err != nil { + log.Warnf("Failed to generate ICE credentials: %v", err) + return + } + + ticker := time.NewTicker(candidatesMonitorPeriod) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if err := cm.handleCandidateTick(ctx, localCandidatesChanged, ufrag, pwd); err != nil { + log.Warnf("Failed to handle candidate tick: %v", err) + } + case <-ctx.Done(): + return + } + } +} + +func (cm *ConnMonitor) handleCandidateTick(ctx context.Context, localCandidatesChanged chan<- struct{}, ufrag string, pwd string) error { + log.Debugf("Gathering ICE candidates") + + transportNet, err := newStdNet(cm.iFaceDiscover, cm.config.ICEConfig.InterfaceBlackList) + if err != nil { + log.Errorf("failed to create pion's stdnet: %s", err) + } + + agent, err := newAgent(cm.config, transportNet, candidateTypesP2P(), ufrag, pwd) + if err != nil { + return fmt.Errorf("create ICE agent: %w", err) + } + defer func() { + if err := agent.Close(); err != nil { + log.Warnf("Failed to close ICE agent: %v", err) + } + }() + + gatherDone := make(chan struct{}) + err = agent.OnCandidate(func(c ice.Candidate) { + log.Tracef("Got candidate: %v", c) + if c == nil { + close(gatherDone) + } + }) + if err != nil { + return fmt.Errorf("set ICE candidate handler: %w", err) + } + + if err := agent.GatherCandidates(); err != nil { + return fmt.Errorf("gather ICE candidates: %w", err) + } + + ctx, cancel := context.WithTimeout(ctx, candidateGatheringTimeout) + defer cancel() + + select { + case <-ctx.Done(): + return fmt.Errorf("wait for gathering: %w", ctx.Err()) + case <-gatherDone: + } + + candidates, err := agent.GetLocalCandidates() + if err != nil { + return fmt.Errorf("get local candidates: %w", err) + } + log.Tracef("Got candidates: %v", candidates) + + if changed := cm.updateCandidates(candidates); changed { + select { + case localCandidatesChanged <- struct{}{}: + default: + } + } + + return nil +} + +func (cm *ConnMonitor) updateCandidates(newCandidates []ice.Candidate) bool { + cm.candidatesMu.Lock() + defer cm.candidatesMu.Unlock() + + if len(cm.currentCandidates) != len(newCandidates) { + cm.currentCandidates = newCandidates + return true + } + + for i, candidate := range cm.currentCandidates { + if candidate.Address() != newCandidates[i].Address() { + cm.currentCandidates = newCandidates + return true + } + } + + return false +} + +func (cm *ConnMonitor) triggerReconnect() { + select { + case cm.reconnectCh <- struct{}{}: + default: + } +} diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go index 22e5409f8..b4926a9d2 100644 --- a/client/internal/peer/conn_test.go +++ b/client/internal/peer/conn_test.go @@ -9,9 +9,9 @@ import ( "github.com/magiconair/properties/assert" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/wgproxy" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/util" ) diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 915fa63f0..a28992fac 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -11,8 +11,8 @@ import ( "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" + "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/internal/relay" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/management/domain" relayClient "github.com/netbirdio/netbird/relay/client" ) @@ -203,7 +203,7 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) { state, ok := d.peers[peerPubKey] if !ok { - return State{}, iface.ErrPeerNotFound + return State{}, configurer.ErrPeerNotFound } return state, nil } @@ -412,7 +412,7 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error { } // UpdateWireGuardPeerState updates the WireGuard bits of the peer state -func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats iface.WGStats) error { +func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats configurer.WGStats) error { d.mux.Lock() defer d.mux.Unlock() diff --git a/client/internal/peer/stdnet.go b/client/internal/peer/stdnet.go index ae31ebbf0..96d211dbc 100644 --- a/client/internal/peer/stdnet.go +++ b/client/internal/peer/stdnet.go @@ -6,6 +6,6 @@ import ( "github.com/netbirdio/netbird/client/internal/stdnet" ) -func (w *WorkerICE) newStdNet() (*stdnet.Net, error) { - return stdnet.NewNet(w.config.ICEConfig.InterfaceBlackList) +func newStdNet(_ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) { + return stdnet.NewNet(ifaceBlacklist) } diff --git a/client/internal/peer/stdnet_android.go b/client/internal/peer/stdnet_android.go index b411405bb..a39a03b1c 100644 --- a/client/internal/peer/stdnet_android.go +++ b/client/internal/peer/stdnet_android.go @@ -2,6 +2,6 @@ package peer import "github.com/netbirdio/netbird/client/internal/stdnet" -func (w *WorkerICE) newStdNet() (*stdnet.Net, error) { - return stdnet.NewNetWithDiscover(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList) +func newStdNet(iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) { + return stdnet.NewNetWithDiscover(iFaceDiscover, ifaceBlacklist) } diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 8bf1b7568..c86c1858f 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -15,9 +15,9 @@ import ( "github.com/pion/stun/v2" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/internal/stdnet" - "github.com/netbirdio/netbird/iface" - "github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/route" ) @@ -233,41 +233,16 @@ func (w *WorkerICE) Close() { } func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, relaySupport []ice.CandidateType) (*ice.Agent, error) { - transportNet, err := w.newStdNet() + transportNet, err := newStdNet(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList) if err != nil { w.log.Errorf("failed to create pion's stdnet: %s", err) } - iceKeepAlive := iceKeepAlive() - iceDisconnectedTimeout := iceDisconnectedTimeout() - iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait() - - agentConfig := &ice.AgentConfig{ - MulticastDNSMode: ice.MulticastDNSModeDisabled, - NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}, - Urls: w.config.ICEConfig.StunTurn.Load().([]*stun.URI), - CandidateTypes: relaySupport, - InterfaceFilter: stdnet.InterfaceFilter(w.config.ICEConfig.InterfaceBlackList), - UDPMux: w.config.ICEConfig.UDPMux, - UDPMuxSrflx: w.config.ICEConfig.UDPMuxSrflx, - NAT1To1IPs: w.config.ICEConfig.NATExternalIPs, - Net: transportNet, - FailedTimeout: &failedTimeout, - DisconnectedTimeout: &iceDisconnectedTimeout, - KeepaliveInterval: &iceKeepAlive, - RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait, - LocalUfrag: w.localUfrag, - LocalPwd: w.localPwd, - } - - if w.config.ICEConfig.DisableIPv6Discovery { - agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4} - } - w.sentExtraSrflx = false - agent, err := ice.NewAgent(agentConfig) + + agent, err := newAgent(w.config, transportNet, relaySupport, w.localUfrag, w.localPwd) if err != nil { - return nil, err + return nil, fmt.Errorf("create agent: %w", err) } err = agent.OnCandidate(w.onICECandidate) @@ -390,6 +365,36 @@ func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferA } } +func newAgent(config ConnConfig, transportNet *stdnet.Net, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ice.Agent, error) { + iceKeepAlive := iceKeepAlive() + iceDisconnectedTimeout := iceDisconnectedTimeout() + iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait() + + agentConfig := &ice.AgentConfig{ + MulticastDNSMode: ice.MulticastDNSModeDisabled, + NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}, + Urls: config.ICEConfig.StunTurn.Load().([]*stun.URI), + CandidateTypes: candidateTypes, + InterfaceFilter: stdnet.InterfaceFilter(config.ICEConfig.InterfaceBlackList), + UDPMux: config.ICEConfig.UDPMux, + UDPMuxSrflx: config.ICEConfig.UDPMuxSrflx, + NAT1To1IPs: config.ICEConfig.NATExternalIPs, + Net: transportNet, + FailedTimeout: &failedTimeout, + DisconnectedTimeout: &iceDisconnectedTimeout, + KeepaliveInterval: &iceKeepAlive, + RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait, + LocalUfrag: ufrag, + LocalPwd: pwd, + } + + if config.ICEConfig.DisableIPv6Discovery { + agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4} + } + + return ice.NewAgent(agentConfig) +} + func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) { relatedAdd := candidate.RelatedAddress() return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ diff --git a/client/internal/peer/worker_relay.go b/client/internal/peer/worker_relay.go index 6bb385d3e..c02fccebc 100644 --- a/client/internal/peer/worker_relay.go +++ b/client/internal/peer/worker_relay.go @@ -74,7 +74,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key) if err != nil { if errors.Is(err, relayClient.ErrConnAlreadyExists) { - w.log.Infof("do not need to reopen relay connection") + w.log.Debugf("handled offer by reusing existing relay connection") return } w.log.Errorf("failed to open connection via Relay: %s", err) diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index db2caea7f..eaa232151 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -10,12 +10,12 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/iface" nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/static" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" ) diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go index 5897031e7..ac94d4a5c 100644 --- a/client/internal/routemanager/dynamic/route.go +++ b/client/internal/routemanager/dynamic/route.go @@ -13,10 +13,10 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/util" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/route" ) @@ -303,7 +303,7 @@ func (r *Route) addRoutes(domain domain.Domain, prefixes []netip.Prefix) ([]neti var merr *multierror.Error for _, prefix := range prefixes { - if _, err := r.routeRefCounter.Increment(prefix, nil); err != nil { + if _, err := r.routeRefCounter.Increment(prefix, struct{}{}); err != nil { merr = multierror.Append(merr, fmt.Errorf("add dynamic route for IP %s: %w", prefix, err)) continue } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index cdfd322bd..d7ddf7ae8 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -14,6 +14,8 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager/notifier" @@ -21,7 +23,6 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routeselector" - "github.com/netbirdio/netbird/iface" relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" nbnet "github.com/netbirdio/netbird/util/net" @@ -87,10 +88,10 @@ func NewManager( } dm.routeRefCounter = refcounter.New( - func(prefix netip.Prefix, _ any) (any, error) { - return nil, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface()) + func(prefix netip.Prefix, _ struct{}) (struct{}, error) { + return struct{}{}, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface()) }, - func(prefix netip.Prefix, _ any) error { + func(prefix netip.Prefix, _ struct{}) error { return sysOps.RemoveVPNRoute(prefix, wgInterface.ToInterface()) }, ) @@ -102,7 +103,7 @@ func NewManager( }, func(prefix netip.Prefix, peerKey string) error { if err := wgInterface.RemoveAllowedIP(peerKey, prefix.String()); err != nil { - if !errors.Is(err, iface.ErrPeerNotFound) && !errors.Is(err, iface.ErrAllowedIPNotFound) { + if !errors.Is(err, configurer.ErrPeerNotFound) && !errors.Is(err, configurer.ErrAllowedIPNotFound) { return err } log.Tracef("Remove allowed IPs %s for %s: %v", prefix, peerKey, err) diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 2995e2740..2f26f7a5e 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -12,8 +12,8 @@ import ( "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" ) diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 58a66715c..908279c88 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -5,9 +5,9 @@ import ( "fmt" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/routeselector" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/util/net" ) diff --git a/client/internal/routemanager/refcounter/refcounter.go b/client/internal/routemanager/refcounter/refcounter.go index f1d696ad9..65ea0f708 100644 --- a/client/internal/routemanager/refcounter/refcounter.go +++ b/client/internal/routemanager/refcounter/refcounter.go @@ -3,7 +3,8 @@ package refcounter import ( "errors" "fmt" - "net/netip" + "runtime" + "strings" "sync" "github.com/hashicorp/go-multierror" @@ -12,118 +13,153 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" ) -// ErrIgnore can be returned by AddFunc to indicate that the counter not be incremented for the given prefix. +const logLevel = log.TraceLevel + +// ErrIgnore can be returned by AddFunc to indicate that the counter should not be incremented for the given key. var ErrIgnore = errors.New("ignore") +// Ref holds the reference count and associated data for a key. type Ref[O any] struct { Count int Out O } -type AddFunc[I, O any] func(prefix netip.Prefix, in I) (out O, err error) -type RemoveFunc[I, O any] func(prefix netip.Prefix, out O) error +// AddFunc is the function type for adding a new key. +// Key is the type of the key (e.g., netip.Prefix). +type AddFunc[Key, I, O any] func(key Key, in I) (out O, err error) -type Counter[I, O any] struct { - // refCountMap keeps track of the reference Ref for prefixes - refCountMap map[netip.Prefix]Ref[O] +// RemoveFunc is the function type for removing a key. +type RemoveFunc[Key, O any] func(key Key, out O) error + +// Counter is a generic reference counter for managing keys and their associated data. +// Key: The type of the key (e.g., netip.Prefix, string). +// +// I: The input type for the AddFunc. It is the input type for additional data needed +// when adding a key, it is passed as the second argument to AddFunc. +// +// O: The output type for the AddFunc and RemoveFunc. This is the output returned by AddFunc. +// It is stored and passed to RemoveFunc when the reference count reaches 0. +// +// The types can be aliased to a specific type using the following syntax: +// +// type RouteRefCounter = Counter[netip.Prefix, any, any] +type Counter[Key comparable, I, O any] struct { + // refCountMap keeps track of the reference Ref for keys + refCountMap map[Key]Ref[O] refCountMu sync.Mutex - // idMap keeps track of the prefixes associated with an ID for removal - idMap map[string][]netip.Prefix + // idMap keeps track of the keys associated with an ID for removal + idMap map[string][]Key idMu sync.Mutex - add AddFunc[I, O] - remove RemoveFunc[I, O] + add AddFunc[Key, I, O] + remove RemoveFunc[Key, O] } -// New creates a new Counter instance -func New[I, O any](add AddFunc[I, O], remove RemoveFunc[I, O]) *Counter[I, O] { - return &Counter[I, O]{ - refCountMap: map[netip.Prefix]Ref[O]{}, - idMap: map[string][]netip.Prefix{}, +// New creates a new Counter instance. +// Usage example: +// +// counter := New[netip.Prefix, string, string]( +// func(key netip.Prefix, in string) (out string, err error) { ... }, +// func(key netip.Prefix, out string) error { ... },` +// ) +func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key, O]) *Counter[Key, I, O] { + return &Counter[Key, I, O]{ + refCountMap: map[Key]Ref[O]{}, + idMap: map[string][]Key{}, add: add, remove: remove, } } -// Increment increments the reference count for the given prefix. -// If this is the first reference to the prefix, the AddFunc is called. -func (rm *Counter[I, O]) Increment(prefix netip.Prefix, in I) (Ref[O], error) { +// Get retrieves the current reference count and associated data for a key. +// If the key doesn't exist, it returns a zero value Ref and false. +func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) { rm.refCountMu.Lock() defer rm.refCountMu.Unlock() - ref := rm.refCountMap[prefix] - log.Tracef("Increasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out) + ref, ok := rm.refCountMap[key] + return ref, ok +} - // Call AddFunc only if it's a new prefix +// Increment increments the reference count for the given key. +// If this is the first reference to the key, the AddFunc is called. +func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) { + rm.refCountMu.Lock() + defer rm.refCountMu.Unlock() + + ref := rm.refCountMap[key] + logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out) + + // Call AddFunc only if it's a new key if ref.Count == 0 { - log.Tracef("Adding for prefix %s with [%v]", prefix, ref.Out) - out, err := rm.add(prefix, in) + logCallerF("Calling add for key %v", key) + out, err := rm.add(key, in) if errors.Is(err, ErrIgnore) { return ref, nil } if err != nil { - return ref, fmt.Errorf("failed to add for prefix %s: %w", prefix, err) + return ref, fmt.Errorf("failed to add for key %v: %w", key, err) } ref.Out = out } ref.Count++ - rm.refCountMap[prefix] = ref + rm.refCountMap[key] = ref return ref, nil } -// IncrementWithID increments the reference count for the given prefix and groups it under the given ID. -// If this is the first reference to the prefix, the AddFunc is called. -func (rm *Counter[I, O]) IncrementWithID(id string, prefix netip.Prefix, in I) (Ref[O], error) { +// IncrementWithID increments the reference count for the given key and groups it under the given ID. +// If this is the first reference to the key, the AddFunc is called. +func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) { rm.idMu.Lock() defer rm.idMu.Unlock() - ref, err := rm.Increment(prefix, in) + ref, err := rm.Increment(key, in) if err != nil { return ref, fmt.Errorf("with ID: %w", err) } - rm.idMap[id] = append(rm.idMap[id], prefix) + rm.idMap[id] = append(rm.idMap[id], key) return ref, nil } -// Decrement decrements the reference count for the given prefix. +// Decrement decrements the reference count for the given key. // If the reference count reaches 0, the RemoveFunc is called. -func (rm *Counter[I, O]) Decrement(prefix netip.Prefix) (Ref[O], error) { +func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) { rm.refCountMu.Lock() defer rm.refCountMu.Unlock() - ref, ok := rm.refCountMap[prefix] + ref, ok := rm.refCountMap[key] if !ok { - log.Tracef("No reference found for prefix %s", prefix) + logCallerF("No reference found for key %v", key) return ref, nil } - log.Tracef("Decreasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out) + logCallerF("Decreasing ref count [%d -> %d] for key %v with Out [%v]", ref.Count, ref.Count-1, key, ref.Out) if ref.Count == 1 { - log.Tracef("Removing for prefix %s with [%v]", prefix, ref.Out) - if err := rm.remove(prefix, ref.Out); err != nil { - return ref, fmt.Errorf("remove for prefix %s: %w", prefix, err) + logCallerF("Calling remove for key %v", key) + if err := rm.remove(key, ref.Out); err != nil { + return ref, fmt.Errorf("remove for key %v: %w", key, err) } - delete(rm.refCountMap, prefix) + delete(rm.refCountMap, key) } else { ref.Count-- - rm.refCountMap[prefix] = ref + rm.refCountMap[key] = ref } return ref, nil } -// DecrementWithID decrements the reference count for all prefixes associated with the given ID. +// DecrementWithID decrements the reference count for all keys associated with the given ID. // If the reference count reaches 0, the RemoveFunc is called. -func (rm *Counter[I, O]) DecrementWithID(id string) error { +func (rm *Counter[Key, I, O]) DecrementWithID(id string) error { rm.idMu.Lock() defer rm.idMu.Unlock() var merr *multierror.Error - for _, prefix := range rm.idMap[id] { - if _, err := rm.Decrement(prefix); err != nil { + for _, key := range rm.idMap[id] { + if _, err := rm.Decrement(key); err != nil { merr = multierror.Append(merr, err) } } @@ -132,24 +168,77 @@ func (rm *Counter[I, O]) DecrementWithID(id string) error { return nberrors.FormatErrorOrNil(merr) } -// Flush removes all references and calls RemoveFunc for each prefix. -func (rm *Counter[I, O]) Flush() error { +// Flush removes all references and calls RemoveFunc for each key. +func (rm *Counter[Key, I, O]) Flush() error { rm.refCountMu.Lock() defer rm.refCountMu.Unlock() rm.idMu.Lock() defer rm.idMu.Unlock() var merr *multierror.Error - for prefix := range rm.refCountMap { - log.Tracef("Removing for prefix %s", prefix) - ref := rm.refCountMap[prefix] - if err := rm.remove(prefix, ref.Out); err != nil { - merr = multierror.Append(merr, fmt.Errorf("remove for prefix %s: %w", prefix, err)) + for key := range rm.refCountMap { + logCallerF("Calling remove for key %v", key) + ref := rm.refCountMap[key] + if err := rm.remove(key, ref.Out); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove for key %v: %w", key, err)) } } - rm.refCountMap = map[netip.Prefix]Ref[O]{} - rm.idMap = map[string][]netip.Prefix{} + clear(rm.refCountMap) + clear(rm.idMap) return nberrors.FormatErrorOrNil(merr) } + +// Clear removes all references without calling RemoveFunc. +func (rm *Counter[Key, I, O]) Clear() { + rm.refCountMu.Lock() + defer rm.refCountMu.Unlock() + rm.idMu.Lock() + defer rm.idMu.Unlock() + + clear(rm.refCountMap) + clear(rm.idMap) +} + +func getCallerInfo(depth int, maxDepth int) (string, bool) { + if depth >= maxDepth { + return "", false + } + + pc, _, _, ok := runtime.Caller(depth) + if !ok { + return "", false + } + + if details := runtime.FuncForPC(pc); details != nil { + name := details.Name() + + lastDotIndex := strings.LastIndex(name, "/") + if lastDotIndex != -1 { + name = name[lastDotIndex+1:] + } + + if strings.HasPrefix(name, "refcounter.") { + // +2 to account for recursion + return getCallerInfo(depth+2, maxDepth) + } + + return name, true + } + + return "", false +} + +// logCaller logs a message with the package name and method of the function that called the current function. +func logCallerF(format string, args ...interface{}) { + if log.GetLevel() < logLevel { + return + } + + if callerName, ok := getCallerInfo(3, 18); ok { + format = fmt.Sprintf("[%s] %s", callerName, format) + } + + log.StandardLogger().Logf(logLevel, format, args...) +} diff --git a/client/internal/routemanager/refcounter/types.go b/client/internal/routemanager/refcounter/types.go index 6753b64ef..aadac3e25 100644 --- a/client/internal/routemanager/refcounter/types.go +++ b/client/internal/routemanager/refcounter/types.go @@ -1,7 +1,9 @@ package refcounter +import "net/netip" + // RouteRefCounter is a Counter for Route, it doesn't take any input on Increment and doesn't use any output on Decrement -type RouteRefCounter = Counter[any, any] +type RouteRefCounter = Counter[netip.Prefix, struct{}, struct{}] // AllowedIPsRefCounter is a Counter for AllowedIPs, it takes a peer key on Increment and passes it back to Decrement -type AllowedIPsRefCounter = Counter[string, string] +type AllowedIPsRefCounter = Counter[netip.Prefix, string, string] diff --git a/client/internal/routemanager/server_android.go b/client/internal/routemanager/server_android.go index 2057b9cc8..c75a0a7f2 100644 --- a/client/internal/routemanager/server_android.go +++ b/client/internal/routemanager/server_android.go @@ -7,8 +7,8 @@ import ( "fmt" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" ) func newServerRouter(context.Context, iface.IWGIface, firewall.Manager, *peer.Status) (serverRouter, error) { diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index 43a266cd2..ef38d5707 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -11,9 +11,9 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" ) @@ -94,7 +94,7 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error return fmt.Errorf("parse prefix: %w", err) } - err = m.firewall.RemoveRoutingRules(routerPair) + err = m.firewall.RemoveNatRule(routerPair) if err != nil { return fmt.Errorf("remove routing rules: %w", err) } @@ -123,7 +123,7 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error { return fmt.Errorf("parse prefix: %w", err) } - err = m.firewall.InsertRoutingRules(routerPair) + err = m.firewall.AddNatRule(routerPair) if err != nil { return fmt.Errorf("insert routing rules: %w", err) } @@ -157,7 +157,7 @@ func (m *defaultServerRouter) cleanUp() { continue } - err = m.firewall.RemoveRoutingRules(routerPair) + err = m.firewall.RemoveNatRule(routerPair) if err != nil { log.Errorf("Failed to remove cleanup route: %v", err) } @@ -173,15 +173,15 @@ func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) { // TODO: add ipv6 source := getDefaultPrefix(route.Network) - destination := route.Network.Masked().String() + destination := route.Network.Masked() if route.IsDynamic() { - // TODO: add ipv6 - destination = "0.0.0.0/0" + // TODO: add ipv6 additionally + destination = getDefaultPrefix(destination) } return firewall.RouterPair{ - ID: string(route.ID), - Source: source.String(), + ID: route.ID, + Source: source, Destination: destination, Masquerade: route.Masquerade, }, nil diff --git a/client/internal/routemanager/static/route.go b/client/internal/routemanager/static/route.go index 88cca522a..98c34dbee 100644 --- a/client/internal/routemanager/static/route.go +++ b/client/internal/routemanager/static/route.go @@ -30,7 +30,7 @@ func (r *Route) String() string { } func (r *Route) AddRoute(context.Context) error { - _, err := r.routeRefCounter.Increment(r.route.Network, nil) + _, err := r.routeRefCounter.Increment(r.route.Network, struct{}{}) return err } diff --git a/client/internal/routemanager/sysctl/sysctl_linux.go b/client/internal/routemanager/sysctl/sysctl_linux.go index 13e1229f8..bb620ee68 100644 --- a/client/internal/routemanager/sysctl/sysctl_linux.go +++ b/client/internal/routemanager/sysctl/sysctl_linux.go @@ -13,7 +13,7 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" ) const ( diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go index ae27b0123..d1cb83bfb 100644 --- a/client/internal/routemanager/systemops/systemops.go +++ b/client/internal/routemanager/systemops/systemops.go @@ -5,9 +5,9 @@ import ( "net/netip" "sync" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/routemanager/notifier" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" - "github.com/netbirdio/netbird/iface" ) type Nexthop struct { @@ -15,7 +15,7 @@ type Nexthop struct { Intf *net.Interface } -type ExclusionCounter = refcounter.Counter[any, Nexthop] +type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop] type SysOps struct { refCounter *ExclusionCounter diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index d76824c10..9258f4a4e 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -16,10 +16,10 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/vars" - "github.com/netbirdio/netbird/iface" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -41,7 +41,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn } refCounter := refcounter.New( - func(prefix netip.Prefix, _ any) (Nexthop, error) { + func(prefix netip.Prefix, _ struct{}) (Nexthop, error) { initialNexthop := initialNextHopV4 if prefix.Addr().Is6() { initialNexthop = initialNextHopV6 @@ -317,7 +317,7 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re return fmt.Errorf("convert ip to prefix: %w", err) } - if _, err := r.refCounter.IncrementWithID(string(connID), prefix, nil); err != nil { + if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil { return fmt.Errorf("adding route reference: %v", err) } diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index 94965c119..238225807 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -19,7 +19,7 @@ import ( "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" ) type dialer interface { diff --git a/client/internal/wgproxy/ebpf/proxy.go b/client/internal/wgproxy/ebpf/proxy.go index 4bd4bfff6..e850f4533 100644 --- a/client/internal/wgproxy/ebpf/proxy.go +++ b/client/internal/wgproxy/ebpf/proxy.go @@ -5,7 +5,6 @@ package ebpf import ( "context" "fmt" - "io" "net" "os" "sync" @@ -81,8 +80,7 @@ func (p *WGEBPFProxy) Listen() error { conn, err := nbnet.ListenUDP("udp", &addr) if err != nil { - cErr := p.Free() - if cErr != nil { + if cErr := p.Free(); cErr != nil { log.Errorf("Failed to close the wgproxy: %s", cErr) } return err @@ -95,13 +93,12 @@ func (p *WGEBPFProxy) Listen() error { } // AddTurnConn add new turn connection for the proxy -func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) { +func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (*net.UDPAddr, error) { wgEndpointPort, err := p.storeTurnConn(turnConn) if err != nil { return nil, err } - go p.proxyToLocal(ctx, wgEndpointPort, turnConn) log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort) wgEndpoint := &net.UDPAddr{ @@ -122,8 +119,10 @@ func (p *WGEBPFProxy) Free() error { p.ctxCancel() var result *multierror.Error - if err := p.conn.Close(); err != nil { - result = multierror.Append(result, err) + if p.conn != nil { // p.conn will be nil if we have failed to listen + if err := p.conn.Close(); err != nil { + result = multierror.Append(result, err) + } } if err := p.ebpfManager.FreeWGProxy(); err != nil { @@ -136,35 +135,6 @@ func (p *WGEBPFProxy) Free() error { return nberrors.FormatErrorOrNil(result) } -func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) { - defer p.removeTurnConn(endpointPort) - - var ( - err error - n int - ) - buf := make([]byte, 1500) - for ctx.Err() == nil { - n, err = remoteConn.Read(buf) - if err != nil { - if ctx.Err() != nil { - return - } - if err != io.EOF { - log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err) - } - return - } - - if err := p.sendPkg(buf[:n], endpointPort); err != nil { - if ctx.Err() != nil || p.ctx.Err() != nil { - return - } - log.Errorf("failed to write out turn pkg to local conn: %v", err) - } - } -} - // proxyToRemote read messages from local WireGuard interface and forward it to remote conn // From this go routine has only one instance. func (p *WGEBPFProxy) proxyToRemote() { @@ -279,7 +249,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) { return packetConn, nil } -func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error { +func (p *WGEBPFProxy) sendPkg(data []byte, port int) error { localhost := net.ParseIP("127.0.0.1") payload := gopacket.Payload(data) diff --git a/client/internal/wgproxy/ebpf/wrapper.go b/client/internal/wgproxy/ebpf/wrapper.go index c5639f840..b6a8ac452 100644 --- a/client/internal/wgproxy/ebpf/wrapper.go +++ b/client/internal/wgproxy/ebpf/wrapper.go @@ -4,8 +4,13 @@ package ebpf import ( "context" + "errors" "fmt" + "io" "net" + "sync" + + log "github.com/sirupsen/logrus" ) // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call @@ -13,20 +18,55 @@ type ProxyWrapper struct { WgeBPFProxy *WGEBPFProxy remoteConn net.Conn - cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread + ctx context.Context + cancel context.CancelFunc + + wgEndpointAddr *net.UDPAddr + + pausedMu sync.Mutex + paused bool + isStarted bool } -func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) { - ctxConn, cancel := context.WithCancel(ctx) - addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn) - +func (p *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) error { + addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn) if err != nil { - cancel() - return nil, fmt.Errorf("add turn conn: %w", err) + return fmt.Errorf("add turn conn: %w", err) } - e.remoteConn = remoteConn - e.cancel = cancel - return addr, err + p.remoteConn = remoteConn + p.ctx, p.cancel = context.WithCancel(ctx) + p.wgEndpointAddr = addr + return err +} + +func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr { + return p.wgEndpointAddr +} + +func (p *ProxyWrapper) Work() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = false + p.pausedMu.Unlock() + + if !p.isStarted { + p.isStarted = true + go p.proxyToLocal(p.ctx) + } +} + +func (p *ProxyWrapper) Pause() { + if p.remoteConn == nil { + return + } + + log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr()) + p.pausedMu.Lock() + p.paused = true + p.pausedMu.Unlock() } // CloseConn close the remoteConn and automatically remove the conn instance from the map @@ -42,3 +82,45 @@ func (e *ProxyWrapper) CloseConn() error { } return nil } + +func (p *ProxyWrapper) proxyToLocal(ctx context.Context) { + defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port)) + + buf := make([]byte, 1500) + for { + n, err := p.readFromRemote(ctx, buf) + if err != nil { + return + } + + p.pausedMu.Lock() + if p.paused { + p.pausedMu.Unlock() + continue + } + + err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port) + p.pausedMu.Unlock() + + if err != nil { + if ctx.Err() != nil { + return + } + log.Errorf("failed to write out turn pkg to local conn: %v", err) + } + } +} + +func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, error) { + n, err := p.remoteConn.Read(buf) + if err != nil { + if ctx.Err() != nil { + return 0, ctx.Err() + } + if !errors.Is(err, io.EOF) { + log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err) + } + return 0, err + } + return n, nil +} diff --git a/client/internal/wgproxy/proxy.go b/client/internal/wgproxy/proxy.go index 96fae8dd1..558121cdd 100644 --- a/client/internal/wgproxy/proxy.go +++ b/client/internal/wgproxy/proxy.go @@ -7,6 +7,9 @@ import ( // Proxy is a transfer layer between the relayed connection and the WireGuard type Proxy interface { - AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) + AddTurnConn(ctx context.Context, turnConn net.Conn) error + EndpointAddr() *net.UDPAddr + Work() + Pause() CloseConn() error } diff --git a/client/internal/wgproxy/proxy_test.go b/client/internal/wgproxy/proxy_test.go index b09e6be55..b88ff3f83 100644 --- a/client/internal/wgproxy/proxy_test.go +++ b/client/internal/wgproxy/proxy_test.go @@ -114,7 +114,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { relayedConn := newMockConn() - _, err := tt.proxy.AddTurnConn(ctx, relayedConn) + err := tt.proxy.AddTurnConn(ctx, relayedConn) if err != nil { t.Errorf("error: %v", err) } diff --git a/client/internal/wgproxy/usp/proxy.go b/client/internal/wgproxy/usp/proxy.go index 83a8725d8..f73500717 100644 --- a/client/internal/wgproxy/usp/proxy.go +++ b/client/internal/wgproxy/usp/proxy.go @@ -15,13 +15,17 @@ import ( // WGUserSpaceProxy proxies type WGUserSpaceProxy struct { localWGListenPort int - ctx context.Context - cancel context.CancelFunc remoteConn net.Conn localConn net.Conn + ctx context.Context + cancel context.CancelFunc closeMu sync.Mutex closed bool + + pausedMu sync.Mutex + paused bool + isStarted bool } // NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation @@ -33,24 +37,60 @@ func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy { return p } -// AddTurnConn start the proxy with the given remote conn -func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) { - p.ctx, p.cancel = context.WithCancel(ctx) - - p.remoteConn = remoteConn - - var err error +// AddTurnConn +// The provided Context must be non-nil. If the context expires before +// the connection is complete, an error is returned. Once successfully +// connected, any expiration of the context will not affect the +// connection. +func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) error { dialer := net.Dialer{} - p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) + localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) if err != nil { log.Errorf("failed dialing to local Wireguard port %s", err) - return nil, err + return err } - go p.proxyToRemote() - go p.proxyToLocal() + p.ctx, p.cancel = context.WithCancel(ctx) + p.localConn = localConn + p.remoteConn = remoteConn - return p.localConn.LocalAddr(), err + return err +} + +func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr { + if p.localConn == nil { + return nil + } + endpointUdpAddr, _ := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String()) + return endpointUdpAddr +} + +// Work starts the proxy or resumes it if it was paused +func (p *WGUserSpaceProxy) Work() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = false + p.pausedMu.Unlock() + + if !p.isStarted { + p.isStarted = true + go p.proxyToRemote(p.ctx) + go p.proxyToLocal(p.ctx) + } +} + +// Pause pauses the proxy from receiving data from the remote peer +func (p *WGUserSpaceProxy) Pause() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = true + p.pausedMu.Unlock() } // CloseConn close the localConn @@ -85,7 +125,7 @@ func (p *WGUserSpaceProxy) close() error { } // proxyToRemote proxies from Wireguard to the RemoteKey -func (p *WGUserSpaceProxy) proxyToRemote() { +func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) { defer func() { if err := p.close(); err != nil { log.Warnf("error in proxy to remote loop: %s", err) @@ -93,10 +133,10 @@ func (p *WGUserSpaceProxy) proxyToRemote() { }() buf := make([]byte, 1500) - for p.ctx.Err() == nil { + for ctx.Err() == nil { n, err := p.localConn.Read(buf) if err != nil { - if p.ctx.Err() != nil { + if ctx.Err() != nil { return } log.Debugf("failed to read from wg interface conn: %s", err) @@ -105,7 +145,7 @@ func (p *WGUserSpaceProxy) proxyToRemote() { _, err = p.remoteConn.Write(buf[:n]) if err != nil { - if p.ctx.Err() != nil { + if ctx.Err() != nil { return } @@ -116,7 +156,8 @@ func (p *WGUserSpaceProxy) proxyToRemote() { } // proxyToLocal proxies from the Remote peer to local WireGuard -func (p *WGUserSpaceProxy) proxyToLocal() { +// if the proxy is paused it will drain the remote conn and drop the packets +func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) { defer func() { if err := p.close(); err != nil { log.Warnf("error in proxy to local loop: %s", err) @@ -124,19 +165,27 @@ func (p *WGUserSpaceProxy) proxyToLocal() { }() buf := make([]byte, 1500) - for p.ctx.Err() == nil { + for { n, err := p.remoteConn.Read(buf) if err != nil { - if p.ctx.Err() != nil { + if ctx.Err() != nil { return } log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err) return } + p.pausedMu.Lock() + if p.paused { + p.pausedMu.Unlock() + continue + } + _, err = p.localConn.Write(buf[:n]) + p.pausedMu.Unlock() + if err != nil { - if p.ctx.Err() != nil { + if ctx.Err() != nil { return } log.Debugf("failed to write to wg interface conn: %s", err) diff --git a/client/server/server_test.go b/client/server/server_test.go index 795060fab..61bdaf660 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -110,7 +110,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve return nil, "", err } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir) + store, cleanUp, err := server.NewTestStoreFromSQL(context.Background(), "", config.Datadir) if err != nil { return nil, "", err } @@ -160,7 +160,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) { log.Fatalf("failed to listen: %v", err) } - srv, err := signalServer.NewServer(otel.Meter("")) + srv, err := signalServer.NewServer(context.Background(), otel.Meter("")) require.NoError(t, err) proto.RegisterSignalExchangeServer(s, srv) diff --git a/client/testdata/store.json b/client/testdata/store.json deleted file mode 100644 index 8236f2703..000000000 --- a/client/testdata/store.json +++ /dev/null @@ -1,38 +0,0 @@ -{ - "Accounts": { - "bf1c8084-ba50-4ce7-9439-34653001fc3b": { - "Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b", - "SetupKeys": { - "A2C8E62B-38F5-4553-B31E-DD66C696CEBB": { - "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", - "Name": "Default key", - "Type": "reusable", - "CreatedAt": "2021-08-19T20:46:20.005936822+02:00", - "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00", - "Revoked": false, - "UsedTimes": 0 - - } - }, - "Network": { - "Id": "af1c8024-ha40-4ce2-9418-34653101fc3c", - "Net": { - "IP": "100.64.0.0", - "Mask": "//8AAA==" - }, - "Dns": null - }, - "Peers": {}, - "Users": { - "edafee4e-63fb-11ec-90d6-0242ac120003": { - "Id": "edafee4e-63fb-11ec-90d6-0242ac120003", - "Role": "admin" - }, - "f4f6d672-63fb-11ec-90d6-0242ac120003": { - "Id": "f4f6d672-63fb-11ec-90d6-0242ac120003", - "Role": "user" - } - } - } - } -} \ No newline at end of file diff --git a/client/testdata/store.sql b/client/testdata/store.sql new file mode 100644 index 000000000..ed5395486 --- /dev/null +++ b/client/testdata/store.sql @@ -0,0 +1,36 @@ +PRAGMA foreign_keys=OFF; +BEGIN TRANSACTION; +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE); +CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`)); +CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text); +CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); +CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); +CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); +CREATE INDEX `idx_peers_key` ON `peers`(`key`); +CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`); +CREATE INDEX `idx_users_account_id` ON `users`(`account_id`); +CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`); +CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`); +CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`); +CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`); +CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`); +CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`); +CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); + +INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 21:28:24.830195+02:00','','',0,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0); +INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,''); +INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,''); +INSERT INTO installations VALUES(1,''); + +COMMIT; diff --git a/go.mod b/go.mod index 12709e50d..e7e3c17a6 100644 --- a/go.mod +++ b/go.mod @@ -19,8 +19,8 @@ require ( github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 github.com/vishvananda/netlink v1.2.1-beta.2 - golang.org/x/crypto v0.24.0 - golang.org/x/sys v0.21.0 + golang.org/x/crypto v0.28.0 + golang.org/x/sys v0.26.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 @@ -38,6 +38,7 @@ require ( github.com/cilium/ebpf v0.15.0 github.com/coreos/go-iptables v0.7.0 github.com/creack/pty v1.1.18 + github.com/davecgh/go-spew v1.1.1 github.com/eko/gocache/v3 v3.1.1 github.com/fsnotify/fsnotify v1.7.0 github.com/gliderlabs/ssh v0.3.4 @@ -45,7 +46,7 @@ require ( github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.6.0 github.com/google/gopacket v1.1.19 - github.com/google/nftables v0.0.0-20220808154552-2eca00135732 + github.com/google/nftables v0.2.0 github.com/gopacket/gopacket v1.1.1 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 github.com/hashicorp/go-multierror v1.1.1 @@ -55,12 +56,12 @@ require ( github.com/libp2p/go-netroute v0.2.1 github.com/magiconair/properties v1.8.7 github.com/mattn/go-sqlite3 v1.14.19 - github.com/mdlayher/socket v0.4.1 + github.com/mdlayher/socket v0.5.1 github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e - github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080 + github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd + github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 github.com/patrickmn/go-cache v2.1.0+incompatible @@ -89,10 +90,10 @@ require ( goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a - golang.org/x/net v0.26.0 + golang.org/x/net v0.30.0 golang.org/x/oauth2 v0.19.0 - golang.org/x/sync v0.7.0 - golang.org/x/term v0.21.0 + golang.org/x/sync v0.8.0 + golang.org/x/term v0.25.0 google.golang.org/api v0.177.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/postgres v1.5.7 @@ -133,7 +134,6 @@ require ( github.com/containerd/containerd v1.7.16 // indirect github.com/containerd/log v0.1.0 // indirect github.com/cpuguy83/dockercfg v0.3.1 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgraph-io/ristretto v0.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/distribution/reference v0.6.0 // indirect @@ -219,7 +219,7 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/mod v0.17.0 // indirect - golang.org/x/text v0.16.0 // indirect + golang.org/x/text v0.19.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect diff --git a/go.sum b/go.sum index 2355f6f0c..e9bc318d6 100644 --- a/go.sum +++ b/go.sum @@ -322,8 +322,8 @@ github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= -github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A= -github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc= +github.com/google/nftables v0.2.0 h1:PbJwaBmbVLzpeldoeUKGkE2RjstrjPKMl6oLrfEJ6/8= +github.com/google/nftables v0.2.0/go.mod h1:Beg6V6zZ3oEn0JuiUQ4wqwuyqqzasOltcoXPtgLbFp4= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= @@ -475,8 +475,8 @@ github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy5 github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= -github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U= -github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA= +github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= +github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k= github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= @@ -521,12 +521,12 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= -github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e h1:LYxhAmiEzSldLELHSMVoUnRPq3ztTNQImrD27frrGsI= -github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= +github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd h1:phKq1S1Y/lnqEhP5Qknta733+rPX16dRDHM7hKkot9c= +github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= -github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080 h1:mXJkoWLdqJTlkQ7DgQ536kcXHXIdUPeagkN8i4eFDdg= -github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= +github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= +github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= @@ -774,8 +774,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= -golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= +golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= +golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -871,8 +871,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= -golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= -golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= +golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= +golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -901,8 +901,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -974,8 +974,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= -golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= +golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -983,8 +983,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= -golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA= -golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= +golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24= +golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -999,8 +999,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= +golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/iface/tun.go b/iface/tun.go deleted file mode 100644 index 7d0a57ed6..000000000 --- a/iface/tun.go +++ /dev/null @@ -1,21 +0,0 @@ -//go:build !android -// +build !android - -package iface - -import ( - "github.com/netbirdio/netbird/iface/bind" -) - -// CustomWindowsGUIDString is a custom GUID string for the interface -var CustomWindowsGUIDString string - -type wgTunDevice interface { - Create() (wgConfigurer, error) - Up() (*bind.UniversalUDPMuxDefault, error) - UpdateAddr(address WGAddress) error - WgAddress() WGAddress - DeviceName() string - Close() error - Wrapper() *DeviceWrapper // todo eliminate this function -} diff --git a/iface/wg_configurer.go b/iface/wg_configurer.go deleted file mode 100644 index dd38ba075..000000000 --- a/iface/wg_configurer.go +++ /dev/null @@ -1,21 +0,0 @@ -package iface - -import ( - "errors" - "net" - "time" - - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" -) - -var ErrPeerNotFound = errors.New("peer not found") - -type wgConfigurer interface { - configureInterface(privateKey string, port int) error - updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error - removePeer(peerKey string) error - addAllowedIP(peerKey string, allowedIP string) error - removeAllowedIP(peerKey string, allowedIP string) error - close() - getStats(peerKey string) (WGStats, error) -} diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index c0275536b..2c5c35d53 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -793,6 +793,11 @@ services: volumes: - netbird_caddy_data:/data - ./Caddyfile:/etc/caddy/Caddyfile + logging: + driver: "json-file" + options: + max-size: "500m" + max-file: "2" # UI dashboard dashboard: image: netbirdio/dashboard:latest diff --git a/management/Dockerfile b/management/Dockerfile index cac640bf4..3b2df2623 100644 --- a/management/Dockerfile +++ b/management/Dockerfile @@ -1,5 +1,5 @@ -FROM ubuntu:22.04 +FROM ubuntu:24.04 RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt ENTRYPOINT [ "/go/bin/netbird-mgmt","management"] CMD ["--log-file", "console"] -COPY netbird-mgmt /go/bin/netbird-mgmt \ No newline at end of file +COPY netbird-mgmt /go/bin/netbird-mgmt diff --git a/management/Dockerfile.debug b/management/Dockerfile.debug index f4be366a8..4d9730bd7 100644 --- a/management/Dockerfile.debug +++ b/management/Dockerfile.debug @@ -1,4 +1,4 @@ -FROM ubuntu:22.04 +FROM ubuntu:24.04 RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt ENTRYPOINT [ "/go/bin/netbird-mgmt","management","--log-level","debug"] CMD ["--log-file", "console"] diff --git a/management/client/client_test.go b/management/client/client_test.go index a082e354b..100b3fcaa 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -4,7 +4,6 @@ import ( "context" "net" "os" - "path/filepath" "sync" "testing" "time" @@ -47,25 +46,18 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { level, _ := log.ParseLevel("debug") log.SetLevel(level) - testDir := t.TempDir() - config := &mgmt.Config{} _, err := util.ReadJson("../server/testdata/management.json", config) if err != nil { t.Fatal(err) } - config.Datadir = testDir - err = util.CopyFileContents("../server/testdata/store.json", filepath.Join(testDir, "store.json")) - if err != nil { - t.Fatal(err) - } lis, err := net.Listen("tcp", ":0") if err != nil { t.Fatal(err) } s := grpc.NewServer() - store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir) + store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), "../server/testdata/store.sql", t.TempDir()) if err != nil { t.Fatal(err) } diff --git a/management/cmd/management.go b/management/cmd/management.go index 78b1a8d63..719d1a78c 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -475,7 +475,7 @@ func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handle func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config, error) { loadedConfig := &server.Config{} - _, err := util.ReadJson(mgmtConfigPath, loadedConfig) + _, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig) if err != nil { return nil, err } diff --git a/management/proto/management.pb.go b/management/proto/management.pb.go index 48f048c4c..672b2a102 100644 --- a/management/proto/management.pb.go +++ b/management/proto/management.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v3.21.12 +// protoc v4.23.4 // source: management.proto package proto @@ -21,6 +21,153 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) +type RuleProtocol int32 + +const ( + RuleProtocol_UNKNOWN RuleProtocol = 0 + RuleProtocol_ALL RuleProtocol = 1 + RuleProtocol_TCP RuleProtocol = 2 + RuleProtocol_UDP RuleProtocol = 3 + RuleProtocol_ICMP RuleProtocol = 4 +) + +// Enum value maps for RuleProtocol. +var ( + RuleProtocol_name = map[int32]string{ + 0: "UNKNOWN", + 1: "ALL", + 2: "TCP", + 3: "UDP", + 4: "ICMP", + } + RuleProtocol_value = map[string]int32{ + "UNKNOWN": 0, + "ALL": 1, + "TCP": 2, + "UDP": 3, + "ICMP": 4, + } +) + +func (x RuleProtocol) Enum() *RuleProtocol { + p := new(RuleProtocol) + *p = x + return p +} + +func (x RuleProtocol) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (RuleProtocol) Descriptor() protoreflect.EnumDescriptor { + return file_management_proto_enumTypes[0].Descriptor() +} + +func (RuleProtocol) Type() protoreflect.EnumType { + return &file_management_proto_enumTypes[0] +} + +func (x RuleProtocol) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use RuleProtocol.Descriptor instead. +func (RuleProtocol) EnumDescriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{0} +} + +type RuleDirection int32 + +const ( + RuleDirection_IN RuleDirection = 0 + RuleDirection_OUT RuleDirection = 1 +) + +// Enum value maps for RuleDirection. +var ( + RuleDirection_name = map[int32]string{ + 0: "IN", + 1: "OUT", + } + RuleDirection_value = map[string]int32{ + "IN": 0, + "OUT": 1, + } +) + +func (x RuleDirection) Enum() *RuleDirection { + p := new(RuleDirection) + *p = x + return p +} + +func (x RuleDirection) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (RuleDirection) Descriptor() protoreflect.EnumDescriptor { + return file_management_proto_enumTypes[1].Descriptor() +} + +func (RuleDirection) Type() protoreflect.EnumType { + return &file_management_proto_enumTypes[1] +} + +func (x RuleDirection) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use RuleDirection.Descriptor instead. +func (RuleDirection) EnumDescriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{1} +} + +type RuleAction int32 + +const ( + RuleAction_ACCEPT RuleAction = 0 + RuleAction_DROP RuleAction = 1 +) + +// Enum value maps for RuleAction. +var ( + RuleAction_name = map[int32]string{ + 0: "ACCEPT", + 1: "DROP", + } + RuleAction_value = map[string]int32{ + "ACCEPT": 0, + "DROP": 1, + } +) + +func (x RuleAction) Enum() *RuleAction { + p := new(RuleAction) + *p = x + return p +} + +func (x RuleAction) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (RuleAction) Descriptor() protoreflect.EnumDescriptor { + return file_management_proto_enumTypes[2].Descriptor() +} + +func (RuleAction) Type() protoreflect.EnumType { + return &file_management_proto_enumTypes[2] +} + +func (x RuleAction) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use RuleAction.Descriptor instead. +func (RuleAction) EnumDescriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{2} +} + type HostConfig_Protocol int32 const ( @@ -60,11 +207,11 @@ func (x HostConfig_Protocol) String() string { } func (HostConfig_Protocol) Descriptor() protoreflect.EnumDescriptor { - return file_management_proto_enumTypes[0].Descriptor() + return file_management_proto_enumTypes[3].Descriptor() } func (HostConfig_Protocol) Type() protoreflect.EnumType { - return &file_management_proto_enumTypes[0] + return &file_management_proto_enumTypes[3] } func (x HostConfig_Protocol) Number() protoreflect.EnumNumber { @@ -103,11 +250,11 @@ func (x DeviceAuthorizationFlowProvider) String() string { } func (DeviceAuthorizationFlowProvider) Descriptor() protoreflect.EnumDescriptor { - return file_management_proto_enumTypes[1].Descriptor() + return file_management_proto_enumTypes[4].Descriptor() } func (DeviceAuthorizationFlowProvider) Type() protoreflect.EnumType { - return &file_management_proto_enumTypes[1] + return &file_management_proto_enumTypes[4] } func (x DeviceAuthorizationFlowProvider) Number() protoreflect.EnumNumber { @@ -119,153 +266,6 @@ func (DeviceAuthorizationFlowProvider) EnumDescriptor() ([]byte, []int) { return file_management_proto_rawDescGZIP(), []int{21, 0} } -type FirewallRuleDirection int32 - -const ( - FirewallRule_IN FirewallRuleDirection = 0 - FirewallRule_OUT FirewallRuleDirection = 1 -) - -// Enum value maps for FirewallRuleDirection. -var ( - FirewallRuleDirection_name = map[int32]string{ - 0: "IN", - 1: "OUT", - } - FirewallRuleDirection_value = map[string]int32{ - "IN": 0, - "OUT": 1, - } -) - -func (x FirewallRuleDirection) Enum() *FirewallRuleDirection { - p := new(FirewallRuleDirection) - *p = x - return p -} - -func (x FirewallRuleDirection) String() string { - return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) -} - -func (FirewallRuleDirection) Descriptor() protoreflect.EnumDescriptor { - return file_management_proto_enumTypes[2].Descriptor() -} - -func (FirewallRuleDirection) Type() protoreflect.EnumType { - return &file_management_proto_enumTypes[2] -} - -func (x FirewallRuleDirection) Number() protoreflect.EnumNumber { - return protoreflect.EnumNumber(x) -} - -// Deprecated: Use FirewallRuleDirection.Descriptor instead. -func (FirewallRuleDirection) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{31, 0} -} - -type FirewallRuleAction int32 - -const ( - FirewallRule_ACCEPT FirewallRuleAction = 0 - FirewallRule_DROP FirewallRuleAction = 1 -) - -// Enum value maps for FirewallRuleAction. -var ( - FirewallRuleAction_name = map[int32]string{ - 0: "ACCEPT", - 1: "DROP", - } - FirewallRuleAction_value = map[string]int32{ - "ACCEPT": 0, - "DROP": 1, - } -) - -func (x FirewallRuleAction) Enum() *FirewallRuleAction { - p := new(FirewallRuleAction) - *p = x - return p -} - -func (x FirewallRuleAction) String() string { - return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) -} - -func (FirewallRuleAction) Descriptor() protoreflect.EnumDescriptor { - return file_management_proto_enumTypes[3].Descriptor() -} - -func (FirewallRuleAction) Type() protoreflect.EnumType { - return &file_management_proto_enumTypes[3] -} - -func (x FirewallRuleAction) Number() protoreflect.EnumNumber { - return protoreflect.EnumNumber(x) -} - -// Deprecated: Use FirewallRuleAction.Descriptor instead. -func (FirewallRuleAction) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{31, 1} -} - -type FirewallRuleProtocol int32 - -const ( - FirewallRule_UNKNOWN FirewallRuleProtocol = 0 - FirewallRule_ALL FirewallRuleProtocol = 1 - FirewallRule_TCP FirewallRuleProtocol = 2 - FirewallRule_UDP FirewallRuleProtocol = 3 - FirewallRule_ICMP FirewallRuleProtocol = 4 -) - -// Enum value maps for FirewallRuleProtocol. -var ( - FirewallRuleProtocol_name = map[int32]string{ - 0: "UNKNOWN", - 1: "ALL", - 2: "TCP", - 3: "UDP", - 4: "ICMP", - } - FirewallRuleProtocol_value = map[string]int32{ - "UNKNOWN": 0, - "ALL": 1, - "TCP": 2, - "UDP": 3, - "ICMP": 4, - } -) - -func (x FirewallRuleProtocol) Enum() *FirewallRuleProtocol { - p := new(FirewallRuleProtocol) - *p = x - return p -} - -func (x FirewallRuleProtocol) String() string { - return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) -} - -func (FirewallRuleProtocol) Descriptor() protoreflect.EnumDescriptor { - return file_management_proto_enumTypes[4].Descriptor() -} - -func (FirewallRuleProtocol) Type() protoreflect.EnumType { - return &file_management_proto_enumTypes[4] -} - -func (x FirewallRuleProtocol) Number() protoreflect.EnumNumber { - return protoreflect.EnumNumber(x) -} - -// Deprecated: Use FirewallRuleProtocol.Descriptor instead. -func (FirewallRuleProtocol) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{31, 2} -} - type EncryptedMessage struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -1482,6 +1482,10 @@ type NetworkMap struct { FirewallRules []*FirewallRule `protobuf:"bytes,8,rep,name=FirewallRules,proto3" json:"FirewallRules,omitempty"` // firewallRulesIsEmpty indicates whether FirewallRule array is empty or not to bypass protobuf null and empty array equality. FirewallRulesIsEmpty bool `protobuf:"varint,9,opt,name=firewallRulesIsEmpty,proto3" json:"firewallRulesIsEmpty,omitempty"` + // RoutesFirewallRules represents a list of routes firewall rules to be applied to peer + RoutesFirewallRules []*RouteFirewallRule `protobuf:"bytes,10,rep,name=routesFirewallRules,proto3" json:"routesFirewallRules,omitempty"` + // RoutesFirewallRulesIsEmpty indicates whether RouteFirewallRule array is empty or not to bypass protobuf null and empty array equality. + RoutesFirewallRulesIsEmpty bool `protobuf:"varint,11,opt,name=routesFirewallRulesIsEmpty,proto3" json:"routesFirewallRulesIsEmpty,omitempty"` } func (x *NetworkMap) Reset() { @@ -1579,6 +1583,20 @@ func (x *NetworkMap) GetFirewallRulesIsEmpty() bool { return false } +func (x *NetworkMap) GetRoutesFirewallRules() []*RouteFirewallRule { + if x != nil { + return x.RoutesFirewallRules + } + return nil +} + +func (x *NetworkMap) GetRoutesFirewallRulesIsEmpty() bool { + if x != nil { + return x.RoutesFirewallRulesIsEmpty + } + return false +} + // RemotePeerConfig represents a configuration of a remote peer. // The properties are used to configure WireGuard Peers sections type RemotePeerConfig struct { @@ -2487,11 +2505,11 @@ type FirewallRule struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - PeerIP string `protobuf:"bytes,1,opt,name=PeerIP,proto3" json:"PeerIP,omitempty"` - Direction FirewallRuleDirection `protobuf:"varint,2,opt,name=Direction,proto3,enum=management.FirewallRuleDirection" json:"Direction,omitempty"` - Action FirewallRuleAction `protobuf:"varint,3,opt,name=Action,proto3,enum=management.FirewallRuleAction" json:"Action,omitempty"` - Protocol FirewallRuleProtocol `protobuf:"varint,4,opt,name=Protocol,proto3,enum=management.FirewallRuleProtocol" json:"Protocol,omitempty"` - Port string `protobuf:"bytes,5,opt,name=Port,proto3" json:"Port,omitempty"` + PeerIP string `protobuf:"bytes,1,opt,name=PeerIP,proto3" json:"PeerIP,omitempty"` + Direction RuleDirection `protobuf:"varint,2,opt,name=Direction,proto3,enum=management.RuleDirection" json:"Direction,omitempty"` + Action RuleAction `protobuf:"varint,3,opt,name=Action,proto3,enum=management.RuleAction" json:"Action,omitempty"` + Protocol RuleProtocol `protobuf:"varint,4,opt,name=Protocol,proto3,enum=management.RuleProtocol" json:"Protocol,omitempty"` + Port string `protobuf:"bytes,5,opt,name=Port,proto3" json:"Port,omitempty"` } func (x *FirewallRule) Reset() { @@ -2533,25 +2551,25 @@ func (x *FirewallRule) GetPeerIP() string { return "" } -func (x *FirewallRule) GetDirection() FirewallRuleDirection { +func (x *FirewallRule) GetDirection() RuleDirection { if x != nil { return x.Direction } - return FirewallRule_IN + return RuleDirection_IN } -func (x *FirewallRule) GetAction() FirewallRuleAction { +func (x *FirewallRule) GetAction() RuleAction { if x != nil { return x.Action } - return FirewallRule_ACCEPT + return RuleAction_ACCEPT } -func (x *FirewallRule) GetProtocol() FirewallRuleProtocol { +func (x *FirewallRule) GetProtocol() RuleProtocol { if x != nil { return x.Protocol } - return FirewallRule_UNKNOWN + return RuleProtocol_UNKNOWN } func (x *FirewallRule) GetPort() string { @@ -2663,6 +2681,236 @@ func (x *Checks) GetFiles() []string { return nil } +type PortInfo struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Types that are assignable to PortSelection: + // + // *PortInfo_Port + // *PortInfo_Range_ + PortSelection isPortInfo_PortSelection `protobuf_oneof:"portSelection"` +} + +func (x *PortInfo) Reset() { + *x = PortInfo{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[34] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *PortInfo) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PortInfo) ProtoMessage() {} + +func (x *PortInfo) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[34] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PortInfo.ProtoReflect.Descriptor instead. +func (*PortInfo) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{34} +} + +func (m *PortInfo) GetPortSelection() isPortInfo_PortSelection { + if m != nil { + return m.PortSelection + } + return nil +} + +func (x *PortInfo) GetPort() uint32 { + if x, ok := x.GetPortSelection().(*PortInfo_Port); ok { + return x.Port + } + return 0 +} + +func (x *PortInfo) GetRange() *PortInfo_Range { + if x, ok := x.GetPortSelection().(*PortInfo_Range_); ok { + return x.Range + } + return nil +} + +type isPortInfo_PortSelection interface { + isPortInfo_PortSelection() +} + +type PortInfo_Port struct { + Port uint32 `protobuf:"varint,1,opt,name=port,proto3,oneof"` +} + +type PortInfo_Range_ struct { + Range *PortInfo_Range `protobuf:"bytes,2,opt,name=range,proto3,oneof"` +} + +func (*PortInfo_Port) isPortInfo_PortSelection() {} + +func (*PortInfo_Range_) isPortInfo_PortSelection() {} + +// RouteFirewallRule signifies a firewall rule applicable for a routed network. +type RouteFirewallRule struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // sourceRanges IP ranges of the routing peers. + SourceRanges []string `protobuf:"bytes,1,rep,name=sourceRanges,proto3" json:"sourceRanges,omitempty"` + // Action to be taken by the firewall when the rule is applicable. + Action RuleAction `protobuf:"varint,2,opt,name=action,proto3,enum=management.RuleAction" json:"action,omitempty"` + // Network prefix for the routed network. + Destination string `protobuf:"bytes,3,opt,name=destination,proto3" json:"destination,omitempty"` + // Protocol of the routed network. + Protocol RuleProtocol `protobuf:"varint,4,opt,name=protocol,proto3,enum=management.RuleProtocol" json:"protocol,omitempty"` + // Details about the port. + PortInfo *PortInfo `protobuf:"bytes,5,opt,name=portInfo,proto3" json:"portInfo,omitempty"` + // IsDynamic indicates if the route is a DNS route. + IsDynamic bool `protobuf:"varint,6,opt,name=isDynamic,proto3" json:"isDynamic,omitempty"` +} + +func (x *RouteFirewallRule) Reset() { + *x = RouteFirewallRule{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[35] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *RouteFirewallRule) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RouteFirewallRule) ProtoMessage() {} + +func (x *RouteFirewallRule) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[35] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RouteFirewallRule.ProtoReflect.Descriptor instead. +func (*RouteFirewallRule) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{35} +} + +func (x *RouteFirewallRule) GetSourceRanges() []string { + if x != nil { + return x.SourceRanges + } + return nil +} + +func (x *RouteFirewallRule) GetAction() RuleAction { + if x != nil { + return x.Action + } + return RuleAction_ACCEPT +} + +func (x *RouteFirewallRule) GetDestination() string { + if x != nil { + return x.Destination + } + return "" +} + +func (x *RouteFirewallRule) GetProtocol() RuleProtocol { + if x != nil { + return x.Protocol + } + return RuleProtocol_UNKNOWN +} + +func (x *RouteFirewallRule) GetPortInfo() *PortInfo { + if x != nil { + return x.PortInfo + } + return nil +} + +func (x *RouteFirewallRule) GetIsDynamic() bool { + if x != nil { + return x.IsDynamic + } + return false +} + +type PortInfo_Range struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"` + End uint32 `protobuf:"varint,2,opt,name=end,proto3" json:"end,omitempty"` +} + +func (x *PortInfo_Range) Reset() { + *x = PortInfo_Range{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[36] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *PortInfo_Range) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PortInfo_Range) ProtoMessage() {} + +func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[36] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PortInfo_Range.ProtoReflect.Descriptor instead. +func (*PortInfo_Range) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{34, 0} +} + +func (x *PortInfo_Range) GetStart() uint32 { + if x != nil { + return x.Start + } + return 0 +} + +func (x *PortInfo_Range) GetEnd() uint32 { + if x != nil { + return x.End + } + return 0 +} + var File_management_proto protoreflect.FileDescriptor var file_management_proto_rawDesc = []byte{ @@ -2835,7 +3083,7 @@ var file_management_proto_rawDesc = []byte{ 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0xe2, 0x03, 0x0a, 0x0a, + 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0xf3, 0x04, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, @@ -2866,184 +3114,219 @@ var file_management_proto_rawDesc = []byte{ 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, - 0x22, 0x97, 0x01, 0x0a, 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, - 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, - 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x18, - 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, - 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0x49, 0x0a, 0x09, 0x53, 0x53, - 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, - 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, - 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, - 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, - 0x75, 0x62, 0x4b, 0x65, 0x79, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, - 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, 0x69, - 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, - 0x6c, 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, - 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, 0x69, - 0x64, 0x65, 0x72, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, 0x0a, - 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, - 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, 0x50, 0x4b, 0x43, - 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, - 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, 0x50, 0x4b, 0x43, - 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, - 0x6f, 0x77, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0xea, 0x02, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, - 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, - 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, - 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, - 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, 0x69, - 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, - 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, 0x0a, - 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, - 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, - 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, - 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, - 0x69, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, 0x65, - 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x55, - 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, 0x74, - 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, - 0x6e, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, - 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, - 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x18, - 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, - 0x52, 0x4c, 0x73, 0x22, 0xed, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, - 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, - 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, - 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, - 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, - 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, - 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, - 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, - 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, - 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, - 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x44, - 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, - 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, - 0x74, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, - 0x75, 0x74, 0x65, 0x22, 0xb4, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, - 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, - 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, - 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, - 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, - 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x10, - 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, - 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, - 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, 0x43, - 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75, - 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, - 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, - 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, - 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, - 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, - 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, - 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, - 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, - 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, - 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, - 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, - 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, - 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, - 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, - 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, - 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, - 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, - 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, - 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, - 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, - 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, - 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, - 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, - 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xf0, 0x02, 0x0a, 0x0c, 0x46, - 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, - 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, - 0x72, 0x49, 0x50, 0x12, 0x40, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x22, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, - 0x2e, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, 0x65, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x37, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x2e, - 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x3d, - 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x21, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, - 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, - 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, - 0x74, 0x22, 0x1c, 0x0a, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, - 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x22, - 0x1e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, - 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x22, - 0x3c, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, - 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, 0x10, - 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, - 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x22, 0x38, 0x0a, - 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, - 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, - 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, - 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, - 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x32, 0x90, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, - 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, - 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, - 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, - 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, - 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, - 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, - 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, - 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, - 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, - 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, - 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, - 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x12, 0x4f, 0x0a, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, + 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1d, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, + 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x13, 0x72, 0x6f, + 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, + 0x73, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, + 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, + 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, + 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, + 0x79, 0x22, 0x97, 0x01, 0x0a, 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, + 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, + 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, + 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, + 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, + 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0x49, 0x0a, 0x09, 0x53, + 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, + 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, + 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, + 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, + 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, + 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, + 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, + 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, + 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, + 0x69, 0x64, 0x65, 0x72, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, + 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, + 0x0a, 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, 0x50, 0x4b, + 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, + 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, 0x50, 0x4b, + 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, + 0x6c, 0x6f, 0x77, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, + 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, + 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0xea, 0x02, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, + 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, + 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, + 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, + 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, + 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, + 0x0a, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, + 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, + 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, + 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, + 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, + 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, + 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, + 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, + 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, + 0x69, 0x6e, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, + 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, + 0x12, 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, + 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, + 0x55, 0x52, 0x4c, 0x73, 0x22, 0xed, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, + 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, + 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, + 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, + 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, + 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, + 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, + 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, + 0x72, 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, + 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, + 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, + 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, + 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, + 0x75, 0x74, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, + 0x6f, 0x75, 0x74, 0x65, 0x22, 0xb4, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, + 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, + 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, + 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, + 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, + 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, + 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, + 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x22, 0x58, 0x0a, 0x0a, 0x43, + 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, + 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, + 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, + 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, + 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, + 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, + 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, + 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, + 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, + 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, + 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, + 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, + 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, + 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, + 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, + 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, + 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, + 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, + 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, + 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, + 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, + 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xd9, 0x01, 0x0a, 0x0c, + 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, + 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, + 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, + 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, + 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, + 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, + 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, + 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, + 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, + 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, + 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, + 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, + 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, + 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, + 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, + 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, + 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, + 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, + 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, + 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, + 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, + 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x8f, 0x02, 0x0a, 0x11, 0x52, + 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, + 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, + 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, + 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, + 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, + 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, + 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, + 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, + 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x2a, 0x40, 0x0a, 0x0c, + 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, + 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, + 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, + 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x2a, 0x20, + 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, + 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, + 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, + 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, + 0x4f, 0x50, 0x10, 0x01, 0x32, 0x90, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, + 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, - 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, - 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, - 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, - 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, - 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, + 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, + 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, + 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, + 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, + 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, + 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, + 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, + 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, + 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, + 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, + 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, + 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, + 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, + 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -3059,13 +3342,13 @@ func file_management_proto_rawDescGZIP() []byte { } var file_management_proto_enumTypes = make([]protoimpl.EnumInfo, 5) -var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 34) +var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 37) var file_management_proto_goTypes = []interface{}{ - (HostConfig_Protocol)(0), // 0: management.HostConfig.Protocol - (DeviceAuthorizationFlowProvider)(0), // 1: management.DeviceAuthorizationFlow.provider - (FirewallRuleDirection)(0), // 2: management.FirewallRule.direction - (FirewallRuleAction)(0), // 3: management.FirewallRule.action - (FirewallRuleProtocol)(0), // 4: management.FirewallRule.protocol + (RuleProtocol)(0), // 0: management.RuleProtocol + (RuleDirection)(0), // 1: management.RuleDirection + (RuleAction)(0), // 2: management.RuleAction + (HostConfig_Protocol)(0), // 3: management.HostConfig.Protocol + (DeviceAuthorizationFlowProvider)(0), // 4: management.DeviceAuthorizationFlow.provider (*EncryptedMessage)(nil), // 5: management.EncryptedMessage (*SyncRequest)(nil), // 6: management.SyncRequest (*SyncResponse)(nil), // 7: management.SyncResponse @@ -3100,7 +3383,10 @@ var file_management_proto_goTypes = []interface{}{ (*FirewallRule)(nil), // 36: management.FirewallRule (*NetworkAddress)(nil), // 37: management.NetworkAddress (*Checks)(nil), // 38: management.Checks - (*timestamppb.Timestamp)(nil), // 39: google.protobuf.Timestamp + (*PortInfo)(nil), // 39: management.PortInfo + (*RouteFirewallRule)(nil), // 40: management.RouteFirewallRule + (*PortInfo_Range)(nil), // 41: management.PortInfo.Range + (*timestamppb.Timestamp)(nil), // 42: google.protobuf.Timestamp } var file_management_proto_depIdxs = []int32{ 13, // 0: management.SyncRequest.meta:type_name -> management.PeerSystemMeta @@ -3118,12 +3404,12 @@ var file_management_proto_depIdxs = []int32{ 17, // 12: management.LoginResponse.wiretrusteeConfig:type_name -> management.WiretrusteeConfig 21, // 13: management.LoginResponse.peerConfig:type_name -> management.PeerConfig 38, // 14: management.LoginResponse.Checks:type_name -> management.Checks - 39, // 15: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp + 42, // 15: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp 18, // 16: management.WiretrusteeConfig.stuns:type_name -> management.HostConfig 20, // 17: management.WiretrusteeConfig.turns:type_name -> management.ProtectedHostConfig 18, // 18: management.WiretrusteeConfig.signal:type_name -> management.HostConfig 19, // 19: management.WiretrusteeConfig.relay:type_name -> management.RelayConfig - 0, // 20: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol + 3, // 20: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol 18, // 21: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig 24, // 22: management.PeerConfig.sshConfig:type_name -> management.SSHConfig 21, // 23: management.NetworkMap.peerConfig:type_name -> management.PeerConfig @@ -3132,36 +3418,41 @@ var file_management_proto_depIdxs = []int32{ 31, // 26: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig 23, // 27: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig 36, // 28: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule - 24, // 29: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig - 1, // 30: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider - 29, // 31: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 29, // 32: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 34, // 33: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup - 32, // 34: management.DNSConfig.CustomZones:type_name -> management.CustomZone - 33, // 35: management.CustomZone.Records:type_name -> management.SimpleRecord - 35, // 36: management.NameServerGroup.NameServers:type_name -> management.NameServer - 2, // 37: management.FirewallRule.Direction:type_name -> management.FirewallRule.direction - 3, // 38: management.FirewallRule.Action:type_name -> management.FirewallRule.action - 4, // 39: management.FirewallRule.Protocol:type_name -> management.FirewallRule.protocol - 5, // 40: management.ManagementService.Login:input_type -> management.EncryptedMessage - 5, // 41: management.ManagementService.Sync:input_type -> management.EncryptedMessage - 16, // 42: management.ManagementService.GetServerKey:input_type -> management.Empty - 16, // 43: management.ManagementService.isHealthy:input_type -> management.Empty - 5, // 44: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage - 5, // 45: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage - 5, // 46: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage - 5, // 47: management.ManagementService.Login:output_type -> management.EncryptedMessage - 5, // 48: management.ManagementService.Sync:output_type -> management.EncryptedMessage - 15, // 49: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse - 16, // 50: management.ManagementService.isHealthy:output_type -> management.Empty - 5, // 51: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage - 5, // 52: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage - 16, // 53: management.ManagementService.SyncMeta:output_type -> management.Empty - 47, // [47:54] is the sub-list for method output_type - 40, // [40:47] is the sub-list for method input_type - 40, // [40:40] is the sub-list for extension type_name - 40, // [40:40] is the sub-list for extension extendee - 0, // [0:40] is the sub-list for field type_name + 40, // 29: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule + 24, // 30: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig + 4, // 31: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider + 29, // 32: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 29, // 33: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 34, // 34: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup + 32, // 35: management.DNSConfig.CustomZones:type_name -> management.CustomZone + 33, // 36: management.CustomZone.Records:type_name -> management.SimpleRecord + 35, // 37: management.NameServerGroup.NameServers:type_name -> management.NameServer + 1, // 38: management.FirewallRule.Direction:type_name -> management.RuleDirection + 2, // 39: management.FirewallRule.Action:type_name -> management.RuleAction + 0, // 40: management.FirewallRule.Protocol:type_name -> management.RuleProtocol + 41, // 41: management.PortInfo.range:type_name -> management.PortInfo.Range + 2, // 42: management.RouteFirewallRule.action:type_name -> management.RuleAction + 0, // 43: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol + 39, // 44: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo + 5, // 45: management.ManagementService.Login:input_type -> management.EncryptedMessage + 5, // 46: management.ManagementService.Sync:input_type -> management.EncryptedMessage + 16, // 47: management.ManagementService.GetServerKey:input_type -> management.Empty + 16, // 48: management.ManagementService.isHealthy:input_type -> management.Empty + 5, // 49: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage + 5, // 50: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage + 5, // 51: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage + 5, // 52: management.ManagementService.Login:output_type -> management.EncryptedMessage + 5, // 53: management.ManagementService.Sync:output_type -> management.EncryptedMessage + 15, // 54: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse + 16, // 55: management.ManagementService.isHealthy:output_type -> management.Empty + 5, // 56: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage + 5, // 57: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage + 16, // 58: management.ManagementService.SyncMeta:output_type -> management.Empty + 52, // [52:59] is the sub-list for method output_type + 45, // [45:52] is the sub-list for method input_type + 45, // [45:45] is the sub-list for extension type_name + 45, // [45:45] is the sub-list for extension extendee + 0, // [0:45] is the sub-list for field type_name } func init() { file_management_proto_init() } @@ -3578,6 +3869,46 @@ func file_management_proto_init() { return nil } } + file_management_proto_msgTypes[34].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*PortInfo); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[35].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RouteFirewallRule); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[36].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*PortInfo_Range); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + file_management_proto_msgTypes[34].OneofWrappers = []interface{}{ + (*PortInfo_Port)(nil), + (*PortInfo_Range_)(nil), } type x struct{} out := protoimpl.TypeBuilder{ @@ -3585,7 +3916,7 @@ func file_management_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_management_proto_rawDesc, NumEnums: 5, - NumMessages: 34, + NumMessages: 37, NumExtensions: 0, NumServices: 1, }, diff --git a/management/proto/management.proto b/management/proto/management.proto index c5646820f..fe6a828b1 100644 --- a/management/proto/management.proto +++ b/management/proto/management.proto @@ -254,6 +254,12 @@ message NetworkMap { // firewallRulesIsEmpty indicates whether FirewallRule array is empty or not to bypass protobuf null and empty array equality. bool firewallRulesIsEmpty = 9; + + // RoutesFirewallRules represents a list of routes firewall rules to be applied to peer + repeated RouteFirewallRule routesFirewallRules = 10; + + // RoutesFirewallRulesIsEmpty indicates whether RouteFirewallRule array is empty or not to bypass protobuf null and empty array equality. + bool routesFirewallRulesIsEmpty = 11; } // RemotePeerConfig represents a configuration of a remote peer. @@ -384,29 +390,32 @@ message NameServer { int64 Port = 3; } +enum RuleProtocol { + UNKNOWN = 0; + ALL = 1; + TCP = 2; + UDP = 3; + ICMP = 4; +} + +enum RuleDirection { + IN = 0; + OUT = 1; +} + +enum RuleAction { + ACCEPT = 0; + DROP = 1; +} + + // FirewallRule represents a firewall rule message FirewallRule { string PeerIP = 1; - direction Direction = 2; - action Action = 3; - protocol Protocol = 4; + RuleDirection Direction = 2; + RuleAction Action = 3; + RuleProtocol Protocol = 4; string Port = 5; - - enum direction { - IN = 0; - OUT = 1; - } - enum action { - ACCEPT = 0; - DROP = 1; - } - enum protocol { - UNKNOWN = 0; - ALL = 1; - TCP = 2; - UDP = 3; - ICMP = 4; - } } message NetworkAddress { @@ -415,5 +424,40 @@ message NetworkAddress { } message Checks { - repeated string Files= 1; + repeated string Files = 1; } + + +message PortInfo { + oneof portSelection { + uint32 port = 1; + Range range = 2; + } + + message Range { + uint32 start = 1; + uint32 end = 2; + } +} + +// RouteFirewallRule signifies a firewall rule applicable for a routed network. +message RouteFirewallRule { + // sourceRanges IP ranges of the routing peers. + repeated string sourceRanges = 1; + + // Action to be taken by the firewall when the rule is applicable. + RuleAction action = 2; + + // Network prefix for the routed network. + string destination = 3; + + // Protocol of the routed network. + RuleProtocol protocol = 4; + + // Details about the port. + PortInfo portInfo = 5; + + // IsDynamic indicates if the route is a DNS route. + bool isDynamic = 6; +} + diff --git a/management/server/account.go b/management/server/account.go index 0022b14d8..7c84ad1ca 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -4,6 +4,7 @@ import ( "context" "crypto/sha256" b64 "encoding/base64" + "errors" "fmt" "hash/crc32" "math/rand" @@ -20,6 +21,11 @@ import ( cacheStore "github.com/eko/gocache/v3/store" "github.com/hashicorp/go-multierror" "github.com/miekg/dns" + gocache "github.com/patrickmn/go-cache" + "github.com/rs/xid" + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + "github.com/netbirdio/netbird/base62" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" @@ -36,10 +42,6 @@ import ( "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/route" - gocache "github.com/patrickmn/go-cache" - "github.com/rs/xid" - log "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" ) const ( @@ -49,6 +51,9 @@ const ( CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days DefaultPeerLoginExpiration = 24 * time.Hour + DefaultPeerInactivityExpiration = 10 * time.Minute + emptyUserID = "empty user ID in claims" + errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" ) type userLoggedInOnce bool @@ -76,7 +81,8 @@ type AccountManager interface { SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) - GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) + AccountExists(ctx context.Context, accountID string) (bool, error) + GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error) @@ -113,7 +119,7 @@ type AccountManager interface { DeletePolicy(ctx context.Context, accountID, policyID, userID string) error ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) - CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) + CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) SaveRoute(ctx context.Context, accountID, userID string, route *route.Route) error DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) @@ -128,14 +134,14 @@ type AccountManager interface { GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) - UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Settings, error) + UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API GetAllConnectedPeers() (map[string]struct{}, error) HasConnectedChannel(peerID string) bool GetExternalCacheManager() ExternalCacheManager GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) - SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, isUpdate bool) error + SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) GetIdpManager() idp.Manager @@ -176,6 +182,8 @@ type DefaultAccountManager struct { dnsDomain string peerLoginExpiry Scheduler + peerInactivityExpiry Scheduler + // userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account userDeleteFromIDPEnabled bool @@ -193,6 +201,13 @@ type Settings struct { // Applies to all peers that have Peer.LoginExpirationEnabled set to true. PeerLoginExpiration time.Duration + // PeerInactivityExpirationEnabled globally enables or disables peer inactivity expiration + PeerInactivityExpirationEnabled bool + + // PeerInactivityExpiration is a setting that indicates when peer inactivity expires. + // Applies to all peers that have Peer.PeerInactivityExpirationEnabled set to true. + PeerInactivityExpiration time.Duration + // RegularUsersViewBlocked allows to block regular users from viewing even their own peers and some UI elements RegularUsersViewBlocked bool @@ -223,6 +238,9 @@ func (s *Settings) Copy() *Settings { GroupsPropagationEnabled: s.GroupsPropagationEnabled, JWTAllowGroups: s.JWTAllowGroups, RegularUsersViewBlocked: s.RegularUsersViewBlocked, + + PeerInactivityExpirationEnabled: s.PeerInactivityExpirationEnabled, + PeerInactivityExpiration: s.PeerInactivityExpiration, } if s.Extra != nil { settings.Extra = s.Extra.Copy() @@ -460,6 +478,7 @@ func (a *Account) GetPeerNetworkMap( } routesUpdate := a.getRoutesToSync(ctx, peerID, peersToConnect) + routesFirewallRules := a.getPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap) dnsManagementStatus := a.getPeerDNSManagementStatus(peerID) dnsUpdate := nbdns.Config{ @@ -477,12 +496,13 @@ func (a *Account) GetPeerNetworkMap( } nm := &NetworkMap{ - Peers: peersToConnect, - Network: a.Network.Copy(), - Routes: routesUpdate, - DNSConfig: dnsUpdate, - OfflinePeers: expiredPeers, - FirewallRules: firewallRules, + Peers: peersToConnect, + Network: a.Network.Copy(), + Routes: routesUpdate, + DNSConfig: dnsUpdate, + OfflinePeers: expiredPeers, + FirewallRules: firewallRules, + RoutesFirewallRules: routesFirewallRules, } if metrics != nil { @@ -602,6 +622,60 @@ func (a *Account) GetPeersWithExpiration() []*nbpeer.Peer { return peers } +// GetInactivePeers returns peers that have been expired by inactivity +func (a *Account) GetInactivePeers() []*nbpeer.Peer { + var peers []*nbpeer.Peer + for _, inactivePeer := range a.GetPeersWithInactivity() { + inactive, _ := inactivePeer.SessionExpired(a.Settings.PeerInactivityExpiration) + if inactive { + peers = append(peers, inactivePeer) + } + } + return peers +} + +// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. +// If there is no peer that expires this function returns false and a duration of 0. +// This function only considers peers that haven't been expired yet and that are not connected. +func (a *Account) GetNextInactivePeerExpiration() (time.Duration, bool) { + peersWithExpiry := a.GetPeersWithInactivity() + if len(peersWithExpiry) == 0 { + return 0, false + } + var nextExpiry *time.Duration + for _, peer := range peersWithExpiry { + if peer.Status.LoginExpired || peer.Status.Connected { + continue + } + _, duration := peer.SessionExpired(a.Settings.PeerInactivityExpiration) + if nextExpiry == nil || duration < *nextExpiry { + // if expiration is below 1s return 1s duration + // this avoids issues with ticker that can't be set to < 0 + if duration < time.Second { + return time.Second, true + } + nextExpiry = &duration + } + } + + if nextExpiry == nil { + return 0, false + } + + return *nextExpiry, true +} + +// GetPeersWithInactivity eturns a list of peers that have Peer.InactivityExpirationEnabled set to true and that were added by a user +func (a *Account) GetPeersWithInactivity() []*nbpeer.Peer { + peers := make([]*nbpeer.Peer, 0) + for _, peer := range a.Peers { + if peer.InactivityExpirationEnabled && peer.AddedWithSSOLogin() { + peers = append(peers, peer) + } + } + return peers +} + // GetPeers returns a list of all Account peers func (a *Account) GetPeers() []*nbpeer.Peer { var peers []*nbpeer.Peer @@ -841,55 +915,54 @@ func (a *Account) GetPeer(peerID string) *nbpeer.Peer { return a.Peers[peerID] } -// SetJWTGroups updates the user's auto groups by synchronizing JWT groups. -// Returns true if there are changes in the JWT group membership. -func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { - user, ok := a.Users[userID] - if !ok { - return false - } - +// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups. +// Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups, +// newly groups to create and an error if any occurred. +func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgroup.Group, groupNames []string) (bool, []string, []*nbgroup.Group, error) { existedGroupsByName := make(map[string]*nbgroup.Group) - for _, group := range a.Groups { + for _, group := range groups { existedGroupsByName[group.Name] = group } - newAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, a.Groups) - groupsToAdd := difference(groupsNames, maps.Keys(jwtGroupsMap)) - groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupsNames) + newUserAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, groups) + + groupsToAdd := difference(groupNames, maps.Keys(jwtGroupsMap)) + groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupNames) // If no groups are added or removed, we should not sync account if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { - return false + return false, nil, nil, nil } + newGroupsToCreate := make([]*nbgroup.Group, 0) + var modified bool for _, name := range groupsToAdd { group, exists := existedGroupsByName[name] if !exists { group = &nbgroup.Group{ - ID: xid.New().String(), - Name: name, - Issued: nbgroup.GroupIssuedJWT, + ID: xid.New().String(), + AccountID: user.AccountID, + Name: name, + Issued: nbgroup.GroupIssuedJWT, } - a.Groups[group.ID] = group + newGroupsToCreate = append(newGroupsToCreate, group) } if group.Issued == nbgroup.GroupIssuedJWT { - newAutoGroups = append(newAutoGroups, group.ID) + newUserAutoGroups = append(newUserAutoGroups, group.ID) modified = true } } for name, id := range jwtGroupsMap { if !slices.Contains(groupsToRemove, name) { - newAutoGroups = append(newAutoGroups, id) + newUserAutoGroups = append(newUserAutoGroups, id) continue } modified = true } - user.AutoGroups = newAutoGroups - return modified + return modified, newUserAutoGroups, newGroupsToCreate, nil } // UserGroupsAddToPeers adds groups to all peers of user @@ -969,6 +1042,7 @@ func BuildManager( dnsDomain: dnsDomain, eventStore: eventStore, peerLoginExpiry: NewDefaultScheduler(), + peerInactivityExpiry: NewDefaultScheduler(), userDeleteFromIDPEnabled: userDeleteFromIDPEnabled, integratedPeerValidator: integratedPeerValidator, metrics: metrics, @@ -1048,16 +1122,7 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager { // Only users with role UserRoleAdmin can update the account. // User that performs the update has to belong to the account. // Returns an updated Account -func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Settings, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) - if err != nil { - return nil, err - } - - if !user.HasAdminPower() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account") - } - +func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) { halfYearLimit := 180 * 24 * time.Hour if newSettings.PeerLoginExpiration > halfYearLimit { return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") @@ -1067,57 +1132,78 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") } - oldSettings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } - if err = am.validateExtraSettings(ctx, newSettings, oldSettings, userID, accountID); err != nil { + user, err := account.FindUser(userID) + if err != nil { return nil, err } - if err = am.Store.SaveAccountSettings(ctx, LockingStrengthUpdate, accountID, newSettings); err != nil { - return nil, fmt.Errorf("failed updating account settings: %w", err) + if !user.HasAdminPower() { + return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account") } + err = am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID) + if err != nil { + return nil, err + } + + oldSettings := account.Settings if oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled { event := activity.AccountPeerLoginExpirationEnabled if !newSettings.PeerLoginExpirationEnabled { event = activity.AccountPeerLoginExpirationDisabled am.peerLoginExpiry.Cancel(ctx, []string{accountID}) } else { - am.checkAndSchedulePeerLoginExpiration(ctx, accountID) + am.checkAndSchedulePeerLoginExpiration(ctx, account) } am.StoreEvent(ctx, userID, accountID, accountID, event, nil) } if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration { am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil) - am.checkAndSchedulePeerLoginExpiration(ctx, accountID) + am.checkAndSchedulePeerLoginExpiration(ctx, account) } - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + err = am.handleInactivityExpirationSettings(ctx, account, oldSettings, newSettings, userID, accountID) if err != nil { - return nil, fmt.Errorf("error getting account: %w", err) + return nil, err } - am.updateAccountPeers(ctx, account) - return newSettings, nil + updatedAccount := account.UpdateSettings(newSettings) + + err = am.Store.SaveAccount(ctx, account) + if err != nil { + return nil, err + } + + return updatedAccount, nil } -// validateExtraSettings validates the extra settings of the account. -func (am *DefaultAccountManager) validateExtraSettings(ctx context.Context, newSettings, oldSettings *Settings, userID, accountID string) error { - peers, err := am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID) - if err != nil { - return err +func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error { + if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled { + event := activity.AccountPeerInactivityExpirationEnabled + if !newSettings.PeerInactivityExpirationEnabled { + event = activity.AccountPeerInactivityExpirationDisabled + am.peerInactivityExpiry.Cancel(ctx, []string{accountID}) + } else { + am.checkAndSchedulePeerInactivityExpiration(ctx, account) + } + am.StoreEvent(ctx, userID, accountID, accountID, event, nil) } - peerMap := make(map[string]*nbpeer.Peer, len(peers)) - for _, peer := range peers { - peerMap[peer.ID] = peer + if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil) + am.checkAndSchedulePeerInactivityExpiration(ctx, account) } - return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, peerMap, userID, accountID) + return nil } func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { @@ -1125,7 +1211,6 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - // TODO: call direct on the store to get expired peers account, err := am.Store.GetAccount(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed getting account %s expiring peers", accountID) @@ -1140,7 +1225,7 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc log.WithContext(ctx).Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id) - if err := am.expireAndUpdatePeers(ctx, accountID, expiredPeers); err != nil { + if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil { log.WithContext(ctx).Errorf("failed updating account peers while expiring peers for account %s", account.Id) return account.GetNextPeerExpiration() } @@ -1149,10 +1234,47 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc } } -func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, accountID string) { - am.peerLoginExpiry.Cancel(ctx, []string{accountID}) - if nextRun, ok := am.getNextPeerExpiration(ctx, accountID); ok { - go am.peerLoginExpiry.Schedule(ctx, nextRun, accountID, am.peerLoginExpirationJob(ctx, accountID)) +func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, account *Account) { + am.peerLoginExpiry.Cancel(ctx, []string{account.Id}) + if nextRun, ok := account.GetNextPeerExpiration(); ok { + go am.peerLoginExpiry.Schedule(ctx, nextRun, account.Id, am.peerLoginExpirationJob(ctx, account.Id)) + } +} + +// peerInactivityExpirationJob marks login expired for all inactive peers and returns the minimum duration in which the next peer of the account will expire by inactivity if found +func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { + return func() (time.Duration, bool) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + account, err := am.Store.GetAccount(ctx, accountID) + if err != nil { + log.Errorf("failed getting account %s expiring peers", account.Id) + return account.GetNextInactivePeerExpiration() + } + + expiredPeers := account.GetInactivePeers() + var peerIDs []string + for _, peer := range expiredPeers { + peerIDs = append(peerIDs, peer.ID) + } + + log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id) + + if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil { + log.Errorf("failed updating account peers while expiring peers for account %s", account.Id) + return account.GetNextInactivePeerExpiration() + } + + return account.GetNextInactivePeerExpiration() + } +} + +// checkAndSchedulePeerInactivityExpiration periodically checks for inactive peers to end their sessions +func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, account *Account) { + am.peerInactivityExpiry.Cancel(ctx, []string{account.Id}) + if nextRun, ok := account.GetNextInactivePeerExpiration(); ok { + go am.peerInactivityExpiry.Schedule(ctx, nextRun, account.Id, am.peerInactivityExpirationJob(ctx, account.Id)) } } @@ -1274,37 +1396,36 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u return nil } -// GetAccountIDByUserOrAccountID retrieves the account ID based on either the userID or accountID provided. -// If an accountID is provided, it checks if the account exists and returns it. -// If no accountID is provided, but a userID is given, it tries to retrieve the account by userID. +// AccountExists checks if an account exists. +func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) { + return am.Store.AccountExists(ctx, LockingStrengthShare, accountID) +} + +// GetAccountIDByUserID retrieves the account ID based on the userID provided. +// If user does have an account, it returns the user's account ID. // If the user doesn't have an account, it creates one using the provided domain. // Returns the account ID or an error if none is found or created. -func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) { - if accountID != "" { - exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID) - if err != nil { - return "", err - } - if !exists { - return "", status.Errorf(status.NotFound, "account %s does not exist", accountID) - } - return accountID, nil +func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) { + if userID == "" { + return "", status.Errorf(status.NotFound, "no valid userID provided") } - if userID != "" { - account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) - if err != nil { - return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID) - } + accountID, err := am.Store.GetAccountIDByUserID(userID) + if err != nil { + if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { + account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) + if err != nil { + return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID) + } - if err = am.addAccountIDToIDPAppMeta(ctx, userID, account.Id); err != nil { - return "", err + if err = am.addAccountIDToIDPAppMeta(ctx, userID, account.Id); err != nil { + return "", err + } + return account.Id, nil } - - return account.Id, nil + return "", err } - - return "", status.Errorf(status.NotFound, "no valid userID or accountID provided") + return accountID, nil } func isNil(i idp.Manager) bool { @@ -1314,9 +1435,20 @@ func isNil(i idp.Manager) bool { // addAccountIDToIDPAppMeta update user's app metadata in idp manager func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error { if !isNil(am.idpManager) { + accountUsers, err := am.Store.GetAccountUsers(ctx, accountID) + if err != nil { + return err + } + cachedAccount := &Account{ + Id: accountID, + Users: make(map[string]*User), + } + for _, user := range accountUsers { + cachedAccount.Users[user.Id] = user + } // user can be nil if it wasn't found (e.g., just created) - user, err := am.lookupUserInCache(ctx, userID, accountID) + user, err := am.lookupUserInCache(ctx, userID, cachedAccount) if err != nil { return err } @@ -1392,15 +1524,10 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(ctx context.Context, e } // lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil -func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, accountID string) (*idp.UserData, error) { - accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID) - if err != nil { - return nil, err - } - - users := make(map[string]userLoggedInOnce, len(accountUsers)) +func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, account *Account) (*idp.UserData, error) { + users := make(map[string]userLoggedInOnce, len(account.Users)) // ignore service users and users provisioned by integrations than are never logged in - for _, user := range accountUsers { + for _, user := range account.Users { if user.IsServiceUser { continue } @@ -1409,9 +1536,8 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s } users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero()) } - - log.WithContext(ctx).Debugf("looking up user %s of account %s in cache", userID, accountID) - userData, err := am.lookupCache(ctx, users, accountID) + log.WithContext(ctx).Debugf("looking up user %s of account %s in cache", userID, account.Id) + userData, err := am.lookupCache(ctx, users, account.Id) if err != nil { return nil, err } @@ -1424,13 +1550,13 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s // add extra check on external cache manager. We may get to this point when the user is not yet findable in IDP, // or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := account.FindUser(userID) if err != nil { - log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, accountID) + log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, account.Id) return nil, err } - key := user.IntegrationReference.CacheKey(accountID, userID) + key := user.IntegrationReference.CacheKey(account.Id, userID) ud, err := am.externalCacheManager.Get(am.ctx, key) if err != nil { log.WithContext(ctx).Debugf("failed to get externalCache for key: %s, error: %s", key, err) @@ -1563,49 +1689,69 @@ func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accoun return am.cacheManager.Set(am.ctx, accountID, data, cacheStore.WithExpiration(cacheEntryExpiration())) } -// updateAccountDomainAttributes updates the account domain attributes and then, saves the account -func (am *DefaultAccountManager) updateAccountDomainAttributes(ctx context.Context, account *Account, claims jwtclaims.AuthorizationClaims, +// updateAccountDomainAttributesIfNotUpToDate updates the account domain attributes if they are not up to date and then, saves the account changes +func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims, primaryDomain bool, ) error { - - if claims.Domain != "" { - account.IsDomainPrimaryAccount = primaryDomain - - lowerDomain := strings.ToLower(claims.Domain) - userObj := account.Users[claims.UserId] - if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin { - account.Domain = lowerDomain - } - // prevent updating category for different domain until admin logs in - if account.Domain == lowerDomain { - account.DomainCategory = claims.DomainCategory - } - } else { + if claims.Domain == "" { log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims) + return nil } - err := am.Store.SaveAccount(ctx, account) + unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlockAccount() + + accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, accountID) if err != nil { + log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) return err } - return nil + + if domainIsUpToDate(accountDomain, domainCategory, claims) { + return nil + } + + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) + if err != nil { + log.WithContext(ctx).Errorf("error getting user: %v", err) + return err + } + + newDomain := accountDomain + newCategoty := domainCategory + + lowerDomain := strings.ToLower(claims.Domain) + if accountDomain != lowerDomain && user.HasAdminPower() { + newDomain = lowerDomain + } + + if accountDomain == lowerDomain { + newCategoty = claims.DomainCategory + } + + return am.Store.UpdateAccountDomainAttributes(ctx, accountID, newDomain, newCategoty, primaryDomain) } // handleExistingUserAccount handles existing User accounts and update its domain attributes. +// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise, +// we compare the account's ID with the domain account ID, and if they don't match, we set the account as +// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain +// was previously unclassified or classified as public so N users that logged int that time, has they own account +// and peers that shouldn't be lost. func (am *DefaultAccountManager) handleExistingUserAccount( ctx context.Context, - existingAcc *Account, - primaryDomain bool, + userAccountID string, + domainAccountID string, claims jwtclaims.AuthorizationClaims, ) error { - // TODO: remove account as parameter and pass accountID string - err := am.updateAccountDomainAttributes(ctx, existingAcc, claims, primaryDomain) + primaryDomain := domainAccountID == "" || userAccountID == domainAccountID + err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, claims, primaryDomain) if err != nil { return err } // we should register the account ID to this user's metadata in our IDP manager - err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, existingAcc.Id) + err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, userAccountID) if err != nil { return err } @@ -1613,44 +1759,58 @@ func (am *DefaultAccountManager) handleExistingUserAccount( return nil } -// handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account, +// addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account, // otherwise it will create a new account and make it primary account for the domain. -func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) { +func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) { if claims.UserId == "" { - return nil, fmt.Errorf("user ID is empty") + return "", fmt.Errorf("user ID is empty") } - var ( - account *Account - err error - ) + lowerDomain := strings.ToLower(claims.Domain) - // if domain already has a primary account, add regular user - if domainAcc != nil { - account = domainAcc - account.Users[claims.UserId] = NewRegularUser(claims.UserId) - err = am.Store.SaveAccount(ctx, account) - if err != nil { - return nil, err - } - } else { - account, err = am.newAccount(ctx, claims.UserId, lowerDomain) - if err != nil { - return nil, err - } - err = am.updateAccountDomainAttributes(ctx, account, claims, true) - if err != nil { - return nil, err - } - } - err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, account.Id) + newAccount, err := am.newAccount(ctx, claims.UserId, lowerDomain) if err != nil { - return nil, err + return "", err } - am.StoreEvent(ctx, claims.UserId, claims.UserId, account.Id, activity.UserJoined, nil) + newAccount.Domain = lowerDomain + newAccount.DomainCategory = claims.DomainCategory + newAccount.IsDomainPrimaryAccount = true - return account, nil + err = am.Store.SaveAccount(ctx, newAccount) + if err != nil { + return "", err + } + + err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, newAccount.Id) + if err != nil { + return "", err + } + + am.StoreEvent(ctx, claims.UserId, claims.UserId, newAccount.Id, activity.UserJoined, nil) + + return newAccount.Id, nil +} + +func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) { + unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) + defer unlockAccount() + + usersMap := make(map[string]*User) + usersMap[claims.UserId] = NewRegularUser(claims.UserId) + err := am.Store.SaveUsers(domainAccountID, usersMap) + if err != nil { + return "", err + } + + err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, domainAccountID) + if err != nil { + return "", err + } + + am.StoreEvent(ctx, claims.UserId, claims.UserId, domainAccountID, activity.UserJoined, nil) + + return domainAccountID, nil } // redeemInvite checks whether user has been invited and redeems the invite @@ -1661,12 +1821,12 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str return nil } - _, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } - user, err := am.lookupUserInCache(ctx, userID, accountID) + user, err := am.lookupUserInCache(ctx, userID, account) if err != nil { return err } @@ -1676,17 +1836,17 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str } if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite { - log.WithContext(ctx).Infof("redeeming invite for user %s account %s", userID, accountID) + log.WithContext(ctx).Infof("redeeming invite for user %s account %s", userID, account.Id) // User has already logged in, meaning that IdP should have set wt_pending_invite to false. // Our job is to just reload cache. go func() { - _, err = am.refreshCache(ctx, accountID) + _, err = am.refreshCache(ctx, account.Id) if err != nil { - log.WithContext(ctx).Warnf("failed reloading cache when redeeming user %s under account %s", userID, accountID) + log.WithContext(ctx).Warnf("failed reloading cache when redeeming user %s under account %s", userID, account.Id) return } - log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", user.ID, accountID) - am.StoreEvent(ctx, userID, userID, accountID, activity.UserJoined, nil) + log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", user.ID, account.Id) + am.StoreEvent(ctx, userID, userID, account.Id, activity.UserJoined, nil) }() } @@ -1695,18 +1855,33 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str // MarkPATUsed marks a personal access token as used func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string) error { + user, err := am.Store.GetUserByTokenID(ctx, tokenID) if err != nil { return err } - pat, err := am.Store.GetPATByID(ctx, LockingStrengthShare, tokenID, user.Id) + account, err := am.Store.GetAccountByUser(ctx, user.Id) if err != nil { return err } + + unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id) + defer unlock() + + account, err = am.Store.GetAccountByUser(ctx, user.Id) + if err != nil { + return err + } + + pat, ok := account.Users[user.Id].PATs[tokenID] + if !ok { + return fmt.Errorf("token not found") + } + pat.LastUsed = time.Now().UTC() - return am.Store.SavePAT(ctx, LockingStrengthUpdate, pat) + return am.Store.SaveAccount(ctx, account) } // GetAccount returns an account associated with this account ID. @@ -1769,7 +1944,7 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s return nil, err } - if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) { + if user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data") } @@ -1779,7 +1954,7 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s // GetAccountIDFromToken returns an account ID associated with this token. func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { if claims.UserId == "" { - return "", "", fmt.Errorf("user ID is empty") + return "", "", errors.New(emptyUserID) } if am.singleAccountMode && am.singleAccountModeDomain != "" { // This section is mostly related to self-hosted installations. @@ -1800,6 +1975,10 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId) } + if user.AccountID != accountID { + return "", "", status.Errorf(status.PermissionDenied, "user %s is not part of the account %s", claims.UserId, accountID) + } + if !user.IsServiceUser && claims.Invited { err = am.redeemInvite(ctx, accountID, user.Id) if err != nil { @@ -1807,7 +1986,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai } } - if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil { + if err = am.syncJWTGroups(ctx, accountID, claims); err != nil { return "", "", err } @@ -1816,7 +1995,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // and propagates changes to peers if group propagation is enabled. -func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, user *User, claims jwtclaims.AuthorizationClaims) error { +func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error { settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { return err @@ -1827,83 +2006,151 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } if settings.JWTGroupsClaimName == "" { - log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set") + log.WithContext(ctx).Debugf("JWT groups are enabled but no claim name is set") return nil } - // TODO: Remove GetAccount after refactoring account peer's update - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) - oldGroups := make([]string, len(user.AutoGroups)) - copy(oldGroups, user.AutoGroups) + unlockPeer := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer func() { + if unlockPeer != nil { + unlockPeer() + } + }() - // Update the account if group membership changes - if account.SetJWTGroups(claims.UserId, jwtGroupsNames) { - addNewGroups := difference(user.AutoGroups, oldGroups) - removeOldGroups := difference(oldGroups, user.AutoGroups) - - if settings.GroupsPropagationEnabled { - account.UserGroupsAddToPeers(claims.UserId, addNewGroups...) - account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...) - account.Network.IncSerial() + var addNewGroups []string + var removeOldGroups []string + var hasChanges bool + var user *User + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + user, err = am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) + if err != nil { + return fmt.Errorf("error getting user: %w", err) } - if err := am.Store.SaveAccount(ctx, account); err != nil { - log.WithContext(ctx).Errorf("failed to save account: %v", err) + groups, err := am.Store.GetAccountGroups(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account groups: %w", err) + } + + changed, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(user, groups, jwtGroupsNames) + if err != nil { + return fmt.Errorf("error getting JWT groups changes: %w", err) + } + + hasChanges = changed + // skip update if no changes + if !changed { return nil } + if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil { + return fmt.Errorf("error saving groups: %w", err) + } + + addNewGroups = difference(updatedAutoGroups, user.AutoGroups) + removeOldGroups = difference(user.AutoGroups, updatedAutoGroups) + + user.AutoGroups = updatedAutoGroups + if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil { + return fmt.Errorf("error saving user: %w", err) + } + // Propagate changes to peers if group propagation is enabled if settings.GroupsPropagationEnabled { - log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) - am.updateAccountPeers(ctx, account) - } + groups, err = transaction.GetAccountGroups(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account groups: %w", err) + } - for _, g := range addNewGroups { - if group := account.GetGroup(g); group != nil { - am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser, - map[string]any{ - "group": group.Name, - "group_id": group.ID, - "is_service_user": user.IsServiceUser, - "user_name": user.ServiceUserName}) + groupsMap := make(map[string]*nbgroup.Group, len(groups)) + for _, group := range groups { + groupsMap[group.ID] = group + } + + peers, err := transaction.GetUserPeers(ctx, LockingStrengthShare, accountID, claims.UserId) + if err != nil { + return fmt.Errorf("error getting user peers: %w", err) + } + + updatedGroups, err := am.updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups) + if err != nil { + return fmt.Errorf("error modifying user peers in groups: %w", err) + } + + if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, updatedGroups); err != nil { + return fmt.Errorf("error saving groups: %w", err) + } + + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + return fmt.Errorf("error incrementing network serial: %w", err) } } + unlockPeer() + unlockPeer = nil - for _, g := range removeOldGroups { - if group := account.GetGroup(g); group != nil { - am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser, - map[string]any{ - "group": group.Name, - "group_id": group.ID, - "is_service_user": user.IsServiceUser, - "user_name": user.ServiceUserName}) + return nil + }) + if err != nil { + return err + } + + if !hasChanges { + return nil + } + + for _, g := range addNewGroups { + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + if err != nil { + log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) + } else { + meta := map[string]any{ + "group": group.Name, "group_id": group.ID, + "is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName, } + am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupAddedToUser, meta) } } + for _, g := range removeOldGroups { + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + if err != nil { + log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) + } else { + meta := map[string]any{ + "group": group.Name, "group_id": group.ID, + "is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName, + } + am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupRemovedFromUser, meta) + } + } + + if settings.GroupsPropagationEnabled { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account: %w", err) + } + + log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) + am.updateAccountPeers(ctx, account) + } + return nil } // getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims. +// if domain is not private or domain is invalid, it will return the account ID by user ID. // if domain is of the PrivateCategory category, it will evaluate // if account is new, existing or if there is another account with the same domain // // Use cases: // -// New user + New account + New domain -> create account, user role = admin (if private domain, index domain) +// New user + New account + New domain -> create account, user role = owner (if private domain, index domain) // -// New user + New account + Existing Private Domain -> add user to the existing account, user role = regular (not admin) +// New user + New account + Existing Private Domain -> add user to the existing account, user role = user (not admin) // -// New user + New account + Existing Public Domain -> create account, user role = admin +// New user + New account + Existing Public Domain -> create account, user role = owner // // Existing user + Existing account + Existing Domain -> Nothing changes (if private, index domain) // @@ -1913,88 +2160,123 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"", claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory) + if claims.UserId == "" { - return "", fmt.Errorf("user ID is empty") + return "", errors.New(emptyUserID) } - // if Account ID is part of the claims - // it means that we've already classified the domain and user has an account if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { - return am.GetAccountIDByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain) - } else if claims.AccountId != "" { - userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) - if err != nil { - return "", err - } - - if userAccountID != claims.AccountId { - return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) - } - - domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId) - if err != nil { - return "", err - } - - if domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain { - return userAccountID, nil - } + return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain) } - start := time.Now() - unlock := am.Store.AcquireGlobalLock(ctx) - defer unlock() - log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId) + if claims.AccountId != "" { + return am.handlePrivateAccountWithIDFromClaim(ctx, claims) + } // We checked if the domain has a primary account already - domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain) + domainAccountID, cancel, err := am.getPrivateDomainWithGlobalLock(ctx, claims.Domain) + if cancel != nil { + defer cancel() + } if err != nil { - // if NotFound we are good to continue, otherwise return error - e, ok := status.FromError(err) - if !ok || e.Type() != status.NotFound { - return "", err - } + return "", err } userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) - if err == nil { - unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAccountID) - defer unlockAccount() - account, err := am.Store.GetAccountByUser(ctx, claims.UserId) - if err != nil { - return "", err - } - // If there is no primary domain account yet, we set the account as primary for the domain. Otherwise, - // we compare the account's ID with the domain account ID, and if they don't match, we set the account as - // non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain - // was previously unclassified or classified as public so N users that logged int that time, has they own account - // and peers that shouldn't be lost. - primaryDomain := domainAccountID == "" || account.Id == domainAccountID - if err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims); err != nil { - return "", err - } - - return account.Id, nil - } else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { - var domainAccount *Account - if domainAccountID != "" { - unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) - defer unlockAccount() - domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain) - if err != nil { - return "", err - } - } - - account, err := am.handleNewUserAccount(ctx, domainAccount, claims) - if err != nil { - return "", err - } - return account.Id, nil - } else { - // other error + if handleNotFound(err) != nil { + log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) return "", err } + + if userAccountID != "" { + if err = am.handleExistingUserAccount(ctx, userAccountID, domainAccountID, claims); err != nil { + return "", err + } + + return userAccountID, nil + } + + if domainAccountID != "" { + return am.addNewUserToDomainAccount(ctx, domainAccountID, claims) + } + + return am.addNewPrivateAccount(ctx, domainAccountID, claims) +} +func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) { + domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) + if handleNotFound(err) != nil { + + log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) + return "", nil, err + } + + if domainAccountID != "" { + return domainAccountID, nil, nil + } + + log.WithContext(ctx).Debugf("no primary account found for domain %s, acquiring global lock", domain) + cancel := am.Store.AcquireGlobalLock(ctx) + + // check again if the domain has a primary account because of simultaneous requests + domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) + if handleNotFound(err) != nil { + log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) + return "", nil, err + } + + return domainAccountID, cancel, nil +} + +func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { + userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) + if err != nil { + log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) + return "", err + } + + if userAccountID != claims.AccountId { + return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) + } + + accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId) + if handleNotFound(err) != nil { + log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) + return "", err + } + + if domainIsUpToDate(accountDomain, domainCategory, claims) { + return claims.AccountId, nil + } + + // We checked if the domain has a primary account already + domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain) + if handleNotFound(err) != nil { + log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) + return "", err + } + + err = am.handleExistingUserAccount(ctx, claims.AccountId, domainAccountID, claims) + if err != nil { + return "", err + } + + return claims.AccountId, nil +} + +func handleNotFound(err error) error { + if err == nil { + return nil + } + + e, ok := status.FromError(err) + if !ok || e.Type() != status.NotFound { + return err + } + return nil +} + +func domainIsUpToDate(domain string, domainCategory string, claims jwtclaims.AuthorizationClaims) bool { + return claims.Domain != "" && claims.Domain != domain && claims.DomainCategory == PrivateCategory && domainCategory != PrivateCategory } func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { @@ -2233,7 +2515,11 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac routes := make(map[route.ID]*route.Route) setupKeys := map[string]*SetupKey{} nameServersGroups := make(map[string]*nbdns.NameServerGroup) - users[userID] = NewOwnerUser(userID) + + owner := NewOwnerUser(userID) + owner.AccountID = accountID + users[userID] = owner + dnsSettings := DNSSettings{ DisabledManagementGroups: make([]string, 0), } @@ -2256,6 +2542,9 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac PeerLoginExpiration: DefaultPeerLoginExpiration, GroupsPropagationEnabled: true, RegularUsersViewBlocked: true, + + PeerInactivityExpirationEnabled: false, + PeerInactivityExpiration: DefaultPeerInactivityExpiration, }, } @@ -2301,12 +2590,17 @@ func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool { // separateGroups separates user's auto groups into non-JWT and JWT groups. // Returns the list of standard auto groups and a map of JWT auto groups, // where the keys are the group names and the values are the group IDs. -func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([]string, map[string]string) { +func separateGroups(autoGroups []string, allGroups []*nbgroup.Group) ([]string, map[string]string) { newAutoGroups := make([]string, 0) jwtAutoGroups := make(map[string]string) // map of group name to group ID + allGroupsMap := make(map[string]*nbgroup.Group, len(allGroups)) + for _, group := range allGroups { + allGroupsMap[group.ID] = group + } + for _, id := range autoGroups { - if group, ok := allGroups[id]; ok { + if group, ok := allGroupsMap[id]; ok { if group.Issued == nbgroup.GroupIssuedJWT { jwtAutoGroups[group.Name] = id } else { @@ -2314,5 +2608,6 @@ func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([ } } } + return newAutoGroups, jwtAutoGroups } diff --git a/management/server/account_test.go b/management/server/account_test.go index c5513177d..f10cc0b29 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -465,7 +465,26 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { type initUserParams jwtclaims.AuthorizationClaims - type test struct { + var ( + publicDomain = "public.com" + privateDomain = "private.com" + unknownDomain = "unknown.com" + ) + + defaultInitAccount := initUserParams{ + Domain: publicDomain, + UserId: "defaultUser", + } + + initUnknown := defaultInitAccount + initUnknown.DomainCategory = UnknownCategory + initUnknown.Domain = unknownDomain + + privateInitAccount := defaultInitAccount + privateInitAccount.Domain = privateDomain + privateInitAccount.DomainCategory = PrivateCategory + + testCases := []struct { name string inputClaims jwtclaims.AuthorizationClaims inputInitUserParams initUserParams @@ -479,168 +498,143 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { expectedPrimaryDomainStatus bool expectedCreatedBy string expectedUsers []string - } - - var ( - publicDomain = "public.com" - privateDomain = "private.com" - unknownDomain = "unknown.com" - ) - - defaultInitAccount := initUserParams{ - Domain: publicDomain, - UserId: "defaultUser", - } - - testCase1 := test{ - name: "New User With Public Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: publicDomain, - UserId: "pub-domain-user", - DomainCategory: PublicCategory, + }{ + { + name: "New User With Public Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: publicDomain, + UserId: "pub-domain-user", + DomainCategory: PublicCategory, + }, + inputInitUserParams: defaultInitAccount, + testingFunc: require.NotEqual, + expectedMSG: "account IDs shouldn't match", + expectedUserRole: UserRoleOwner, + expectedDomainCategory: "", + expectedDomain: publicDomain, + expectedPrimaryDomainStatus: false, + expectedCreatedBy: "pub-domain-user", + expectedUsers: []string{"pub-domain-user"}, }, - inputInitUserParams: defaultInitAccount, - testingFunc: require.NotEqual, - expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, - expectedDomainCategory: "", - expectedDomain: publicDomain, - expectedPrimaryDomainStatus: false, - expectedCreatedBy: "pub-domain-user", - expectedUsers: []string{"pub-domain-user"}, - } - - initUnknown := defaultInitAccount - initUnknown.DomainCategory = UnknownCategory - initUnknown.Domain = unknownDomain - - testCase2 := test{ - name: "New User With Unknown Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: unknownDomain, - UserId: "unknown-domain-user", - DomainCategory: UnknownCategory, + { + name: "New User With Unknown Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: unknownDomain, + UserId: "unknown-domain-user", + DomainCategory: UnknownCategory, + }, + inputInitUserParams: initUnknown, + testingFunc: require.NotEqual, + expectedMSG: "account IDs shouldn't match", + expectedUserRole: UserRoleOwner, + expectedDomain: unknownDomain, + expectedDomainCategory: "", + expectedPrimaryDomainStatus: false, + expectedCreatedBy: "unknown-domain-user", + expectedUsers: []string{"unknown-domain-user"}, }, - inputInitUserParams: initUnknown, - testingFunc: require.NotEqual, - expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, - expectedDomain: unknownDomain, - expectedDomainCategory: "", - expectedPrimaryDomainStatus: false, - expectedCreatedBy: "unknown-domain-user", - expectedUsers: []string{"unknown-domain-user"}, - } - - testCase3 := test{ - name: "New User With Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: privateDomain, - UserId: "pvt-domain-user", - DomainCategory: PrivateCategory, + { + name: "New User With Private Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: privateDomain, + UserId: "pvt-domain-user", + DomainCategory: PrivateCategory, + }, + inputInitUserParams: defaultInitAccount, + testingFunc: require.NotEqual, + expectedMSG: "account IDs shouldn't match", + expectedUserRole: UserRoleOwner, + expectedDomain: privateDomain, + expectedDomainCategory: PrivateCategory, + expectedPrimaryDomainStatus: true, + expectedCreatedBy: "pvt-domain-user", + expectedUsers: []string{"pvt-domain-user"}, }, - inputInitUserParams: defaultInitAccount, - testingFunc: require.NotEqual, - expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, - expectedDomain: privateDomain, - expectedDomainCategory: PrivateCategory, - expectedPrimaryDomainStatus: true, - expectedCreatedBy: "pvt-domain-user", - expectedUsers: []string{"pvt-domain-user"}, - } - - privateInitAccount := defaultInitAccount - privateInitAccount.Domain = privateDomain - privateInitAccount.DomainCategory = PrivateCategory - - testCase4 := test{ - name: "New Regular User With Existing Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: privateDomain, - UserId: "new-pvt-domain-user", - DomainCategory: PrivateCategory, + { + name: "New Regular User With Existing Private Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: privateDomain, + UserId: "new-pvt-domain-user", + DomainCategory: PrivateCategory, + }, + inputUpdateAttrs: true, + inputInitUserParams: privateInitAccount, + testingFunc: require.Equal, + expectedMSG: "account IDs should match", + expectedUserRole: UserRoleUser, + expectedDomain: privateDomain, + expectedDomainCategory: PrivateCategory, + expectedPrimaryDomainStatus: true, + expectedCreatedBy: defaultInitAccount.UserId, + expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"}, }, - inputUpdateAttrs: true, - inputInitUserParams: privateInitAccount, - testingFunc: require.Equal, - expectedMSG: "account IDs should match", - expectedUserRole: UserRoleUser, - expectedDomain: privateDomain, - expectedDomainCategory: PrivateCategory, - expectedPrimaryDomainStatus: true, - expectedCreatedBy: defaultInitAccount.UserId, - expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"}, - } - - testCase5 := test{ - name: "Existing User With Existing Reclassified Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: defaultInitAccount.Domain, - UserId: defaultInitAccount.UserId, - DomainCategory: PrivateCategory, + { + name: "Existing User With Existing Reclassified Private Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: defaultInitAccount.Domain, + UserId: defaultInitAccount.UserId, + DomainCategory: PrivateCategory, + }, + inputInitUserParams: defaultInitAccount, + testingFunc: require.Equal, + expectedMSG: "account IDs should match", + expectedUserRole: UserRoleOwner, + expectedDomain: defaultInitAccount.Domain, + expectedDomainCategory: PrivateCategory, + expectedPrimaryDomainStatus: true, + expectedCreatedBy: defaultInitAccount.UserId, + expectedUsers: []string{defaultInitAccount.UserId}, }, - inputInitUserParams: defaultInitAccount, - testingFunc: require.Equal, - expectedMSG: "account IDs should match", - expectedUserRole: UserRoleOwner, - expectedDomain: defaultInitAccount.Domain, - expectedDomainCategory: PrivateCategory, - expectedPrimaryDomainStatus: true, - expectedCreatedBy: defaultInitAccount.UserId, - expectedUsers: []string{defaultInitAccount.UserId}, - } - - testCase6 := test{ - name: "Existing Account Id With Existing Reclassified Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: defaultInitAccount.Domain, - UserId: defaultInitAccount.UserId, - DomainCategory: PrivateCategory, + { + name: "Existing Account Id With Existing Reclassified Private Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: defaultInitAccount.Domain, + UserId: defaultInitAccount.UserId, + DomainCategory: PrivateCategory, + }, + inputUpdateClaimAccount: true, + inputInitUserParams: defaultInitAccount, + testingFunc: require.Equal, + expectedMSG: "account IDs should match", + expectedUserRole: UserRoleOwner, + expectedDomain: defaultInitAccount.Domain, + expectedDomainCategory: PrivateCategory, + expectedPrimaryDomainStatus: true, + expectedCreatedBy: defaultInitAccount.UserId, + expectedUsers: []string{defaultInitAccount.UserId}, }, - inputUpdateClaimAccount: true, - inputInitUserParams: defaultInitAccount, - testingFunc: require.Equal, - expectedMSG: "account IDs should match", - expectedUserRole: UserRoleOwner, - expectedDomain: defaultInitAccount.Domain, - expectedDomainCategory: PrivateCategory, - expectedPrimaryDomainStatus: true, - expectedCreatedBy: defaultInitAccount.UserId, - expectedUsers: []string{defaultInitAccount.UserId}, - } - - testCase7 := test{ - name: "User With Private Category And Empty Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: "", - UserId: "pvt-domain-user", - DomainCategory: PrivateCategory, + { + name: "User With Private Category And Empty Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: "", + UserId: "pvt-domain-user", + DomainCategory: PrivateCategory, + }, + inputInitUserParams: defaultInitAccount, + testingFunc: require.NotEqual, + expectedMSG: "account IDs shouldn't match", + expectedUserRole: UserRoleOwner, + expectedDomain: "", + expectedDomainCategory: "", + expectedPrimaryDomainStatus: false, + expectedCreatedBy: "pvt-domain-user", + expectedUsers: []string{"pvt-domain-user"}, }, - inputInitUserParams: defaultInitAccount, - testingFunc: require.NotEqual, - expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, - expectedDomain: "", - expectedDomainCategory: "", - expectedPrimaryDomainStatus: false, - expectedCreatedBy: "pvt-domain-user", - expectedUsers: []string{"pvt-domain-user"}, } - for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6, testCase7} { + for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain) + accountID, err := manager.GetAccountIDByUserID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.Domain) require.NoError(t, err, "create init user failed") initAccount, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "get init account failed") if testCase.inputUpdateAttrs { - err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) + err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) require.NoError(t, err, "update init user failed") } @@ -671,17 +665,16 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { userId := "user-id" domain := "test.domain" - initAccount := newAccountWithId(context.Background(), "", userId, domain) + _ = newAccountWithId(context.Background(), "", userId, domain) manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID := initAccount.Id - accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userId, accountID, domain) + accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain) require.NoError(t, err, "create init user failed") // as initAccount was created without account id we have to take the id after account initialization - // that happens inside the GetAccountIDByUserOrAccountID where the id is getting generated + // that happens inside the GetAccountIDByUserID where the id is getting generated // it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it - initAccount, err = manager.Store.GetAccount(context.Background(), accountID) + initAccount, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "get init account failed") claims := jwtclaims.AuthorizationClaims{ @@ -885,7 +878,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) { } } -func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { +func TestAccountManager_GetAccountByUserID(t *testing.T) { manager, err := createManager(t) if err != nil { t.Fatal(err) @@ -894,7 +887,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { userId := "test_user" - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userId, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, "") if err != nil { t.Fatal(err) } @@ -903,14 +896,13 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { return } - _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "") - if err != nil { - t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountID) - } + exists, err := manager.Store.AccountExists(context.Background(), LockingStrengthShare, accountID) + assert.NoError(t, err) + assert.True(t, exists, "expected to get existing account after creation using userid") - _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", "", "") + _, err = manager.GetAccountIDByUserID(context.Background(), "", "") if err == nil { - t.Errorf("expected an error when user and account IDs are empty") + t.Errorf("expected an error when user ID is empty") } } @@ -1599,9 +1591,10 @@ func TestAccount_Copy(t *testing.T) { }, Routes: map[route.ID]*route.Route{ "route1": { - ID: "route1", - PeerGroups: []string{}, - Groups: []string{"group1"}, + ID: "route1", + PeerGroups: []string{}, + Groups: []string{"group1"}, + AccessControlGroups: []string{}, }, }, NameServerGroups: map[string]*nbdns.NameServerGroup{ @@ -1668,7 +1661,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) @@ -1683,7 +1676,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + _, err = manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1695,7 +1688,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { }) require.NoError(t, err, "unable to add peer") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to get the account") account, err := manager.Store.GetAccount(context.Background(), accountID) @@ -1741,7 +1734,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1769,7 +1762,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. }, } - accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err = manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to get the account") account, err := manager.Store.GetAccount(context.Background(), accountID) @@ -1789,7 +1782,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + _, err = manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1801,7 +1794,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test }) require.NoError(t, err, "unable to add peer") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to get the account") account, err := manager.Store.GetAccount(context.Background(), accountID) @@ -1849,7 +1842,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") updatedSettings, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ @@ -1860,9 +1853,6 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { assert.False(t, updatedSettings.PeerLoginExpirationEnabled) assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour) - accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "") - require.NoError(t, err, "unable to get account by ID") - settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) require.NoError(t, err, "unable to get account settings") @@ -1967,6 +1957,90 @@ func TestAccount_GetExpiredPeers(t *testing.T) { } } +func TestAccount_GetInactivePeers(t *testing.T) { + type test struct { + name string + peers map[string]*nbpeer.Peer + expectedPeers map[string]struct{} + } + testCases := []test{ + { + name: "Peers with inactivity expiration disabled, no expired peers", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + InactivityExpirationEnabled: false, + }, + "peer-2": { + InactivityExpirationEnabled: false, + }, + }, + expectedPeers: map[string]struct{}{}, + }, + { + name: "Two peers expired", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + ID: "peer-1", + InactivityExpirationEnabled: true, + Status: &nbpeer.PeerStatus{ + LastSeen: time.Now().UTC().Add(-45 * time.Second), + Connected: false, + LoginExpired: false, + }, + LastLogin: time.Now().UTC().Add(-30 * time.Minute), + UserID: userID, + }, + "peer-2": { + ID: "peer-2", + InactivityExpirationEnabled: true, + Status: &nbpeer.PeerStatus{ + LastSeen: time.Now().UTC().Add(-45 * time.Second), + Connected: false, + LoginExpired: false, + }, + LastLogin: time.Now().UTC().Add(-2 * time.Hour), + UserID: userID, + }, + "peer-3": { + ID: "peer-3", + InactivityExpirationEnabled: true, + Status: &nbpeer.PeerStatus{ + LastSeen: time.Now().UTC(), + Connected: true, + LoginExpired: false, + }, + LastLogin: time.Now().UTC().Add(-1 * time.Hour), + UserID: userID, + }, + }, + expectedPeers: map[string]struct{}{ + "peer-1": {}, + "peer-2": {}, + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + account := &Account{ + Peers: testCase.peers, + Settings: &Settings{ + PeerInactivityExpirationEnabled: true, + PeerInactivityExpiration: time.Second, + }, + } + + expiredPeers := account.GetInactivePeers() + assert.Len(t, expiredPeers, len(testCase.expectedPeers)) + for _, peer := range expiredPeers { + if _, ok := testCase.expectedPeers[peer.ID]; !ok { + t.Fatalf("expected to have peer %s expired", peer.ID) + } + } + }) + } +} + func TestAccount_GetPeersWithExpiration(t *testing.T) { type test struct { name string @@ -2036,6 +2110,75 @@ func TestAccount_GetPeersWithExpiration(t *testing.T) { } } +func TestAccount_GetPeersWithInactivity(t *testing.T) { + type test struct { + name string + peers map[string]*nbpeer.Peer + expectedPeers map[string]struct{} + } + + testCases := []test{ + { + name: "No account peers, no peers with expiration", + peers: map[string]*nbpeer.Peer{}, + expectedPeers: map[string]struct{}{}, + }, + { + name: "Peers with login expiration disabled, no peers with expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + InactivityExpirationEnabled: false, + UserID: userID, + }, + "peer-2": { + InactivityExpirationEnabled: false, + UserID: userID, + }, + }, + expectedPeers: map[string]struct{}{}, + }, + { + name: "Peers with login expiration enabled, return peers with expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + ID: "peer-1", + InactivityExpirationEnabled: true, + UserID: userID, + }, + "peer-2": { + InactivityExpirationEnabled: false, + UserID: userID, + }, + }, + expectedPeers: map[string]struct{}{ + "peer-1": {}, + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + account := &Account{ + Peers: testCase.peers, + } + + actual := account.GetPeersWithInactivity() + assert.Len(t, actual, len(testCase.expectedPeers)) + if len(testCase.expectedPeers) > 0 { + for k := range testCase.expectedPeers { + contains := false + for _, peer := range actual { + if k == peer.ID { + contains = true + } + } + assert.True(t, contains) + } + } + }) + } +} + func TestAccount_GetNextPeerExpiration(t *testing.T) { type test struct { name string @@ -2197,9 +2340,175 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) { } } +func TestAccount_GetNextInactivePeerExpiration(t *testing.T) { + type test struct { + name string + peers map[string]*nbpeer.Peer + expiration time.Duration + expirationEnabled bool + expectedNextRun bool + expectedNextExpiration time.Duration + } + + expectedNextExpiration := time.Minute + testCases := []test{ + { + name: "No peers, no expiration", + peers: map[string]*nbpeer.Peer{}, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + { + name: "No connected peers, no expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + Status: &nbpeer.PeerStatus{ + Connected: false, + }, + InactivityExpirationEnabled: false, + UserID: userID, + }, + "peer-2": { + Status: &nbpeer.PeerStatus{ + Connected: false, + }, + InactivityExpirationEnabled: false, + UserID: userID, + }, + }, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + { + name: "Connected peers with disabled expiration, no expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + Status: &nbpeer.PeerStatus{ + Connected: true, + }, + InactivityExpirationEnabled: false, + UserID: userID, + }, + "peer-2": { + Status: &nbpeer.PeerStatus{ + Connected: true, + }, + InactivityExpirationEnabled: false, + UserID: userID, + }, + }, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + { + name: "Expired peers, no expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + Status: &nbpeer.PeerStatus{ + Connected: true, + LoginExpired: true, + }, + InactivityExpirationEnabled: true, + UserID: userID, + }, + "peer-2": { + Status: &nbpeer.PeerStatus{ + Connected: true, + LoginExpired: true, + }, + InactivityExpirationEnabled: true, + UserID: userID, + }, + }, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + { + name: "To be expired peer, return expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + Status: &nbpeer.PeerStatus{ + Connected: false, + LoginExpired: false, + LastSeen: time.Now().Add(-1 * time.Second), + }, + InactivityExpirationEnabled: true, + LastLogin: time.Now().UTC(), + UserID: userID, + }, + "peer-2": { + Status: &nbpeer.PeerStatus{ + Connected: true, + LoginExpired: true, + }, + InactivityExpirationEnabled: true, + UserID: userID, + }, + }, + expiration: time.Minute, + expirationEnabled: false, + expectedNextRun: true, + expectedNextExpiration: expectedNextExpiration, + }, + { + name: "Peers added with setup keys, no expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + Status: &nbpeer.PeerStatus{ + Connected: true, + LoginExpired: false, + }, + InactivityExpirationEnabled: true, + SetupKey: "key", + }, + "peer-2": { + Status: &nbpeer.PeerStatus{ + Connected: true, + LoginExpired: false, + }, + InactivityExpirationEnabled: true, + SetupKey: "key", + }, + }, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + account := &Account{ + Peers: testCase.peers, + Settings: &Settings{PeerInactivityExpiration: testCase.expiration, PeerInactivityExpirationEnabled: testCase.expirationEnabled}, + } + + expiration, ok := account.GetNextInactivePeerExpiration() + assert.Equal(t, testCase.expectedNextRun, ok) + if testCase.expectedNextRun { + assert.True(t, expiration >= 0 && expiration <= testCase.expectedNextExpiration) + } else { + assert.Equal(t, expiration, testCase.expectedNextExpiration) + } + }) + } +} + func TestAccount_SetJWTGroups(t *testing.T) { + manager, err := createManager(t) + require.NoError(t, err, "unable to create account manager") + // create a new account account := &Account{ + Id: "accountID", Peers: map[string]*nbpeer.Peer{ "peer1": {ID: "peer1", Key: "key1", UserID: "user1"}, "peer2": {ID: "peer2", Key: "key2", UserID: "user1"}, @@ -2210,62 +2519,120 @@ func TestAccount_SetJWTGroups(t *testing.T) { Groups: map[string]*group.Group{ "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}}, }, - Settings: &Settings{GroupsPropagationEnabled: true}, + Settings: &Settings{GroupsPropagationEnabled: true, JWTGroupsEnabled: true, JWTGroupsClaimName: "groups"}, Users: map[string]*User{ - "user1": {Id: "user1"}, - "user2": {Id: "user2"}, + "user1": {Id: "user1", AccountID: "accountID"}, + "user2": {Id: "user2", AccountID: "accountID"}, }, } + assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account") + t.Run("empty jwt groups", func(t *testing.T) { - updated := account.SetJWTGroups("user1", []string{}) - assert.False(t, updated, "account should not be updated") - assert.Empty(t, account.Users["user1"].AutoGroups, "auto groups must be empty") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{}}, + } + err := manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Empty(t, user.AutoGroups, "auto groups must be empty") }) t.Run("jwt match existing api group", func(t *testing.T) { - updated := account.SetJWTGroups("user1", []string{"group1"}) - assert.False(t, updated, "account should not be updated") - assert.Equal(t, 0, len(account.Users["user1"].AutoGroups)) - assert.Equal(t, account.Groups["group1"].Issued, group.GroupIssuedAPI, "group should be api issued") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{"group1"}}, + } + err := manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 0) + + group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") + assert.NoError(t, err, "unable to get group") + assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") }) t.Run("jwt match existing api group in user auto groups", func(t *testing.T) { account.Users["user1"].AutoGroups = []string{"group1"} + assert.NoError(t, manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, account.Users["user1"])) - updated := account.SetJWTGroups("user1", []string{"group1"}) - assert.False(t, updated, "account should not be updated") - assert.Equal(t, 1, len(account.Users["user1"].AutoGroups)) - assert.Equal(t, account.Groups["group1"].Issued, group.GroupIssuedAPI, "group should be api issued") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{"group1"}}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 1) + + group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") + assert.NoError(t, err, "unable to get group") + assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") }) t.Run("add jwt group", func(t *testing.T) { - updated := account.SetJWTGroups("user1", []string{"group1", "group2"}) - assert.True(t, updated, "account should be updated") - assert.Len(t, account.Groups, 2, "new group should be added") - assert.Len(t, account.Users["user1"].AutoGroups, 2, "new group should be added") - assert.Contains(t, account.Groups, account.Users["user1"].AutoGroups[0], "groups must contain group2 from user groups") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group2"}}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 2, "groups count should not be change") }) t.Run("existed group not update", func(t *testing.T) { - updated := account.SetJWTGroups("user1", []string{"group2"}) - assert.False(t, updated, "account should not be updated") - assert.Len(t, account.Groups, 2, "groups count should not be changed") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{"group2"}}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 2, "groups count should not be change") }) t.Run("add new group", func(t *testing.T) { - updated := account.SetJWTGroups("user2", []string{"group1", "group3"}) - assert.True(t, updated, "account should be updated") - assert.Len(t, account.Groups, 3, "new group should be added") - assert.Len(t, account.Users["user2"].AutoGroups, 1, "new group should be added") - assert.Contains(t, account.Groups, account.Users["user2"].AutoGroups[0], "groups must contain group3 from user groups") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user2", + Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group3"}}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID") + assert.NoError(t, err) + assert.Len(t, groups, 3, "new group3 should be added") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user2") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 1, "new group should be added") }) t.Run("remove all JWT groups", func(t *testing.T) { - updated := account.SetJWTGroups("user1", []string{}) - assert.True(t, updated, "account should be updated") - assert.Len(t, account.Users["user1"].AutoGroups, 1, "only non-JWT groups should remain") - assert.Contains(t, account.Users["user1"].AutoGroups, "group1", " group1 should still be present") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{}}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain") + assert.Contains(t, user.AutoGroups, "group1", " group1 should still be present") }) } @@ -2365,7 +2732,7 @@ func createManager(t TB) (*DefaultAccountManager, error) { func createStore(t TB) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 4ee57f181..188494241 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -139,6 +139,13 @@ const ( PostureCheckUpdated Activity = 61 // PostureCheckDeleted indicates that the user deleted a posture check PostureCheckDeleted Activity = 62 + + PeerInactivityExpirationEnabled Activity = 63 + PeerInactivityExpirationDisabled Activity = 64 + + AccountPeerInactivityExpirationEnabled Activity = 65 + AccountPeerInactivityExpirationDisabled Activity = 66 + AccountPeerInactivityExpirationDurationUpdated Activity = 67 ) var activityMap = map[Activity]Code{ @@ -205,6 +212,13 @@ var activityMap = map[Activity]Code{ PostureCheckCreated: {"Posture check created", "posture.check.created"}, PostureCheckUpdated: {"Posture check updated", "posture.check.updated"}, PostureCheckDeleted: {"Posture check deleted", "posture.check.deleted"}, + + PeerInactivityExpirationEnabled: {"Peer inactivity expiration enabled", "peer.inactivity.expiration.enable"}, + PeerInactivityExpirationDisabled: {"Peer inactivity expiration disabled", "peer.inactivity.expiration.disable"}, + + AccountPeerInactivityExpirationEnabled: {"Account peer inactivity expiration enabled", "account.peer.inactivity.expiration.enable"}, + AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"}, + AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"}, } // StringCode returns a string code of the activity diff --git a/management/server/dns_test.go b/management/server/dns_test.go index e033c1a21..c7f435b68 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -210,7 +210,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { func createDNSStore(t *testing.T) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } diff --git a/management/server/file_store.go b/management/server/file_store.go index a8d8de002..561e133ce 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -2,24 +2,19 @@ package server import ( "context" - "errors" - "net" "os" "path/filepath" "strings" "sync" "time" - "github.com/netbirdio/netbird/dns" - nbgroup "github.com/netbirdio/netbird/management/server/group" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/management/server/status" - "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/netbirdio/netbird/route" - "github.com/netbirdio/netbird/util" "github.com/rs/xid" log "github.com/sirupsen/logrus" + + nbgroup "github.com/netbirdio/netbird/management/server/group" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/util" ) // storeFileName Store file name. Stored in the datadir @@ -41,167 +36,9 @@ type FileStore struct { mux sync.Mutex `json:"-"` storeFile string `json:"-"` - // sync.Mutex indexed by resource ID - resourceLocks sync.Map `json:"-"` - globalAccountLock sync.Mutex `json:"-"` - metrics telemetry.AppMetrics `json:"-"` } -func (s *FileStore) ExecuteInTransaction(ctx context.Context, f func(store Store) error) error { - return f(s) -} - -func (s *FileStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKeyID)] - if !ok { - return status.NewSetupKeyNotFoundError() - } - - account, err := s.getAccount(accountID) - if err != nil { - return err - } - - account.SetupKeys[setupKeyID].UsedTimes++ - - return s.SaveAccount(ctx, account) -} - -func (s *FileStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return err - } - - allGroup, err := account.GetGroupAll() - if err != nil || allGroup == nil { - return errors.New("all group not found") - } - - allGroup.Peers = append(allGroup.Peers, peerID) - - return nil -} - -func (s *FileStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountId) - if err != nil { - return err - } - - account.Groups[groupID].Peers = append(account.Groups[groupID].Peers, peerId) - - return nil -} - -func (s *FileStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, ok := s.Accounts[peer.AccountID] - if !ok { - return status.NewAccountNotFoundError(peer.AccountID) - } - - account.Peers[peer.ID] = peer - return s.SaveAccount(ctx, account) -} - -func (s *FileStore) IncrementNetworkSerial(ctx context.Context, _ LockingStrength, accountId string) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, ok := s.Accounts[accountId] - if !ok { - return status.NewAccountNotFoundError(accountId) - } - - account.Network.Serial++ - - return s.SaveAccount(ctx, account) -} - -func (s *FileStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(key)] - if !ok { - return nil, status.NewSetupKeyNotFoundError() - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - setupKey, ok := account.SetupKeys[key] - if !ok { - return nil, status.Errorf(status.NotFound, "setup key not found") - } - - return setupKey, nil -} - -func (s *FileStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - var takenIps []net.IP - for _, existingPeer := range account.Peers { - takenIps = append(takenIps, existingPeer.IP) - } - - return takenIps, nil -} - -func (s *FileStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - existingLabels := []string{} - for _, peer := range account.Peers { - if peer.DNSLabel != "" { - existingLabels = append(existingLabels, peer.DNSLabel) - } - } - return existingLabels, nil -} - -func (s *FileStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - return account.Network, nil -} - -type StoredAccount struct{} - // NewFileStore restores a store from the file located in the datadir func NewFileStore(ctx context.Context, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) { fs, err := restore(ctx, filepath.Join(dataDir, storeFileName)) @@ -212,25 +49,6 @@ func NewFileStore(ctx context.Context, dataDir string, metrics telemetry.AppMetr return fs, nil } -// NewFilestoreFromSqliteStore restores a store from Sqlite and stores to Filestore json in the file located in datadir -func NewFilestoreFromSqliteStore(ctx context.Context, sqlStore *SqlStore, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) { - store, err := NewFileStore(ctx, dataDir, metrics) - if err != nil { - return nil, err - } - - err = store.SaveInstallationID(ctx, sqlStore.GetInstallationID()) - if err != nil { - return nil, err - } - - for _, account := range sqlStore.GetAllAccounts(ctx) { - store.Accounts[account.Id] = account - } - - return store, store.persist(ctx, store.storeFile) -} - // restore the state of the store from the file. // Creates a new empty store file if doesn't exist func restore(ctx context.Context, file string) (*FileStore, error) { @@ -239,7 +57,6 @@ func restore(ctx context.Context, file string) (*FileStore, error) { s := &FileStore{ Accounts: make(map[string]*Account), mux: sync.Mutex{}, - globalAccountLock: sync.Mutex{}, SetupKeyID2AccountID: make(map[string]string), PeerKeyID2AccountID: make(map[string]string), UserID2AccountID: make(map[string]string), @@ -278,6 +95,9 @@ func restore(ctx context.Context, file string) (*FileStore, error) { account.Settings = &Settings{ PeerLoginExpirationEnabled: false, PeerLoginExpiration: DefaultPeerLoginExpiration, + + PeerInactivityExpirationEnabled: false, + PeerInactivityExpiration: DefaultPeerInactivityExpiration, } } @@ -415,252 +235,6 @@ func (s *FileStore) persist(ctx context.Context, file string) error { return nil } -// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock -func (s *FileStore) AcquireGlobalLock(ctx context.Context) (unlock func()) { - log.WithContext(ctx).Debugf("acquiring global lock") - start := time.Now() - s.globalAccountLock.Lock() - - unlock = func() { - s.globalAccountLock.Unlock() - log.WithContext(ctx).Debugf("released global lock in %v", time.Since(start)) - } - - took := time.Since(start) - log.WithContext(ctx).Debugf("took %v to acquire global lock", took) - if s.metrics != nil { - s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took) - } - - return unlock -} - -// AcquireWriteLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock -func (s *FileStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) { - log.WithContext(ctx).Debugf("acquiring lock for ID %s", uniqueID) - start := time.Now() - value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.Mutex{}) - mtx := value.(*sync.Mutex) - mtx.Lock() - - unlock = func() { - mtx.Unlock() - log.WithContext(ctx).Debugf("released lock for ID %s in %v", uniqueID, time.Since(start)) - } - - return unlock -} - -// AcquireReadLockByUID acquires an ID lock for reading a resource and returns a function that releases the lock -// This method is still returns a write lock as file store can't handle read locks -func (s *FileStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) { - return s.AcquireWriteLockByUID(ctx, uniqueID) -} - -func (s *FileStore) SaveAccount(ctx context.Context, account *Account) error { - s.mux.Lock() - defer s.mux.Unlock() - - if account.Id == "" { - return status.Errorf(status.InvalidArgument, "account id should not be empty") - } - - accountCopy := account.Copy() - - s.Accounts[accountCopy.Id] = accountCopy - - // todo check that account.Id and keyId are not exist already - // because if keyId exists for other accounts this can be bad - for keyID := range accountCopy.SetupKeys { - s.SetupKeyID2AccountID[strings.ToUpper(keyID)] = accountCopy.Id - } - - // enforce peer to account index and delete peer to route indexes for rebuild - for _, peer := range accountCopy.Peers { - s.PeerKeyID2AccountID[peer.Key] = accountCopy.Id - s.PeerID2AccountID[peer.ID] = accountCopy.Id - } - - for _, user := range accountCopy.Users { - s.UserID2AccountID[user.Id] = accountCopy.Id - for _, pat := range user.PATs { - s.TokenID2UserID[pat.ID] = user.Id - s.HashedPAT2TokenID[pat.HashedToken] = pat.ID - } - } - - if accountCopy.DomainCategory == PrivateCategory && accountCopy.IsDomainPrimaryAccount { - s.PrivateDomain2AccountID[accountCopy.Domain] = accountCopy.Id - } - - return s.persist(ctx, s.storeFile) -} - -func (s *FileStore) DeleteAccount(ctx context.Context, account *Account) error { - s.mux.Lock() - defer s.mux.Unlock() - - if account.Id == "" { - return status.Errorf(status.InvalidArgument, "account id should not be empty") - } - - for keyID := range account.SetupKeys { - delete(s.SetupKeyID2AccountID, strings.ToUpper(keyID)) - } - - // enforce peer to account index and delete peer to route indexes for rebuild - for _, peer := range account.Peers { - delete(s.PeerKeyID2AccountID, peer.Key) - delete(s.PeerID2AccountID, peer.ID) - } - - for _, user := range account.Users { - for _, pat := range user.PATs { - delete(s.TokenID2UserID, pat.ID) - delete(s.HashedPAT2TokenID, pat.HashedToken) - } - delete(s.UserID2AccountID, user.Id) - } - - if account.DomainCategory == PrivateCategory && account.IsDomainPrimaryAccount { - delete(s.PrivateDomain2AccountID, account.Domain) - } - - delete(s.Accounts, account.Id) - - return s.persist(ctx, s.storeFile) -} - -// DeleteHashedPAT2TokenIDIndex removes an entry from the indexing map HashedPAT2TokenID -func (s *FileStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error { - s.mux.Lock() - defer s.mux.Unlock() - - delete(s.HashedPAT2TokenID, hashedToken) - - return nil -} - -// DeleteTokenID2UserIDIndex removes an entry from the indexing map TokenID2UserID -func (s *FileStore) DeleteTokenID2UserIDIndex(tokenID string) error { - s.mux.Lock() - defer s.mux.Unlock() - - delete(s.TokenID2UserID, tokenID) - - return nil -} - -// GetAccountByPrivateDomain returns account by private domain -func (s *FileStore) GetAccountByPrivateDomain(_ context.Context, domain string) (*Account, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.PrivateDomain2AccountID[strings.ToLower(domain)] - if !ok { - return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - return account.Copy(), nil -} - -// GetAccountBySetupKey returns account by setup key id -func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*Account, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)] - if !ok { - return nil, status.NewSetupKeyNotFoundError() - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - return account.Copy(), nil -} - -// GetTokenIDByHashedToken returns the id of a personal access token by its hashed secret -func (s *FileStore) GetTokenIDByHashedToken(_ context.Context, token string) (string, error) { - s.mux.Lock() - defer s.mux.Unlock() - - tokenID, ok := s.HashedPAT2TokenID[token] - if !ok { - return "", status.Errorf(status.NotFound, "tokenID not found: provided token doesn't exists") - } - - return tokenID, nil -} - -// GetUserByTokenID returns a User object a tokenID belongs to -func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User, error) { - s.mux.Lock() - defer s.mux.Unlock() - - userID, ok := s.TokenID2UserID[tokenID] - if !ok { - return nil, status.Errorf(status.NotFound, "user not found: provided tokenID doesn't exists") - } - - accountID, ok := s.UserID2AccountID[userID] - if !ok { - return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists") - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - return account.Users[userID].Copy(), nil -} - -func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID string) (*User, error) { - accountID, ok := s.UserID2AccountID[userID] - if !ok { - return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists") - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - user := account.Users[userID].Copy() - pat := make([]PersonalAccessToken, 0, len(user.PATs)) - for _, token := range user.PATs { - if token != nil { - pat = append(pat, *token) - } - } - user.PATsG = pat - - return user, nil -} - -func (s *FileStore) GetAccountGroups(_ context.Context, accountID string) ([]*nbgroup.Group, error) { - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - groupsSlice := make([]*nbgroup.Group, 0, len(account.Groups)) - - for _, group := range account.Groups { - groupsSlice = append(groupsSlice, group) - } - - return groupsSlice, nil -} - // GetAllAccounts returns all accounts func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) { s.mux.Lock() @@ -672,278 +246,6 @@ func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) { return all } -// getAccount returns a reference to the Account. Should not return a copy. -func (s *FileStore) getAccount(accountID string) (*Account, error) { - account, ok := s.Accounts[accountID] - if !ok { - return nil, status.NewAccountNotFoundError(accountID) - } - - return account, nil -} - -// GetAccount returns an account for ID -func (s *FileStore) GetAccount(_ context.Context, accountID string) (*Account, error) { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - return account.Copy(), nil -} - -// GetAccountByUser returns a user account -func (s *FileStore) GetAccountByUser(_ context.Context, userID string) (*Account, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.UserID2AccountID[userID] - if !ok { - return nil, status.NewUserNotFoundError(userID) - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - return account.Copy(), nil -} - -// GetAccountByPeerID returns an account for a given peer ID -func (s *FileStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.PeerID2AccountID[peerID] - if !ok { - return nil, status.Errorf(status.NotFound, "provided peer ID doesn't exists %s", peerID) - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - // this protection is needed because when we delete a peer, we don't really remove index peerID -> accountID. - // check Account.Peers for a match - if _, ok := account.Peers[peerID]; !ok { - delete(s.PeerID2AccountID, peerID) - log.WithContext(ctx).Warnf("removed stale peerID %s to accountID %s index", peerID, accountID) - return nil, status.NewPeerNotFoundError(peerID) - } - - return account.Copy(), nil -} - -// GetAccountByPeerPubKey returns an account for a given peer WireGuard public key -func (s *FileStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.PeerKeyID2AccountID[peerKey] - if !ok { - return nil, status.NewPeerNotFoundError(peerKey) - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - // this protection is needed because when we delete a peer, we don't really remove index peerKey -> accountID. - // check Account.Peers for a match - stale := true - for _, peer := range account.Peers { - if peer.Key == peerKey { - stale = false - break - } - } - if stale { - delete(s.PeerKeyID2AccountID, peerKey) - log.WithContext(ctx).Warnf("removed stale peerKey %s to accountID %s index", peerKey, accountID) - return nil, status.NewPeerNotFoundError(peerKey) - } - - return account.Copy(), nil -} - -func (s *FileStore) GetAccountIDByPeerPubKey(_ context.Context, peerKey string) (string, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.PeerKeyID2AccountID[peerKey] - if !ok { - return "", status.NewPeerNotFoundError(peerKey) - } - - return accountID, nil -} - -func (s *FileStore) GetAccountIDByUserID(userID string) (string, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.UserID2AccountID[userID] - if !ok { - return "", status.NewUserNotFoundError(userID) - } - - return accountID, nil -} - -func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) (string, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)] - if !ok { - return "", status.NewSetupKeyNotFoundError() - } - - return accountID, nil -} - -func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, _ LockingStrength, peerKey string) (*nbpeer.Peer, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.PeerKeyID2AccountID[peerKey] - if !ok { - return nil, status.NewPeerNotFoundError(peerKey) - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - for _, peer := range account.Peers { - if peer.Key == peerKey { - return peer.Copy(), nil - } - } - - return nil, status.NewPeerNotFoundError(peerKey) -} - -func (s *FileStore) GetAccountSettings(_ context.Context, _ LockingStrength, accountID string) (*Settings, error) { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - return account.Settings.Copy(), nil -} - -// GetInstallationID returns the installation ID from the store -func (s *FileStore) GetInstallationID() string { - return s.InstallationID -} - -// SaveInstallationID saves the installation ID -func (s *FileStore) SaveInstallationID(ctx context.Context, ID string) error { - s.mux.Lock() - defer s.mux.Unlock() - - s.InstallationID = ID - - return s.persist(ctx, s.storeFile) -} - -// SavePeer saves the peer in the account -func (s *FileStore) SavePeer(_ context.Context, accountID string, peer *nbpeer.Peer) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return err - } - - newPeer := peer.Copy() - - account.Peers[peer.ID] = newPeer - - s.PeerKeyID2AccountID[peer.Key] = accountID - s.PeerID2AccountID[peer.ID] = accountID - - return nil -} - -// SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things. -// PeerStatus will be saved eventually when some other changes occur. -func (s *FileStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return err - } - - peer := account.Peers[peerID] - if peer == nil { - return status.Errorf(status.NotFound, "peer %s not found", peerID) - } - - peer.Status = &peerStatus - - return nil -} - -// SavePeerLocation stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things. -// Peer.Location will be saved eventually when some other changes occur. -func (s *FileStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return err - } - - peer := account.Peers[peerWithLocation.ID] - if peer == nil { - return status.Errorf(status.NotFound, "peer %s not found", peerWithLocation.ID) - } - - peer.Location = peerWithLocation.Location - - return nil -} - -// SaveUserLastLogin stores the last login time for a user in memory. It doesn't attempt to persist data to speed up things. -func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID string, lastLogin time.Time) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return err - } - - peer := account.Users[userID] - if peer == nil { - return status.Errorf(status.NotFound, "user %s not found", userID) - } - - peer.LastLogin = lastLogin - - return nil -} - -func (s *FileStore) GetPostureCheckByChecksDefinition(_ string, _ *posture.ChecksDefinition) (*posture.Checks, error) { - return nil, status.Errorf(status.Internal, "GetPostureCheckByChecksDefinition is not implemented") -} - // Close the FileStore persisting data to disk func (s *FileStore) Close(ctx context.Context) error { s.mux.Lock() @@ -958,192 +260,3 @@ func (s *FileStore) Close(ctx context.Context) error { func (s *FileStore) GetStoreEngine() StoreEngine { return FileStoreEngine } - -func (s *FileStore) SaveGroups(_ context.Context, _ LockingStrength, _ []*nbgroup.Group) error { - return status.Errorf(status.Internal, "SaveGroups is not implemented") -} - -func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ LockingStrength, _ string) (string, error) { - return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented") -} - -func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ LockingStrength, accountID string) (string, string, error) { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return "", "", err - } - - return account.Domain, account.DomainCategory, nil -} - -// AccountExists checks whether an account exists by the given ID. -func (s *FileStore) AccountExists(_ context.Context, _ LockingStrength, id string) (bool, error) { - _, exists := s.Accounts[id] - return exists, nil -} - -func (s *FileStore) GetAccountDNSSettings(_ context.Context, _ LockingStrength, _ string) (*DNSSettings, error) { - return nil, status.Errorf(status.Internal, "GetAccountDNSSettings is not implemented") -} - -func (s *FileStore) GetGroupByID(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) { - return nil, status.Errorf(status.Internal, "GetGroupByID is not implemented") -} - -func (s *FileStore) GetGroupByName(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) { - return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented") -} - -func (s *FileStore) GetAccountPolicies(_ context.Context, _ LockingStrength, _ string) ([]*Policy, error) { - return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented") -} - -func (s *FileStore) GetPolicyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*Policy, error) { - return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented") - -} - -func (s *FileStore) SavePolicy(_ context.Context, _ LockingStrength, _ *Policy) error { - return status.Errorf(status.Internal, "SavePolicy is not implemented") -} - -func (s *FileStore) DeletePolicy(_ context.Context, _ LockingStrength, _, _ string) error { - return status.Errorf(status.Internal, "DeletePolicy is not implemented") -} - -func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ LockingStrength, _ string) ([]*posture.Checks, error) { - return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented") -} - -func (s *FileStore) GetPostureChecksByID(_ context.Context, _ LockingStrength, _ string, _ string) (*posture.Checks, error) { - return nil, status.Errorf(status.Internal, "GetPostureChecksByID is not implemented") -} - -func (s *FileStore) SavePostureChecks(_ context.Context, _ LockingStrength, _ *posture.Checks) error { - return status.Errorf(status.Internal, "SavePostureChecks is not implemented") -} - -func (s *FileStore) DeletePostureChecks(_ context.Context, _ LockingStrength, _, _ string) error { - return status.Errorf(status.Internal, "DeletePostureChecks is not implemented") -} - -func (s *FileStore) GetAccountRoutes(_ context.Context, _ LockingStrength, _ string) ([]*route.Route, error) { - return nil, status.Errorf(status.Internal, "GetAccountRoutes is not implemented") -} - -func (s *FileStore) GetRouteByID(_ context.Context, _ LockingStrength, _ string, _ string) (*route.Route, error) { - return nil, status.Errorf(status.Internal, "GetRouteByID is not implemented") -} - -func (s *FileStore) SaveRoute(_ context.Context, _ LockingStrength, _ *route.Route) error { - return status.Errorf(status.Internal, "SaveRoute is not implemented") -} -func (s *FileStore) DeleteRoute(_ context.Context, _ LockingStrength, _, _ string) error { - return status.Errorf(status.Internal, "DeleteRoute is not implemented") -} - -func (s *FileStore) GetAccountSetupKeys(_ context.Context, _ LockingStrength, _ string) ([]*SetupKey, error) { - return nil, status.Errorf(status.Internal, "GetAccountSetupKeys is not implemented") -} - -func (s *FileStore) GetSetupKeyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*SetupKey, error) { - return nil, status.Errorf(status.Internal, "GetSetupKeyByID is not implemented") -} - -func (s *FileStore) SaveSetupKey(_ context.Context, _ LockingStrength, _ *SetupKey) error { - return status.Errorf(status.Internal, "GetSetupKeyByID is not implemented") -} - -func (s *FileStore) DeleteSetupKey(_ context.Context, _ LockingStrength, _, _ string) error { - return status.Errorf(status.Internal, "DeleteSetupKey is not implemented") -} - -func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ LockingStrength, _ string) ([]*dns.NameServerGroup, error) { - return nil, status.Errorf(status.Internal, "GetAccountNameServerGroups is not implemented") -} - -func (s *FileStore) GetNameServerGroupByID(_ context.Context, _ LockingStrength, _ string, _ string) (*dns.NameServerGroup, error) { - return nil, status.Errorf(status.Internal, "GetNameServerGroupByID is not implemented") -} - -func (s *FileStore) SaveNameServerGroup(_ context.Context, _ LockingStrength, _ *dns.NameServerGroup) error { - return status.Errorf(status.Internal, "SaveNameServerGroup is not implemented") -} - -func (s *FileStore) DeleteNameServerGroup(_ context.Context, _ LockingStrength, _, _ string) error { - return status.Errorf(status.Internal, "DeleteNameServerGroup is not implemented") -} - -func (s *FileStore) GetAccountPeers(_ context.Context, _ LockingStrength, _ string) ([]*nbpeer.Peer, error) { - return nil, status.Errorf(status.Internal, "GetAccountPeers is not implemented") -} - -func (s *FileStore) GetUserPeers(_ context.Context, _ LockingStrength, _, _ string) ([]*nbpeer.Peer, error) { - return nil, status.Errorf(status.Internal, "GetUserPeers is not implemented") -} - -func (s *FileStore) GetAccountPeersWithExpiration(_ context.Context, _ LockingStrength, _ string) ([]*nbpeer.Peer, error) { - return nil, status.Errorf(status.Internal, "GetAccountPeersWithExpiration is not implemented") -} - -func (s *FileStore) GetPeerByID(_ context.Context, _ LockingStrength, _ string, _ string) (*nbpeer.Peer, error) { - return nil, status.Errorf(status.Internal, "GetPeerByID is not implemented") -} - -func (s *FileStore) GetPATByID(_ context.Context, _ LockingStrength, _ string, _ string) (*PersonalAccessToken, error) { - return nil, status.Errorf(status.Internal, "GetPATByID is not implemented") -} - -func (s *FileStore) SavePAT(_ context.Context, _ LockingStrength, _ *PersonalAccessToken) error { - return status.Errorf(status.Internal, "SavePAT is not implemented") -} - -func (s *FileStore) DeletePAT(_ context.Context, _ LockingStrength, _, _ string) error { - return status.Errorf(status.Internal, "DeletePAT is not implemented") -} - -func (s *FileStore) SaveDNSSettings(_ context.Context, _ LockingStrength, _ string, _ *DNSSettings) error { - return status.Errorf(status.Internal, "SaveDNSSettings is not implemented") -} - -func (s *FileStore) SaveAccountSettings(_ context.Context, _ LockingStrength, _ string, _ *Settings) error { - return status.Errorf(status.Internal, "SaveAccountSettings is not implemented") -} - -func (s *FileStore) SaveGroup(_ context.Context, _ LockingStrength, _ *nbgroup.Group) error { - return status.Errorf(status.Internal, "SaveGroup is not implemented") -} - -func (s *FileStore) DeleteGroup(_ context.Context, _ LockingStrength, _, _ string) error { - return status.Errorf(status.Internal, "DeleteGroup is not implemented") -} -func (s *FileStore) DeleteGroups(_ context.Context, _ LockingStrength, _ []string, _ string) error { - return status.Errorf(status.Internal, "DeleteGroups is not implemented") -} - -func (s *FileStore) GetAccountUsers(_ context.Context, _ LockingStrength, _ string) ([]*User, error) { - return nil, status.Errorf(status.Internal, "GetAccountUsers is not implemented") -} - -func (s *FileStore) SaveUser(_ context.Context, _ LockingStrength, _ *User) error { - return status.Errorf(status.Internal, "SaveUser is not implemented") -} - -func (s *FileStore) SaveUsers(_ context.Context, _ LockingStrength, _ []*User) error { - return status.Errorf(status.Internal, "SaveUsers is not implemented") -} - -func (s *FileStore) DeleteUser(_ context.Context, _ LockingStrength, _, _ string) error { - return status.Errorf(status.Internal, "DeleteUser is not implemented") -} - -func (s *FileStore) DeleteUsers(_ context.Context, _ LockingStrength, _ []string, _ string) error { - return status.Errorf(status.Internal, "DeleteUsers is not implemented") -} - -func (s *FileStore) GetAccountOwnerID(_ context.Context, _ LockingStrength, _ string) (string, error) { - return "", status.Errorf(status.Internal, "GetAccountOwnerID is not implemented") -} diff --git a/management/server/file_store_test.go b/management/server/file_store_test.go deleted file mode 100644 index 56e46b696..000000000 --- a/management/server/file_store_test.go +++ /dev/null @@ -1,655 +0,0 @@ -package server - -import ( - "context" - "crypto/sha256" - "net" - "path/filepath" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/netbirdio/netbird/management/server/group" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/util" -) - -type accounts struct { - Accounts map[string]*Account -} - -func TestStalePeerIndices(t *testing.T) { - storeDir := t.TempDir() - - err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json")) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - return - } - - account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") - require.NoError(t, err) - - peerID := "some_peer" - peerKey := "some_peer_key" - account.Peers[peerID] = &nbpeer.Peer{ - ID: peerID, - Key: peerKey, - } - - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err) - - account.DeletePeer(peerID) - - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err) - - _, err = store.GetAccountByPeerID(context.Background(), peerID) - require.Error(t, err, "expecting to get an error when found stale index") - - _, err = store.GetAccountByPeerPubKey(context.Background(), peerKey) - require.Error(t, err, "expecting to get an error when found stale index") -} - -func TestNewStore(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) - - if store.Accounts == nil || len(store.Accounts) != 0 { - t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") - } - - if store.SetupKeyID2AccountID == nil || len(store.SetupKeyID2AccountID) != 0 { - t.Errorf("expected to create a new empty SetupKeyID2AccountID map when creating a new FileStore") - } - - if store.PeerKeyID2AccountID == nil || len(store.PeerKeyID2AccountID) != 0 { - t.Errorf("expected to create a new empty PeerKeyID2AccountID map when creating a new FileStore") - } - - if store.UserID2AccountID == nil || len(store.UserID2AccountID) != 0 { - t.Errorf("expected to create a new empty UserID2AccountID map when creating a new FileStore") - } - - if store.HashedPAT2TokenID == nil || len(store.HashedPAT2TokenID) != 0 { - t.Errorf("expected to create a new empty HashedPAT2TokenID map when creating a new FileStore") - } - - if store.TokenID2UserID == nil || len(store.TokenID2UserID) != 0 { - t.Errorf("expected to create a new empty TokenID2UserID map when creating a new FileStore") - } -} - -func TestSaveAccount(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) - - account := newAccountWithId(context.Background(), "account_id", "testuser", "") - setupKey := GenerateDefaultSetupKey() - account.SetupKeys[setupKey.Key] = setupKey - account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } - - // SaveAccount should trigger persist - err := store.SaveAccount(context.Background(), account) - if err != nil { - return - } - - if store.Accounts[account.Id] == nil { - t.Errorf("expecting Account to be stored after SaveAccount()") - } - - if store.PeerKeyID2AccountID["peerkey"] == "" { - t.Errorf("expecting PeerKeyID2AccountID index updated after SaveAccount()") - } - - if store.UserID2AccountID["testuser"] == "" { - t.Errorf("expecting UserID2AccountID index updated after SaveAccount()") - } - - if store.SetupKeyID2AccountID[setupKey.Key] == "" { - t.Errorf("expecting SetupKeyID2AccountID index updated after SaveAccount()") - } -} - -func TestDeleteAccount(t *testing.T) { - storeDir := t.TempDir() - storeFile := filepath.Join(storeDir, "store.json") - err := util.CopyFileContents("testdata/store.json", storeFile) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - t.Fatal(err) - } - defer store.Close(context.Background()) - - var account *Account - for _, a := range store.Accounts { - account = a - break - } - - require.NotNil(t, account, "failed to restore a FileStore file and get at least one account") - - err = store.DeleteAccount(context.Background(), account) - require.NoError(t, err, "failed to delete account, error: %v", err) - - _, ok := store.Accounts[account.Id] - require.False(t, ok, "failed to delete account") - - for id := range account.Users { - _, ok := store.UserID2AccountID[id] - assert.False(t, ok, "failed to delete UserID2AccountID index") - for _, pat := range account.Users[id].PATs { - _, ok := store.HashedPAT2TokenID[pat.HashedToken] - assert.False(t, ok, "failed to delete HashedPAT2TokenID index") - _, ok = store.TokenID2UserID[pat.ID] - assert.False(t, ok, "failed to delete TokenID2UserID index") - } - } - - for _, p := range account.Peers { - _, ok := store.PeerKeyID2AccountID[p.Key] - assert.False(t, ok, "failed to delete PeerKeyID2AccountID index") - _, ok = store.PeerID2AccountID[p.ID] - assert.False(t, ok, "failed to delete PeerID2AccountID index") - } - - for id := range account.SetupKeys { - _, ok := store.SetupKeyID2AccountID[id] - assert.False(t, ok, "failed to delete SetupKeyID2AccountID index") - } - - _, ok = store.PrivateDomain2AccountID[account.Domain] - assert.False(t, ok, "failed to delete PrivateDomain2AccountID index") - -} - -func TestStore(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) - - account := newAccountWithId(context.Background(), "account_id", "testuser", "") - account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } - account.Groups["all"] = &group.Group{ - ID: "all", - Name: "all", - Peers: []string{"testpeer"}, - } - account.Policies = append(account.Policies, &Policy{ - ID: "all", - Name: "all", - Enabled: true, - Rules: []*PolicyRule{ - { - ID: "all", - Name: "all", - Sources: []string{"all"}, - Destinations: []string{"all"}, - }, - }, - }) - account.Policies = append(account.Policies, &Policy{ - ID: "dmz", - Name: "dmz", - Enabled: true, - Rules: []*PolicyRule{ - { - ID: "dmz", - Name: "dmz", - Enabled: true, - Sources: []string{"all"}, - Destinations: []string{"all"}, - }, - }, - }) - - // SaveAccount should trigger persist - err := store.SaveAccount(context.Background(), account) - if err != nil { - return - } - - restored, err := NewFileStore(context.Background(), store.storeFile, nil) - if err != nil { - return - } - - restoredAccount := restored.Accounts[account.Id] - if restoredAccount == nil { - t.Errorf("failed to restore a FileStore file - missing Account %s", account.Id) - return - } - - if restoredAccount.Peers["testpeer"] == nil { - t.Errorf("failed to restore a FileStore file - missing Peer testpeer") - } - - if restoredAccount.CreatedBy != "testuser" { - t.Errorf("failed to restore a FileStore file - missing Account CreatedBy") - } - - if restoredAccount.Users["testuser"] == nil { - t.Errorf("failed to restore a FileStore file - missing User testuser") - } - - if restoredAccount.Network == nil { - t.Errorf("failed to restore a FileStore file - missing Network") - } - - if restoredAccount.Groups["all"] == nil { - t.Errorf("failed to restore a FileStore file - missing Group all") - } - - if len(restoredAccount.Policies) != 2 { - t.Errorf("failed to restore a FileStore file - missing Policies") - return - } - - assert.Equal(t, account.Policies[0], restoredAccount.Policies[0], "failed to restore a FileStore file - missing Policy all") - assert.Equal(t, account.Policies[1], restoredAccount.Policies[1], "failed to restore a FileStore file - missing Policy dmz") -} - -func TestRestore(t *testing.T) { - storeDir := t.TempDir() - - err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json")) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - return - } - - account := store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] - - require.NotNil(t, account, "failed to restore a FileStore file - missing account bf1c8084-ba50-4ce7-9439-34653001fc3b") - - require.NotNil(t, account.Users["edafee4e-63fb-11ec-90d6-0242ac120003"], "failed to restore a FileStore file - missing Account User edafee4e-63fb-11ec-90d6-0242ac120003") - - require.NotNil(t, account.Users["f4f6d672-63fb-11ec-90d6-0242ac120003"], "failed to restore a FileStore file - missing Account User f4f6d672-63fb-11ec-90d6-0242ac120003") - - require.NotNil(t, account.Network, "failed to restore a FileStore file - missing Account Network") - - require.NotNil(t, account.SetupKeys["A2C8E62B-38F5-4553-B31E-DD66C696CEBB"], "failed to restore a FileStore file - missing Account SetupKey A2C8E62B-38F5-4553-B31E-DD66C696CEBB") - - require.NotNil(t, account.Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"], "failed to restore a FileStore wrong PATs length") - - require.Len(t, store.UserID2AccountID, 2, "failed to restore a FileStore wrong UserID2AccountID mapping length") - - require.Len(t, store.SetupKeyID2AccountID, 1, "failed to restore a FileStore wrong SetupKeyID2AccountID mapping length") - - require.Len(t, store.PrivateDomain2AccountID, 1, "failed to restore a FileStore wrong PrivateDomain2AccountID mapping length") - - require.Len(t, store.HashedPAT2TokenID, 1, "failed to restore a FileStore wrong HashedPAT2TokenID mapping length") - - require.Len(t, store.TokenID2UserID, 1, "failed to restore a FileStore wrong TokenID2UserID mapping length") -} - -func TestRestoreGroups_Migration(t *testing.T) { - storeDir := t.TempDir() - - err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json")) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - return - } - - // create default group - account := store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] - account.Groups = map[string]*group.Group{ - "cfefqs706sqkneg59g3g": { - ID: "cfefqs706sqkneg59g3g", - Name: "All", - }, - } - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err, "failed to save account") - - // restore account with default group with empty Issue field - if store, err = NewFileStore(context.Background(), storeDir, nil); err != nil { - return - } - account = store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] - - require.Contains(t, account.Groups, "cfefqs706sqkneg59g3g", "failed to restore a FileStore file - missing Account Groups") - require.Equal(t, group.GroupIssuedAPI, account.Groups["cfefqs706sqkneg59g3g"].Issued, "default group should has API issued mark") -} - -func TestGetAccountByPrivateDomain(t *testing.T) { - storeDir := t.TempDir() - - err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json")) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - return - } - - existingDomain := "test.com" - - account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain) - require.NoError(t, err, "should found account") - require.Equal(t, existingDomain, account.Domain, "domains should match") - - _, err = store.GetAccountByPrivateDomain(context.Background(), "missing-domain.com") - require.Error(t, err, "should return error on domain lookup") -} - -func TestFileStore_GetAccount(t *testing.T) { - storeDir := t.TempDir() - storeFile := filepath.Join(storeDir, "store.json") - err := util.CopyFileContents("testdata/store.json", storeFile) - if err != nil { - t.Fatal(err) - } - - accounts := &accounts{} - _, err = util.ReadJson(storeFile, accounts) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - t.Fatal(err) - } - - expected := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] - if expected == nil { - t.Fatalf("expected account doesn't exist") - return - } - - account, err := store.GetAccount(context.Background(), expected.Id) - if err != nil { - t.Fatal(err) - } - - assert.Equal(t, expected.IsDomainPrimaryAccount, account.IsDomainPrimaryAccount) - assert.Equal(t, expected.DomainCategory, account.DomainCategory) - assert.Equal(t, expected.Domain, account.Domain) - assert.Equal(t, expected.CreatedBy, account.CreatedBy) - assert.Equal(t, expected.Network.Identifier, account.Network.Identifier) - assert.Len(t, account.Peers, len(expected.Peers)) - assert.Len(t, account.Users, len(expected.Users)) - assert.Len(t, account.SetupKeys, len(expected.SetupKeys)) - assert.Len(t, account.Routes, len(expected.Routes)) - assert.Len(t, account.NameServerGroups, len(expected.NameServerGroups)) -} - -func TestFileStore_GetTokenIDByHashedToken(t *testing.T) { - storeDir := t.TempDir() - storeFile := filepath.Join(storeDir, "store.json") - err := util.CopyFileContents("testdata/store.json", storeFile) - if err != nil { - t.Fatal(err) - } - - accounts := &accounts{} - _, err = util.ReadJson(storeFile, accounts) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - t.Fatal(err) - } - - hashedToken := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].HashedToken - tokenID, err := store.GetTokenIDByHashedToken(context.Background(), hashedToken) - if err != nil { - t.Fatal(err) - } - - expectedTokenID := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].ID - assert.Equal(t, expectedTokenID, tokenID) -} - -func TestFileStore_DeleteHashedPAT2TokenIDIndex(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) - store.HashedPAT2TokenID["someHashedToken"] = "someTokenId" - - err := store.DeleteHashedPAT2TokenIDIndex("someHashedToken") - if err != nil { - t.Fatal(err) - } - - assert.Empty(t, store.HashedPAT2TokenID["someHashedToken"]) -} - -func TestFileStore_DeleteTokenID2UserIDIndex(t *testing.T) { - store := newStore(t) - store.TokenID2UserID["someTokenId"] = "someUserId" - - err := store.DeleteTokenID2UserIDIndex("someTokenId") - if err != nil { - t.Fatal(err) - } - - assert.Empty(t, store.TokenID2UserID["someTokenId"]) -} - -func TestFileStore_GetTokenIDByHashedToken_Failure(t *testing.T) { - storeDir := t.TempDir() - storeFile := filepath.Join(storeDir, "store.json") - err := util.CopyFileContents("testdata/store.json", storeFile) - if err != nil { - t.Fatal(err) - } - - accounts := &accounts{} - _, err = util.ReadJson(storeFile, accounts) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - t.Fatal(err) - } - - wrongToken := sha256.Sum256([]byte("someNotValidTokenThatFails1234")) - _, err = store.GetTokenIDByHashedToken(context.Background(), string(wrongToken[:])) - - assert.Error(t, err, "GetTokenIDByHashedToken should throw error if token invalid") -} - -func TestFileStore_GetUserByTokenID(t *testing.T) { - storeDir := t.TempDir() - storeFile := filepath.Join(storeDir, "store.json") - err := util.CopyFileContents("testdata/store.json", storeFile) - if err != nil { - t.Fatal(err) - } - - accounts := &accounts{} - _, err = util.ReadJson(storeFile, accounts) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - t.Fatal(err) - } - - tokenID := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].ID - user, err := store.GetUserByTokenID(context.Background(), tokenID) - if err != nil { - t.Fatal(err) - } - - assert.Equal(t, "f4f6d672-63fb-11ec-90d6-0242ac120003", user.Id) -} - -func TestFileStore_GetUserByTokenID_Failure(t *testing.T) { - storeDir := t.TempDir() - storeFile := filepath.Join(storeDir, "store.json") - err := util.CopyFileContents("testdata/store.json", storeFile) - if err != nil { - t.Fatal(err) - } - - accounts := &accounts{} - _, err = util.ReadJson(storeFile, accounts) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - t.Fatal(err) - } - - wrongTokenID := "someNonExistingTokenID" - _, err = store.GetUserByTokenID(context.Background(), wrongTokenID) - - assert.Error(t, err, "GetUserByTokenID should throw error if tokenID invalid") -} - -func TestFileStore_SavePeerStatus(t *testing.T) { - storeDir := t.TempDir() - - err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json")) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - return - } - - account, err := store.getAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b") - if err != nil { - t.Fatal(err) - } - - // save status of non-existing peer - newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()} - err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus) - assert.Error(t, err) - - // save new status of existing peer - account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - ID: "testpeer", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}, - } - - err = store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatal(err) - } - - err = store.SavePeerStatus(account.Id, "testpeer", newStatus) - if err != nil { - t.Fatal(err) - } - account, err = store.getAccount(account.Id) - if err != nil { - t.Fatal(err) - } - - actual := account.Peers["testpeer"].Status - assert.Equal(t, newStatus, *actual) -} - -func TestFileStore_SavePeerLocation(t *testing.T) { - storeDir := t.TempDir() - - err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json")) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - return - } - account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") - require.NoError(t, err) - - peer := &nbpeer.Peer{ - AccountID: account.Id, - ID: "testpeer", - Location: nbpeer.Location{ - ConnectionIP: net.ParseIP("10.0.0.0"), - CountryCode: "YY", - CityName: "City", - GeoNameID: 1, - }, - Meta: nbpeer.PeerSystemMeta{}, - } - // error is expected as peer is not in store yet - err = store.SavePeerLocation(account.Id, peer) - assert.Error(t, err) - - account.Peers[peer.ID] = peer - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err) - - peer.Location.ConnectionIP = net.ParseIP("35.1.1.1") - peer.Location.CountryCode = "DE" - peer.Location.CityName = "Berlin" - peer.Location.GeoNameID = 2950159 - - err = store.SavePeerLocation(account.Id, account.Peers[peer.ID]) - assert.NoError(t, err) - - account, err = store.GetAccount(context.Background(), account.Id) - require.NoError(t, err) - - actual := account.Peers[peer.ID].Location - assert.Equal(t, peer.Location, actual) -} - -func newStore(t *testing.T) *FileStore { - t.Helper() - store, err := NewFileStore(context.Background(), t.TempDir(), nil) - if err != nil { - t.Errorf("failed creating a new store") - } - - return store -} diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index cda3bc748..4c4ef6c3c 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -596,6 +596,10 @@ func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turn response.NetworkMap.FirewallRules = firewallRules response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0 + routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules) + response.NetworkMap.RoutesFirewallRules = routesFirewallRules + response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0 + return response } diff --git a/management/server/http/accounts_handler.go b/management/server/http/accounts_handler.go index 73bd5c35d..5ff07c821 100644 --- a/management/server/http/accounts_handler.go +++ b/management/server/http/accounts_handler.go @@ -78,6 +78,9 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled, PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)), RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked, + + PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled, + PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)), } if req.Settings.Extra != nil { diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 2463f830e..9d5148248 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -54,6 +54,14 @@ components: description: Period of time after which peer login expires (seconds). type: integer example: 43200 + peer_inactivity_expiration_enabled: + description: Enables or disables peer inactivity expiration globally. After peer's session has expired the user has to log in (authenticate). Applies only to peers that were added by a user (interactive SSO login). + type: boolean + example: true + peer_inactivity_expiration: + description: Period of time of inactivity after which peer session expires (seconds). + type: integer + example: 43200 regular_users_view_blocked: description: Allows blocking regular users from viewing parts of the system. type: boolean @@ -81,6 +89,8 @@ components: required: - peer_login_expiration_enabled - peer_login_expiration + - peer_inactivity_expiration_enabled + - peer_inactivity_expiration - regular_users_view_blocked AccountExtraSettings: type: object @@ -243,6 +253,9 @@ components: login_expiration_enabled: type: boolean example: false + inactivity_expiration_enabled: + type: boolean + example: false approval_required: description: (Cloud only) Indicates whether peer needs approval type: boolean @@ -251,6 +264,7 @@ components: - name - ssh_enabled - login_expiration_enabled + - inactivity_expiration_enabled Peer: allOf: - $ref: '#/components/schemas/PeerMinimum' @@ -327,6 +341,10 @@ components: type: string format: date-time example: "2023-05-05T09:00:35.477782Z" + inactivity_expiration_enabled: + description: Indicates whether peer inactivity expiration has been enabled or not + type: boolean + example: false approval_required: description: (Cloud only) Indicates whether peer needs approval type: boolean @@ -354,6 +372,7 @@ components: - last_seen - login_expiration_enabled - login_expired + - inactivity_expiration_enabled - os - ssh_enabled - user_id @@ -727,17 +746,39 @@ components: enum: ["all", "tcp", "udp", "icmp"] example: "tcp" ports: - description: Policy rule affected ports or it ranges list + description: Policy rule affected ports type: array items: type: string example: "80" + port_ranges: + description: Policy rule affected ports ranges list + type: array + items: + $ref: '#/components/schemas/RulePortRange' required: - name - enabled - bidirectional - protocol - action + + RulePortRange: + description: Policy rule affected ports range + type: object + properties: + start: + description: The starting port of the range + type: integer + example: 80 + end: + description: The ending port of the range + type: integer + example: 320 + required: + - start + - end + PolicyRuleUpdate: allOf: - $ref: '#/components/schemas/PolicyRuleMinimum' @@ -1106,6 +1147,12 @@ components: description: Indicate if the route should be kept after a domain doesn't resolve that IP anymore type: boolean example: true + access_control_groups: + description: Access control group identifier associated with route. + type: array + items: + type: string + example: "chacbco6lnnbn6cg5s91" required: - id - description diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index b219d38fd..e2870d5d8 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -220,6 +220,12 @@ type AccountSettings struct { // JwtGroupsEnabled Allows extract groups from JWT claim and add it to account groups. JwtGroupsEnabled *bool `json:"jwt_groups_enabled,omitempty"` + // PeerInactivityExpiration Period of time of inactivity after which peer session expires (seconds). + PeerInactivityExpiration int `json:"peer_inactivity_expiration"` + + // PeerInactivityExpirationEnabled Enables or disables peer inactivity expiration globally. After peer's session has expired the user has to log in (authenticate). Applies only to peers that were added by a user (interactive SSO login). + PeerInactivityExpirationEnabled bool `json:"peer_inactivity_expiration_enabled"` + // PeerLoginExpiration Period of time after which peer login expires (seconds). PeerLoginExpiration int `json:"peer_login_expiration"` @@ -538,6 +544,9 @@ type Peer struct { // Id Peer ID Id string `json:"id"` + // InactivityExpirationEnabled Indicates whether peer inactivity expiration has been enabled or not + InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"` + // Ip Peer's IP address Ip string `json:"ip"` @@ -613,6 +622,9 @@ type PeerBatch struct { // Id Peer ID Id string `json:"id"` + // InactivityExpirationEnabled Indicates whether peer inactivity expiration has been enabled or not + InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"` + // Ip Peer's IP address Ip string `json:"ip"` @@ -677,10 +689,11 @@ type PeerNetworkRangeCheckAction string // PeerRequest defines model for PeerRequest. type PeerRequest struct { // ApprovalRequired (Cloud only) Indicates whether peer needs approval - ApprovalRequired *bool `json:"approval_required,omitempty"` - LoginExpirationEnabled bool `json:"login_expiration_enabled"` - Name string `json:"name"` - SshEnabled bool `json:"ssh_enabled"` + ApprovalRequired *bool `json:"approval_required,omitempty"` + InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"` + LoginExpirationEnabled bool `json:"login_expiration_enabled"` + Name string `json:"name"` + SshEnabled bool `json:"ssh_enabled"` } // PersonalAccessToken defines model for PersonalAccessToken. @@ -780,7 +793,10 @@ type PolicyRule struct { // Name Policy rule name identifier Name string `json:"name"` - // Ports Policy rule affected ports or it ranges list + // PortRanges Policy rule affected ports ranges list + PortRanges *[]RulePortRange `json:"port_ranges,omitempty"` + + // Ports Policy rule affected ports Ports *[]string `json:"ports,omitempty"` // Protocol Policy rule type of the traffic @@ -816,7 +832,10 @@ type PolicyRuleMinimum struct { // Name Policy rule name identifier Name string `json:"name"` - // Ports Policy rule affected ports or it ranges list + // PortRanges Policy rule affected ports ranges list + PortRanges *[]RulePortRange `json:"port_ranges,omitempty"` + + // Ports Policy rule affected ports Ports *[]string `json:"ports,omitempty"` // Protocol Policy rule type of the traffic @@ -852,7 +871,10 @@ type PolicyRuleUpdate struct { // Name Policy rule name identifier Name string `json:"name"` - // Ports Policy rule affected ports or it ranges list + // PortRanges Policy rule affected ports ranges list + PortRanges *[]RulePortRange `json:"port_ranges,omitempty"` + + // Ports Policy rule affected ports Ports *[]string `json:"ports,omitempty"` // Protocol Policy rule type of the traffic @@ -935,6 +957,9 @@ type ProcessCheck struct { // Route defines model for Route. type Route struct { + // AccessControlGroups Access control group identifier associated with route. + AccessControlGroups *[]string `json:"access_control_groups,omitempty"` + // Description Route description Description string `json:"description"` @@ -977,6 +1002,9 @@ type Route struct { // RouteRequest defines model for RouteRequest. type RouteRequest struct { + // AccessControlGroups Access control group identifier associated with route. + AccessControlGroups *[]string `json:"access_control_groups,omitempty"` + // Description Route description Description string `json:"description"` @@ -1011,6 +1039,15 @@ type RouteRequest struct { PeerGroups *[]string `json:"peer_groups,omitempty"` } +// RulePortRange Policy rule affected ports range +type RulePortRange struct { + // End The ending port of the range + End int `json:"end"` + + // Start The starting port of the range + Start int `json:"start"` +} + // SetupKey defines model for SetupKey. type SetupKey struct { // AutoGroups List of group IDs to auto-assign to peers registered with this key diff --git a/management/server/http/handler.go b/management/server/http/handler.go index ef94f22b9..3f8a8554d 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -82,7 +82,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa AuthCfg: authCfg, } - if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator); err != nil { + if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator, appMetrics.GetMeter()); err != nil { return nil, fmt.Errorf("register integrations endpoints: %w", err) } diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index 4fbbc3106..a5856a0e4 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -7,6 +7,8 @@ import ( "net/http" "github.com/gorilla/mux" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" @@ -14,7 +16,6 @@ import ( "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" - log "github.com/sirupsen/logrus" ) // PeersHandler is a handler that returns peers of the account @@ -87,6 +88,8 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, SSHEnabled: req.SshEnabled, Name: req.Name, LoginExpirationEnabled: req.LoginExpirationEnabled, + + InactivityExpirationEnabled: req.InactivityExpirationEnabled, } if req.ApprovalRequired != nil { @@ -331,29 +334,30 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD } return &api.Peer{ - Id: peer.ID, - Name: peer.Name, - Ip: peer.IP.String(), - ConnectionIp: peer.Location.ConnectionIP.String(), - Connected: peer.Status.Connected, - LastSeen: peer.Status.LastSeen, - Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion), - KernelVersion: peer.Meta.KernelVersion, - GeonameId: int(peer.Location.GeoNameID), - Version: peer.Meta.WtVersion, - Groups: groupsInfo, - SshEnabled: peer.SSHEnabled, - Hostname: peer.Meta.Hostname, - UserId: peer.UserID, - UiVersion: peer.Meta.UIVersion, - DnsLabel: fqdn(peer, dnsDomain), - LoginExpirationEnabled: peer.LoginExpirationEnabled, - LastLogin: peer.LastLogin, - LoginExpired: peer.Status.LoginExpired, - ApprovalRequired: !approved, - CountryCode: peer.Location.CountryCode, - CityName: peer.Location.CityName, - SerialNumber: peer.Meta.SystemSerialNumber, + Id: peer.ID, + Name: peer.Name, + Ip: peer.IP.String(), + ConnectionIp: peer.Location.ConnectionIP.String(), + Connected: peer.Status.Connected, + LastSeen: peer.Status.LastSeen, + Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion), + KernelVersion: peer.Meta.KernelVersion, + GeonameId: int(peer.Location.GeoNameID), + Version: peer.Meta.WtVersion, + Groups: groupsInfo, + SshEnabled: peer.SSHEnabled, + Hostname: peer.Meta.Hostname, + UserId: peer.UserID, + UiVersion: peer.Meta.UIVersion, + DnsLabel: fqdn(peer, dnsDomain), + LoginExpirationEnabled: peer.LoginExpirationEnabled, + LastLogin: peer.LastLogin, + LoginExpired: peer.Status.LoginExpired, + ApprovalRequired: !approved, + CountryCode: peer.Location.CountryCode, + CityName: peer.Location.CityName, + SerialNumber: peer.Meta.SystemSerialNumber, + InactivityExpirationEnabled: peer.InactivityExpirationEnabled, } } @@ -387,6 +391,8 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn CountryCode: peer.Location.CountryCode, CityName: peer.Location.CityName, SerialNumber: peer.Meta.SystemSerialNumber, + + InactivityExpirationEnabled: peer.InactivityExpirationEnabled, } } diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go index a9a87800f..30c84dbfc 100644 --- a/management/server/http/policies_handler.go +++ b/management/server/http/policies_handler.go @@ -173,6 +173,11 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID return } + if (rule.Ports != nil && len(*rule.Ports) != 0) && (rule.PortRanges != nil && len(*rule.PortRanges) != 0) { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "specify either individual ports or port ranges, not both"), w) + return + } + if rule.Ports != nil && len(*rule.Ports) != 0 { for _, v := range *rule.Ports { if port, err := strconv.Atoi(v); err != nil || port < 1 || port > 65535 { @@ -183,10 +188,23 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID } } + if rule.PortRanges != nil && len(*rule.PortRanges) != 0 { + for _, portRange := range *rule.PortRanges { + if portRange.Start < 1 || portRange.End > 65535 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "valid port value is in 1..65535 range"), w) + return + } + pr.PortRanges = append(pr.PortRanges, server.RulePortRange{ + Start: uint16(portRange.Start), + End: uint16(portRange.End), + }) + } + } + // validate policy object switch pr.Protocol { case server.PolicyRuleProtocolALL, server.PolicyRuleProtocolICMP: - if len(pr.Ports) != 0 { + if len(pr.Ports) != 0 || len(pr.PortRanges) != 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w) return } @@ -195,7 +213,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID return } case server.PolicyRuleProtocolTCP, server.PolicyRuleProtocolUDP: - if !pr.Bidirectional && len(pr.Ports) == 0 { + if !pr.Bidirectional && (len(pr.Ports) == 0 || len(pr.PortRanges) != 0) { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w) return } @@ -321,6 +339,17 @@ func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Polic rule.Ports = &portsCopy } + if len(r.PortRanges) != 0 { + portRanges := make([]api.RulePortRange, 0, len(r.PortRanges)) + for _, portRange := range r.PortRanges { + portRanges = append(portRanges, api.RulePortRange{ + End: int(portRange.End), + Start: int(portRange.Start), + }) + } + rule.PortRanges = &portRanges + } + for _, gid := range r.Sources { _, ok := cache[gid] if ok { diff --git a/management/server/http/routes_handler.go b/management/server/http/routes_handler.go index 0932e6445..ce4edee4f 100644 --- a/management/server/http/routes_handler.go +++ b/management/server/http/routes_handler.go @@ -117,9 +117,14 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { peerGroupIds = *req.PeerGroups } + var accessControlGroupIds []string + if req.AccessControlGroups != nil { + accessControlGroupIds = *req.AccessControlGroups + } + newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds, - req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, userID, req.KeepRoute, - ) + req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userID, req.KeepRoute) + if err != nil { util.WriteError(r.Context(), err, w) return @@ -233,6 +238,10 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { newRoute.PeerGroups = *req.PeerGroups } + if req.AccessControlGroups != nil { + newRoute.AccessControlGroups = *req.AccessControlGroups + } + err = h.accountManager.SaveRoute(r.Context(), accountID, userID, newRoute) if err != nil { util.WriteError(r.Context(), err, w) @@ -326,6 +335,9 @@ func toRouteResponse(serverRoute *route.Route) (*api.Route, error) { if len(serverRoute.PeerGroups) > 0 { route.PeerGroups = &serverRoute.PeerGroups } + if len(serverRoute.AccessControlGroups) > 0 { + route.AccessControlGroups = &serverRoute.AccessControlGroups + } return route, nil } diff --git a/management/server/http/routes_handler_test.go b/management/server/http/routes_handler_test.go index 2c367cac3..83bd7004d 100644 --- a/management/server/http/routes_handler_test.go +++ b/management/server/http/routes_handler_test.go @@ -105,7 +105,7 @@ func initRoutesTestData() *RoutesHandler { } return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID) }, - CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) { + CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) { if peerID == notFoundPeerID { return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) } @@ -119,18 +119,19 @@ func initRoutesTestData() *RoutesHandler { } return &route.Route{ - ID: existingRouteID, - NetID: netID, - Peer: peerID, - PeerGroups: peerGroups, - Network: prefix, - Domains: domains, - NetworkType: networkType, - Description: description, - Masquerade: masquerade, - Enabled: enabled, - Groups: groups, - KeepRoute: keepRoute, + ID: existingRouteID, + NetID: netID, + Peer: peerID, + PeerGroups: peerGroups, + Network: prefix, + Domains: domains, + NetworkType: networkType, + Description: description, + Masquerade: masquerade, + Enabled: enabled, + Groups: groups, + KeepRoute: keepRoute, + AccessControlGroups: accessControlGroups, }, nil }, SaveRouteFunc: func(_ context.Context, _, _ string, r *route.Route) error { @@ -268,6 +269,27 @@ func TestRoutesHandlers(t *testing.T) { Groups: []string{existingGroupID}, }, }, + { + name: "POST OK With Access Control Groups", + requestType: http.MethodPost, + requestPath: "/api/routes", + requestBody: bytes.NewBuffer( + []byte(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"],\"access_control_groups\":[\"%s\"]}", existingPeerID, existingGroupID, existingGroupID))), + expectedStatus: http.StatusOK, + expectedBody: true, + expectedRoute: &api.Route{ + Id: existingRouteID, + Description: "Post", + NetworkId: "awesomeNet", + Network: toPtr("192.168.0.0/16"), + Peer: &existingPeerID, + NetworkType: route.IPv4NetworkString, + Masquerade: false, + Enabled: false, + Groups: []string{existingGroupID}, + AccessControlGroups: &[]string{existingGroupID}, + }, + }, { name: "POST Non Linux Peer", requestType: http.MethodPost, diff --git a/management/server/idp/zitadel.go b/management/server/idp/zitadel.go index 729b49733..9d7626844 100644 --- a/management/server/idp/zitadel.go +++ b/management/server/idp/zitadel.go @@ -2,10 +2,12 @@ package idp import ( "context" + "errors" "fmt" "io" "net/http" "net/url" + "slices" "strings" "sync" "time" @@ -97,6 +99,42 @@ type zitadelUserResponse struct { PasswordlessRegistration zitadelPasswordlessRegistration `json:"passwordlessRegistration"` } +// readZitadelError parses errors returned by the zitadel APIs from a response. +func readZitadelError(body io.ReadCloser) error { + bodyBytes, err := io.ReadAll(body) + if err != nil { + return fmt.Errorf("failed to read response body: %w", err) + } + + helper := JsonParser{} + var target map[string]interface{} + err = helper.Unmarshal(bodyBytes, &target) + if err != nil { + return fmt.Errorf("error unparsable body: %s", string(bodyBytes)) + } + + // ensure keys are ordered for consistent logging behaviour. + errorKeys := make([]string, 0, len(target)) + for k := range target { + errorKeys = append(errorKeys, k) + } + slices.Sort(errorKeys) + + var errsOut []string + for _, k := range errorKeys { + if _, isEmbedded := target[k].(map[string]interface{}); isEmbedded { + continue + } + errsOut = append(errsOut, fmt.Sprintf("%s: %v", k, target[k])) + } + + if len(errsOut) == 0 { + return errors.New("unknown error") + } + + return errors.New(strings.Join(errsOut, " ")) +} + // NewZitadelManager creates a new instance of the ZitadelManager. func NewZitadelManager(config ZitadelClientConfig, appMetrics telemetry.AppMetrics) (*ZitadelManager, error) { httpTransport := http.DefaultTransport.(*http.Transport).Clone() @@ -176,7 +214,8 @@ func (zc *ZitadelCredentials) requestJWTToken(ctx context.Context) (*http.Respon } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unable to get zitadel token, statusCode %d", resp.StatusCode) + zErr := readZitadelError(resp.Body) + return nil, fmt.Errorf("unable to get zitadel token, statusCode %d, zitadel: %w", resp.StatusCode, zErr) } return resp, nil @@ -489,7 +528,9 @@ func (zm *ZitadelManager) post(ctx context.Context, resource string, body string zm.appMetrics.IDPMetrics().CountRequestStatusError() } - return nil, fmt.Errorf("unable to post %s, statusCode %d", reqURL, resp.StatusCode) + zErr := readZitadelError(resp.Body) + + return nil, fmt.Errorf("unable to post %s, statusCode %d, zitadel: %w", reqURL, resp.StatusCode, zErr) } return io.ReadAll(resp.Body) @@ -561,7 +602,9 @@ func (zm *ZitadelManager) get(ctx context.Context, resource string, q url.Values zm.appMetrics.IDPMetrics().CountRequestStatusError() } - return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode) + zErr := readZitadelError(resp.Body) + + return nil, fmt.Errorf("unable to get %s, statusCode %d, zitadel: %w", reqURL, resp.StatusCode, zErr) } return io.ReadAll(resp.Body) diff --git a/management/server/idp/zitadel_test.go b/management/server/idp/zitadel_test.go index 6bc612e78..722f94fe0 100644 --- a/management/server/idp/zitadel_test.go +++ b/management/server/idp/zitadel_test.go @@ -66,7 +66,6 @@ func TestNewZitadelManager(t *testing.T) { } func TestZitadelRequestJWTToken(t *testing.T) { - type requestJWTTokenTest struct { name string inputCode int @@ -88,15 +87,14 @@ func TestZitadelRequestJWTToken(t *testing.T) { requestJWTTokenTestCase2 := requestJWTTokenTest{ name: "Request Bad Status Code", inputCode: 400, - inputRespBody: "{}", + inputRespBody: "{\"error\": \"invalid_scope\", \"error_description\":\"openid missing\"}", helper: JsonParser{}, - expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400"), + expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400, zitadel: error: invalid_scope error_description: openid missing"), expectedToken: "", } for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} { t.Run(testCase.name, func(t *testing.T) { - jwtReqClient := mockHTTPClient{ resBody: testCase.inputRespBody, code: testCase.inputCode, @@ -156,7 +154,7 @@ func TestZitadelParseRequestJWTResponse(t *testing.T) { } parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{ name: "Parse Bad json JWT Body", - inputRespBody: "", + inputRespBody: "{}", helper: JsonParser{}, expectedToken: "", expectedExpiresIn: 0, @@ -254,7 +252,7 @@ func TestZitadelAuthenticate(t *testing.T) { inputCode: 400, inputResBody: "{}", helper: JsonParser{}, - expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400"), + expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400, zitadel: unknown error"), expectedCode: 200, expectedToken: "", } diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index ff09129bd..dc8765e19 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -6,7 +6,6 @@ import ( "io" "net" "os" - "path/filepath" "runtime" "sync" "sync/atomic" @@ -89,14 +88,7 @@ func getServerKey(client mgmtProto.ManagementServiceClient) (*wgtypes.Key, error func Test_SyncProtocol(t *testing.T) { dir := t.TempDir() - err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json")) - if err != nil { - t.Fatal(err) - } - defer func() { - os.Remove(filepath.Join(dir, "store.json")) //nolint - }() - mgmtServer, _, mgmtAddr, err := startManagementForTest(t, &Config{ + mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{ Stuns: []*Host{{ Proto: "udp", URI: "stun:stun.wiretrustee.com:3468", @@ -117,6 +109,7 @@ func Test_SyncProtocol(t *testing.T) { Datadir: dir, HttpConfig: nil, }) + defer cleanup() if err != nil { t.Fatal(err) return @@ -412,18 +405,18 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) { } } -func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultAccountManager, string, error) { +func startManagementForTest(t *testing.T, testFile string, config *Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) { t.Helper() lis, err := net.Listen("tcp", "localhost:0") if err != nil { - return nil, nil, "", err + return nil, nil, "", nil, err } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := NewTestStoreFromJson(context.Background(), config.Datadir) + + store, cleanup, err := NewTestStoreFromSQL(context.Background(), testFile, t.TempDir()) if err != nil { - return nil, nil, "", err + t.Fatal(err) } - t.Cleanup(cleanUp) peersUpdateManager := NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} @@ -437,7 +430,8 @@ func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultA eventStore, nil, false, MocIntegratedValidator{}, metrics) if err != nil { - return nil, nil, "", err + cleanup() + return nil, nil, "", cleanup, err } secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) @@ -445,7 +439,7 @@ func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultA ephemeralMgr := NewEphemeralManager(store, accountManager) mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, ephemeralMgr) if err != nil { - return nil, nil, "", err + return nil, nil, "", cleanup, err } mgmtProto.RegisterManagementServiceServer(s, mgmtServer) @@ -455,7 +449,7 @@ func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultA } }() - return s, accountManager, lis.Addr().String(), nil + return s, accountManager, lis.Addr().String(), cleanup, nil } func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.ClientConn, error) { @@ -475,7 +469,9 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie return mgmtProto.NewManagementServiceClient(conn), conn, nil } + func Test_SyncStatusRace(t *testing.T) { + t.Skip() if os.Getenv("CI") == "true" && os.Getenv("NETBIRD_STORE_ENGINE") == "postgres" { t.Skip("Skipping on CI and Postgres store") } @@ -487,16 +483,10 @@ func Test_SyncStatusRace(t *testing.T) { } func testSyncStatusRace(t *testing.T) { t.Helper() + t.Skip() dir := t.TempDir() - err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json")) - if err != nil { - t.Fatal(err) - } - defer func() { - os.Remove(filepath.Join(dir, "store.json")) //nolint - }() - mgmtServer, am, mgmtAddr, err := startManagementForTest(t, &Config{ + mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{ Stuns: []*Host{{ Proto: "udp", URI: "stun:stun.wiretrustee.com:3468", @@ -517,6 +507,7 @@ func testSyncStatusRace(t *testing.T) { Datadir: dir, HttpConfig: nil, }) + defer cleanup() if err != nil { t.Fatal(err) return @@ -638,6 +629,7 @@ func testSyncStatusRace(t *testing.T) { } func Test_LoginPerformance(t *testing.T) { + t.Skip() if os.Getenv("CI") == "true" || runtime.GOOS == "windows" { t.Skip("Skipping test on CI or Windows") } @@ -665,15 +657,8 @@ func Test_LoginPerformance(t *testing.T) { t.Run(bc.name, func(t *testing.T) { t.Helper() dir := t.TempDir() - err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json")) - if err != nil { - t.Fatal(err) - } - defer func() { - os.Remove(filepath.Join(dir, "store.json")) //nolint - }() - mgmtServer, am, _, err := startManagementForTest(t, &Config{ + mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{ Stuns: []*Host{{ Proto: "udp", URI: "stun:stun.wiretrustee.com:3468", @@ -694,6 +679,7 @@ func Test_LoginPerformance(t *testing.T) { Datadir: dir, HttpConfig: nil, }) + defer cleanup() if err != nil { t.Fatal(err) return diff --git a/management/server/management_test.go b/management/server/management_test.go index 3956d96b1..d53c177d6 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -5,7 +5,6 @@ import ( "math/rand" "net" "os" - "path/filepath" "runtime" sync2 "sync" "time" @@ -52,8 +51,6 @@ var _ = Describe("Management service", func() { dataDir, err = os.MkdirTemp("", "wiretrustee_mgmt_test_tmp_*") Expect(err).NotTo(HaveOccurred()) - err = util.CopyFileContents("testdata/store.json", filepath.Join(dataDir, "store.json")) - Expect(err).NotTo(HaveOccurred()) var listener net.Listener config := &server.Config{} @@ -61,7 +58,7 @@ var _ = Describe("Management service", func() { Expect(err).NotTo(HaveOccurred()) config.Datadir = dataDir - s, listener = startServer(config) + s, listener = startServer(config, dataDir, "testdata/store.sql") addr = listener.Addr().String() client, conn = createRawClient(addr) @@ -530,12 +527,12 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie return mgmtProto.NewManagementServiceClient(conn), conn } -func startServer(config *server.Config) (*grpc.Server, net.Listener) { +func startServer(config *server.Config, dataDir string, testFile string) (*grpc.Server, net.Listener) { lis, err := net.Listen("tcp", ":0") Expect(err).NotTo(HaveOccurred()) s := grpc.NewServer() - store, _, err := server.NewTestStoreFromJson(context.Background(), config.Datadir) + store, _, err := server.NewTestStoreFromSQL(context.Background(), testFile, dataDir) if err != nil { log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index b3fb03f3d..51eef2af3 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -27,7 +27,8 @@ type MockAccountManager struct { CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) - GetAccountIDByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (string, error) + AccountExistsFunc func(ctx context.Context, accountID string) (bool, error) + GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error) GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) @@ -58,7 +59,7 @@ type MockAccountManager struct { UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error UpdatePeerSSHKeyFunc func(ctx context.Context, peerID string, sshKey string) error UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) - CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) + CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error @@ -194,14 +195,22 @@ func (am *MockAccountManager) CreateSetupKey( return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented") } -// GetAccountIDByUserOrAccountID mock implementation of GetAccountIDByUserOrAccountID from server.AccountManager interface -func (am *MockAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userId, accountId, domain string) (string, error) { - if am.GetAccountIDByUserOrAccountIdFunc != nil { - return am.GetAccountIDByUserOrAccountIdFunc(ctx, userId, accountId, domain) +// AccountExists mock implementation of AccountExists from server.AccountManager interface +func (am *MockAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) { + if am.AccountExistsFunc != nil { + return am.AccountExistsFunc(ctx, accountID) + } + return false, status.Errorf(codes.Unimplemented, "method AccountExists is not implemented") +} + +// GetAccountIDByUserID mock implementation of GetAccountIDByUserID from server.AccountManager interface +func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId, domain string) (string, error) { + if am.GetAccountIDByUserIdFunc != nil { + return am.GetAccountIDByUserIdFunc(ctx, userId, domain) } return "", status.Errorf( codes.Unimplemented, - "method GetAccountIDByUserOrAccountID is not implemented", + "method GetAccountIDByUserID is not implemented", ) } @@ -367,7 +376,7 @@ func (am *MockAccountManager) DeleteRule(ctx context.Context, accountID, ruleID, if am.DeleteRuleFunc != nil { return am.DeleteRuleFunc(ctx, accountID, ruleID, userID) } - return status.Errorf(codes.Unimplemented, "method DeleteRule is not implemented") + return status.Errorf(codes.Unimplemented, "method DeletePeerRule is not implemented") } // GetPolicy mock implementation of GetPolicy from server.AccountManager interface @@ -442,9 +451,9 @@ func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID } // CreateRoute mock implementation of CreateRoute from server.AccountManager interface -func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { +func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { if am.CreateRouteFunc != nil { - return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, enabled, userID, keepRoute) + return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, accessControlGroupID, enabled, userID, keepRoute) } return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented") } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 5f8545243..8a3fe6eb0 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -773,7 +773,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { func createNSStore(t *testing.T) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } diff --git a/management/server/network.go b/management/server/network.go index 0e7d753a7..a5b188b46 100644 --- a/management/server/network.go +++ b/management/server/network.go @@ -26,12 +26,13 @@ const ( ) type NetworkMap struct { - Peers []*nbpeer.Peer - Network *Network - Routes []*route.Route - DNSConfig nbdns.Config - OfflinePeers []*nbpeer.Peer - FirewallRules []*FirewallRule + Peers []*nbpeer.Peer + Network *Network + Routes []*route.Route + DNSConfig nbdns.Config + OfflinePeers []*nbpeer.Peer + FirewallRules []*FirewallRule + RoutesFirewallRules []*RouteFirewallRule } type Network struct { diff --git a/management/server/peer.go b/management/server/peer.go index f962ea414..a4c7e1266 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -110,6 +110,31 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK return err } + expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, account) + if err != nil { + return err + } + + if peer.AddedWithSSOLogin() { + if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { + am.checkAndSchedulePeerLoginExpiration(ctx, account) + } + + if peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled { + am.checkAndSchedulePeerInactivityExpiration(ctx, account) + } + } + + if expired { + // we need to update other peers because when peer login expires all other peers are notified to disconnect from + // the expired one. Here we notify them that connection is now allowed again. + am.updateAccountPeers(ctx, account) + } + + return nil +} + +func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, account *Account) (bool, error) { oldStatus := peer.Status.Copy() newStatus := oldStatus newStatus.LastSeen = time.Now().UTC() @@ -138,26 +163,15 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK account.UpdatePeer(peer) - err = am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus) + err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus) if err != nil { - return err + return false, err } - if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { - //TODO: use refactored method - //am.checkAndSchedulePeerLoginExpiration(ctx, account) - } - - if oldStatus.LoginExpired { - // we need to update other peers because when peer login expires all other peers are notified to disconnect from - // the expired one. Here we notify them that connection is now allowed again. - am.updateAccountPeers(ctx, account) - } - - return nil + return oldStatus.LoginExpired, nil } -// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated. +// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated. func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -216,8 +230,26 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { - //TODO: use refactored method - //am.checkAndSchedulePeerLoginExpiration(ctx, account) + am.checkAndSchedulePeerLoginExpiration(ctx, account) + } + } + + if peer.InactivityExpirationEnabled != update.InactivityExpirationEnabled { + + if !peer.AddedWithSSOLogin() { + return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated") + } + + peer.InactivityExpirationEnabled = update.InactivityExpirationEnabled + + event := activity.PeerInactivityExpirationEnabled + if !update.InactivityExpirationEnabled { + event = activity.PeerInactivityExpirationDisabled + } + am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) + + if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled { + am.checkAndSchedulePeerInactivityExpiration(ctx, account) } } @@ -444,23 +476,24 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s registrationTime := time.Now().UTC() newPeer = &nbpeer.Peer{ - ID: xid.New().String(), - AccountID: accountID, - Key: peer.Key, - SetupKey: upperKey, - IP: freeIP, - Meta: peer.Meta, - Name: peer.Meta.Hostname, - DNSLabel: freeLabel, - UserID: userID, - Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime}, - SSHEnabled: false, - SSHKey: peer.SSHKey, - LastLogin: registrationTime, - CreatedAt: registrationTime, - LoginExpirationEnabled: addedByUser, - Ephemeral: ephemeral, - Location: peer.Location, + ID: xid.New().String(), + AccountID: accountID, + Key: peer.Key, + SetupKey: upperKey, + IP: freeIP, + Meta: peer.Meta, + Name: peer.Meta.Hostname, + DNSLabel: freeLabel, + UserID: userID, + Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime}, + SSHEnabled: false, + SSHKey: peer.SSHKey, + LastLogin: registrationTime, + CreatedAt: registrationTime, + LoginExpirationEnabled: addedByUser, + Ephemeral: ephemeral, + Location: peer.Location, + InactivityExpirationEnabled: addedByUser, } opEvent.TargetID = newPeer.ID opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain()) @@ -504,7 +537,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return fmt.Errorf("failed to add peer to account: %w", err) } - err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID) + err = transaction.IncrementNetworkSerial(ctx, accountID) if err != nil { return fmt.Errorf("failed to increment network serial: %w", err) } @@ -695,6 +728,11 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) updateRemotePeers := false if login.UserID != "" { + if peer.UserID != login.UserID { + log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, login.UserID) + return nil, nil, nil, status.Errorf(status.Unauthenticated, "invalid user") + } + changed, err := am.handleUserPeer(ctx, peer, settings) if err != nil { return nil, nil, nil, err @@ -1003,72 +1041,6 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account wg.Wait() } -// getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. -// If there is no peer that expires this function returns false and a duration of 0. -// This function only considers peers that haven't been expired yet and that are connected. -func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) { - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get account settings: %v", err) - return 0, false - } - - peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, LockingStrengthShare, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err) - return 0, false - } - - if len(peersWithExpiry) == 0 { - return 0, false - } - var nextExpiry *time.Duration - for _, peer := range peersWithExpiry { - // consider only connected peers because others will require login on connecting to the management server - if peer.Status.LoginExpired || !peer.Status.Connected { - continue - } - _, duration := peer.LoginExpired(settings.PeerLoginExpiration) - if nextExpiry == nil || duration < *nextExpiry { - // if expiration is below 1s return 1s duration - // this avoids issues with ticker that can't be set to < 0 - if duration < time.Second { - return time.Second, true - } - nextExpiry = &duration - } - } - - if nextExpiry == nil { - return 0, false - } - - return *nextExpiry, true -} - -// getExpiredPeers returns peers that have been expired. -func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) { - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) - if err != nil { - return nil, err - } - - peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, LockingStrengthShare, accountID) - if err != nil { - return nil, err - } - - var peers []*nbpeer.Peer - for _, peer := range peersWithExpiry { - expired, _ := peer.LoginExpired(settings.PeerLoginExpiration) - if expired { - peers = append(peers, peer) - } - } - - return peers, nil -} - func ConvertSliceToMap(existingLabels []string) map[string]struct{} { labelMap := make(map[string]struct{}, len(existingLabels)) for _, label := range existingLabels { diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 3d9ba18e9..9a53459a8 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -38,6 +38,8 @@ type Peer struct { // LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login. // Works with LastLogin LoginExpirationEnabled bool + + InactivityExpirationEnabled bool // LastLogin the time when peer performed last login operation LastLogin time.Time // CreatedAt records the time the peer was created @@ -187,6 +189,8 @@ func (p *Peer) Copy() *Peer { CreatedAt: p.CreatedAt, Ephemeral: p.Ephemeral, Location: p.Location, + + InactivityExpirationEnabled: p.InactivityExpirationEnabled, } } @@ -219,6 +223,22 @@ func (p *Peer) MarkLoginExpired(expired bool) { p.Status = newStatus } +// SessionExpired indicates whether the peer's session has expired or not. +// If Peer.LastLogin plus the expiresIn duration has happened already; then session has expired. +// Return true if a session has expired, false otherwise, and time left to expiration (negative when expired). +// Session expiration can be disabled/enabled on a Peer level via Peer.LoginExpirationEnabled property. +// Session expiration can also be disabled/enabled globally on the Account level via Settings.PeerLoginExpirationEnabled. +// Only peers added by interactive SSO login can be expired. +func (p *Peer) SessionExpired(expiresIn time.Duration) (bool, time.Duration) { + if !p.AddedWithSSOLogin() || !p.InactivityExpirationEnabled || p.Status.Connected { + return false, 0 + } + expiresAt := p.Status.LastSeen.Add(expiresIn) + now := time.Now() + timeLeft := expiresAt.Sub(now) + return timeLeft <= 0, timeLeft +} + // LoginExpired indicates whether the peer's login has expired or not. // If Peer.LastLogin plus the expiresIn duration has happened already; then login has expired. // Return true if a login has expired, false otherwise, and time left to expiration (negative when expired). diff --git a/management/server/peer_test.go b/management/server/peer_test.go index d329e04bc..c5edb5636 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -82,6 +82,68 @@ func TestPeer_LoginExpired(t *testing.T) { } } +func TestPeer_SessionExpired(t *testing.T) { + tt := []struct { + name string + expirationEnabled bool + lastLogin time.Time + connected bool + expected bool + accountSettings *Settings + }{ + { + name: "Peer Inactivity Expiration Disabled. Peer Inactivity Should Not Expire", + expirationEnabled: false, + connected: false, + lastLogin: time.Now().UTC().Add(-1 * time.Second), + accountSettings: &Settings{ + PeerInactivityExpirationEnabled: true, + PeerInactivityExpiration: time.Hour, + }, + expected: false, + }, + { + name: "Peer Inactivity Should Expire", + expirationEnabled: true, + connected: false, + lastLogin: time.Now().UTC().Add(-1 * time.Second), + accountSettings: &Settings{ + PeerInactivityExpirationEnabled: true, + PeerInactivityExpiration: time.Second, + }, + expected: true, + }, + { + name: "Peer Inactivity Should Not Expire", + expirationEnabled: true, + connected: true, + lastLogin: time.Now().UTC(), + accountSettings: &Settings{ + PeerInactivityExpirationEnabled: true, + PeerInactivityExpiration: time.Second, + }, + expected: false, + }, + } + + for _, c := range tt { + t.Run(c.name, func(t *testing.T) { + peerStatus := &nbpeer.PeerStatus{ + Connected: c.connected, + } + peer := &nbpeer.Peer{ + InactivityExpirationEnabled: c.expirationEnabled, + LastLogin: c.lastLogin, + Status: peerStatus, + UserID: userID, + } + + expired, _ := peer.SessionExpired(c.accountSettings.PeerInactivityExpiration) + assert.Equal(t, expired, c.expected) + }) + } +} + func TestAccountManager_GetNetworkMap(t *testing.T) { manager, err := createManager(t) if err != nil { @@ -646,7 +708,6 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { }) } - } func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccountManager, string, string, error) { @@ -991,9 +1052,9 @@ func TestToSyncResponse(t *testing.T) { // assert network map Firewall assert.Equal(t, 1, len(response.NetworkMap.FirewallRules)) assert.Equal(t, "192.168.1.2", response.NetworkMap.FirewallRules[0].PeerIP) - assert.Equal(t, proto.FirewallRule_IN, response.NetworkMap.FirewallRules[0].Direction) - assert.Equal(t, proto.FirewallRule_ACCEPT, response.NetworkMap.FirewallRules[0].Action) - assert.Equal(t, proto.FirewallRule_TCP, response.NetworkMap.FirewallRules[0].Protocol) + assert.Equal(t, proto.RuleDirection_IN, response.NetworkMap.FirewallRules[0].Direction) + assert.Equal(t, proto.RuleAction_ACCEPT, response.NetworkMap.FirewallRules[0].Action) + assert.Equal(t, proto.RuleProtocol_TCP, response.NetworkMap.FirewallRules[0].Protocol) assert.Equal(t, "80", response.NetworkMap.FirewallRules[0].Port) // assert posture checks assert.Equal(t, 1, len(response.Checks)) @@ -1005,7 +1066,11 @@ func Test_RegisterPeerByUser(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + defer cleanup() eventStore := &activity.InMemoryEventStore{} @@ -1066,7 +1131,11 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + defer cleanup() eventStore := &activity.InMemoryEventStore{} @@ -1128,7 +1197,11 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + defer cleanup() eventStore := &activity.InMemoryEventStore{} @@ -1177,6 +1250,6 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z") assert.NoError(t, err) - assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed) + assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed.UTC()) assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes) } diff --git a/management/server/policy.go b/management/server/policy.go index d4d51b39c..75647de44 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -3,7 +3,7 @@ package server import ( "context" _ "embed" - "fmt" + "slices" "strconv" "strings" @@ -76,6 +76,12 @@ type PolicyUpdateOperation struct { Values []string } +// RulePortRange represents a range of ports for a firewall rule. +type RulePortRange struct { + Start uint16 + End uint16 +} + // PolicyRule is the metadata of the policy type PolicyRule struct { // ID of the policy rule @@ -110,6 +116,9 @@ type PolicyRule struct { // Ports or it ranges list Ports []string `gorm:"serializer:json"` + + // PortRanges a list of port ranges. + PortRanges []RulePortRange `gorm:"serializer:json"` } // Copy returns a copy of a policy rule @@ -125,10 +134,12 @@ func (pm *PolicyRule) Copy() *PolicyRule { Bidirectional: pm.Bidirectional, Protocol: pm.Protocol, Ports: make([]string, len(pm.Ports)), + PortRanges: make([]RulePortRange, len(pm.PortRanges)), } copy(rule.Destinations, pm.Destinations) copy(rule.Sources, pm.Sources) copy(rule.Ports, pm.Ports) + copy(rule.PortRanges, pm.PortRanges) return rule } @@ -321,7 +332,7 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic } if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only admin users are allowed to view policies") + return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") } return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID) @@ -329,48 +340,20 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic // SavePolicy in the store func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } - if !user.HasAdminPower() || user.AccountID != accountID { - return status.Errorf(status.PermissionDenied, "only admin users are allowed to update policies") - } - - groups, err := am.Store.GetAccountGroups(ctx, accountID) - if err != nil { + if err = am.savePolicy(account, policy, isUpdate); err != nil { return err } - postureChecks, err := am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) - if err != nil { - return err - } - - for index, rule := range policy.Rules { - rule.Sources = getValidGroupIDs(groups, rule.Sources) - rule.Destinations = getValidGroupIDs(groups, rule.Destinations) - policy.Rules[index] = rule - } - - if policy.SourcePostureChecks != nil { - policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks) - } - - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID) - if err != nil { - return fmt.Errorf("failed to increment network serial: %w", err) - } - - err = transaction.SavePolicy(ctx, LockingStrengthUpdate, policy) - if err != nil { - return fmt.Errorf("failed to save policy: %w", err) - } - return nil - }) - if err != nil { + account.Network.IncSerial() + if err = am.Store.SaveAccount(ctx, account); err != nil { return err } @@ -380,10 +363,6 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user } am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf("error getting account: %w", err) - } am.updateAccountPeers(ctx, account) return nil @@ -391,42 +370,26 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user // DeletePolicy from the store func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } - if !user.HasAdminPower() || user.AccountID != accountID { - return status.Errorf(status.PermissionDenied, "deleting policies is restricted to admin users only") - } - - policy, err := am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID) + policy, err := am.deletePolicy(account, policyID) if err != nil { return err } - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID) - if err != nil { - return fmt.Errorf("failed to increment network serial: %w", err) - } - - err = transaction.DeletePolicy(ctx, LockingStrengthUpdate, policyID, accountID) - if err != nil { - return fmt.Errorf("failed to delete policy: %w", err) - } - return nil - }) - if err != nil { + account.Network.IncSerial() + if err = am.Store.SaveAccount(ctx, account); err != nil { return err } - am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta()) + am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf("error getting account: %w", err) - } am.updateAccountPeers(ctx, account) return nil @@ -440,7 +403,7 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us } if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only admin users are allowed to view policies") + return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") } return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID) @@ -463,36 +426,47 @@ func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) return policy, nil } -func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule { - result := make([]*proto.FirewallRule, len(update)) - for i := range update { - direction := proto.FirewallRule_IN - if update[i].Direction == firewallRuleDirectionOUT { - direction = proto.FirewallRule_OUT - } - action := proto.FirewallRule_ACCEPT - if update[i].Action == string(PolicyTrafficActionDrop) { - action = proto.FirewallRule_DROP +// savePolicy saves or updates a policy in the given account. +// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy. +func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) error { + for index, rule := range policyToSave.Rules { + rule.Sources = filterValidGroupIDs(account, rule.Sources) + rule.Destinations = filterValidGroupIDs(account, rule.Destinations) + policyToSave.Rules[index] = rule + } + + if policyToSave.SourcePostureChecks != nil { + policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks) + } + + if isUpdate { + policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID }) + if policyIdx < 0 { + return status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID) } - protocol := proto.FirewallRule_UNKNOWN - switch PolicyRuleProtocolType(update[i].Protocol) { - case PolicyRuleProtocolALL: - protocol = proto.FirewallRule_ALL - case PolicyRuleProtocolTCP: - protocol = proto.FirewallRule_TCP - case PolicyRuleProtocolUDP: - protocol = proto.FirewallRule_UDP - case PolicyRuleProtocolICMP: - protocol = proto.FirewallRule_ICMP - } + // Update the existing policy + account.Policies[policyIdx] = policyToSave + return nil + } + + // Add the new policy to the account + account.Policies = append(account.Policies, policyToSave) + + return nil +} + +func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { + result := make([]*proto.FirewallRule, len(rules)) + for i := range rules { + rule := rules[i] result[i] = &proto.FirewallRule{ - PeerIP: update[i].PeerIP, - Direction: direction, - Action: action, - Protocol: protocol, - Port: update[i].Port, + PeerIP: rule.PeerIP, + Direction: getProtoDirection(rule.Direction), + Action: getProtoAction(rule.Action), + Protocol: getProtoProtocol(rule.Protocol), + Port: rule.Port, } } return result @@ -576,36 +550,28 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks { return nil } -// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list. -func getValidPostureCheckIDs(postureChecks []*posture.Checks, postureChecksIds []string) []string { - validPostureCheckIDs := make(map[string]struct{}) - for _, check := range postureChecks { - validPostureCheckIDs[check.ID] = struct{}{} - } - - validIDs := make([]string, 0, len(postureChecksIds)) +// filterValidPostureChecks filters and returns the posture check IDs from the given list +// that are valid within the provided account. +func filterValidPostureChecks(account *Account, postureChecksIds []string) []string { + result := make([]string, 0, len(postureChecksIds)) for _, id := range postureChecksIds { - if _, exists := validPostureCheckIDs[id]; exists { - validIDs = append(validIDs, id) + for _, postureCheck := range account.PostureChecks { + if id == postureCheck.ID { + result = append(result, id) + continue + } } } - - return validIDs + return result } -// getValidGroupIDs filters and returns only the valid group IDs from the provided list. -func getValidGroupIDs(groups []*nbgroup.Group, groupIDs []string) []string { - validGroupIDs := make(map[string]struct{}) - for _, group := range groups { - validGroupIDs[group.ID] = struct{}{} - } - - validIDs := make([]string, 0, len(groupIDs)) - for _, id := range groupIDs { - if _, exists := validGroupIDs[id]; exists { - validIDs = append(validIDs, id) +// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map. +func filterValidGroupIDs(account *Account, groupIDs []string) []string { + result := make([]string, 0, len(groupIDs)) + for _, groupID := range groupIDs { + if _, exists := account.Groups[groupID]; exists { + result = append(result, groupID) } } - - return validIDs + return result } diff --git a/management/server/route.go b/management/server/route.go index 9afe5d418..39ee6170c 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -4,9 +4,15 @@ import ( "context" "fmt" "net/netip" + "slices" + "strconv" + "strings" "unicode/utf8" "github.com/rs/xid" + log "github.com/sirupsen/logrus" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" @@ -15,6 +21,30 @@ import ( "github.com/netbirdio/netbird/route" ) +// RouteFirewallRule a firewall rule applicable for a routed network. +type RouteFirewallRule struct { + // SourceRanges IP ranges of the routing peers. + SourceRanges []string + + // Action of the traffic when the rule is applicable + Action string + + // Destination a network prefix for the routed traffic + Destination string + + // Protocol of the traffic + Protocol string + + // Port of the traffic + Port uint16 + + // PortRange represents the range of ports for a firewall rule + PortRange RulePortRange + + // isDynamic indicates whether the rule is for DNS routing + IsDynamic bool +} + // GetRoute gets a route object from account and route IDs func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) @@ -32,21 +62,7 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, // checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error { // routes can have both peer and peer_groups - routes, err := am.Store.GetAccountRoutes(context.Background(), LockingStrengthShare, account.Id) - if err != nil { - return err - } - - routesWithPrefix := make([]*route.Route, 0) - for _, r := range routes { - dynamic := r.IsDynamic() - if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() || - !dynamic && r.Network.String() == prefix.String() { - routesWithPrefix = append(routesWithPrefix, r) - } - } - - //routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains) + routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains) // lets remember all the peers and the peer groups from routesWithPrefix seenPeers := make(map[string]bool) @@ -65,8 +81,8 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account for _, groupID := range prefixRoute.PeerGroups { seenPeerGroups[groupID] = true - group, err := am.Store.GetGroupByID(context.Background(), LockingStrengthShare, groupID, account.Id) - if err != nil || group == nil { + group := account.GetGroup(groupID) + if group == nil { return status.Errorf( status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist", getRouteDescriptor(prefix, domains), groupID, @@ -81,11 +97,10 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account if peerID != "" { // check that peerID exists and is not in any route as single peer or part of the group - peer, err := am.Store.GetPeerByID(context.Background(), LockingStrengthShare, peerID, account.Id) - if err != nil || peer == nil { + peer := account.GetPeer(peerID) + if peer == nil { return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) } - if _, ok := seenPeers[peerID]; ok { return status.Errorf(status.AlreadyExists, "failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID) @@ -94,11 +109,7 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account // check that peerGroupIDs are not in any route peerGroups list for _, groupID := range peerGroupIDs { - // we validated the group existence before entering this function, no need to check again. - group, err := am.Store.GetGroupByID(context.Background(), LockingStrengthShare, groupID, account.Id) - if err != nil || group == nil { - return status.Errorf(status.InvalidArgument, "group with ID %s not found", peerID) - } + group := account.GetGroup(groupID) // we validated the group existence before entering this function, no need to check again. if _, ok := seenPeerGroups[groupID]; ok { return status.Errorf( @@ -109,11 +120,10 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account // check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix for _, id := range group.Peers { if _, ok := seenPeers[id]; ok { - peer, err := am.Store.GetPeerByID(context.Background(), LockingStrengthShare, peerID, account.Id) - if err != nil || peer == nil { + peer := account.GetPeer(id) + if peer == nil { return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) } - return status.Errorf(status.AlreadyExists, "failed to add route with %s - peer %s from the group %s already has this route", getRouteDescriptor(prefix, domains), peer.Name, group.Name) @@ -132,7 +142,7 @@ func getRouteDescriptor(prefix netip.Prefix, domains domain.List) string { } // CreateRoute creates and saves a new route -func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { +func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -171,10 +181,17 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri newRoute.ID = route.ID(xid.New().String()) if len(peerGroupIDs) > 0 { - //err = validateGroups(peerGroupIDs, account.Groups) - //if err != nil { - // return nil, err - //} + err = validateGroups(peerGroupIDs, account.Groups) + if err != nil { + return nil, err + } + } + + if len(accessControlGroupIDs) > 0 { + err = validateGroups(accessControlGroupIDs, account.Groups) + if err != nil { + return nil, err + } } err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains) @@ -190,10 +207,10 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) } - //err = validateGroups(groups, account.Groups) - //if err != nil { - // return nil, err - //} + err = validateGroups(groups, account.Groups) + if err != nil { + return nil, err + } newRoute.Peer = peerID newRoute.PeerGroups = peerGroupIDs @@ -207,6 +224,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri newRoute.Enabled = enabled newRoute.Groups = groups newRoute.KeepRoute = keepRoute + newRoute.AccessControlGroups = accessControlGroupIDs if account.Routes == nil { account.Routes = make(map[route.ID]*route.Route) @@ -228,19 +246,13 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri // SaveRoute saves route func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userID string, routeToSave *route.Route) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + if routeToSave == nil { return status.Errorf(status.InvalidArgument, "route provided is nil") } - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) - if err != nil { - return err - } - - if user.AccountID != accountID { - return status.Errorf(status.PermissionDenied, "user not allowed to update route") - } - if routeToSave.Metric < route.MinMetric || routeToSave.Metric > route.MaxMetric { return status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric) } @@ -249,14 +261,16 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) } - // Do not allow non-Linux peers - peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, routeToSave.Peer, accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } - if peer.Meta.GoOS != "linux" { - return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes") + // Do not allow non-Linux peers + if peer := account.GetPeer(routeToSave.Peer); peer != nil { + if peer.Meta.GoOS != "linux" { + return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes") + } } if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() { @@ -275,78 +289,67 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time") } - groups, err := am.Store.GetAccountGroups(ctx, accountID) + if len(routeToSave.PeerGroups) > 0 { + err = validateGroups(routeToSave.PeerGroups, account.Groups) + if err != nil { + return err + } + } + + if len(routeToSave.AccessControlGroups) > 0 { + err = validateGroups(routeToSave.AccessControlGroups, account.Groups) + if err != nil { + return err + } + } + + err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains) if err != nil { return err } - _ = groups - //if len(routeToSave.PeerGroups) > 0 { - // err = validateGroups(routeToSave.PeerGroups, groups) - // if err != nil { - // return err - // } - //} + err = validateGroups(routeToSave.Groups, account.Groups) + if err != nil { + return err + } - //err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains) - //if err != nil { - // return err - //} - // - //err = validateGroups(routeToSave.Groups, account.Groups) - //if err != nil { - // return err - //} - // - //account.Routes[routeToSave.ID] = routeToSave - // - //account.Network.IncSerial() - //if err = am.Store.SaveAccount(ctx, account); err != nil { - // return err - //} - // - //am.updateAccountPeers(ctx, account) - // - //am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) + account.Routes[routeToSave.ID] = routeToSave + + account.Network.IncSerial() + if err = am.Store.SaveAccount(ctx, account); err != nil { + return err + } + + am.updateAccountPeers(ctx, account) + + am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) return nil } // DeleteRoute deletes route with routeID func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } - if user.AccountID != accountID { - return status.Errorf(status.PermissionDenied, "user not allowed to delete route") + routy := account.Routes[routeID] + if routy == nil { + return status.Errorf(status.NotFound, "route with ID %s doesn't exist", routeID) } + delete(account.Routes, routeID) - route, err := am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID) - if err != nil { + account.Network.IncSerial() + if err = am.Store.SaveAccount(ctx, account); err != nil { return err } - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID) - if err != nil { - return fmt.Errorf("failed to increment network serial: %w", err) - } + am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) - err = transaction.DeleteRoute(ctx, LockingStrengthUpdate, string(routeID), accountID) - if err != nil { - return fmt.Errorf("failed to delete route: %w", err) - } - - return nil - }) - am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta()) - - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf("error getting account: %w", err) - } am.updateAccountPeers(ctx, account) return nil @@ -393,3 +396,248 @@ func getPlaceholderIP() netip.Prefix { // Using an IP from the documentation range to minimize impact in case older clients try to set a route return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32) } + +// getPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account. +func (a *Account) getPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule { + routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes)) + + enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID) + for _, route := range enabledRoutes { + // If no access control groups are specified, accept all traffic. + if len(route.AccessControlGroups) == 0 { + defaultPermit := getDefaultPermit(route) + routesFirewallRules = append(routesFirewallRules, defaultPermit...) + continue + } + + policies := getAllRoutePoliciesFromGroups(a, route.AccessControlGroups) + for _, policy := range policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + distributionGroupPeers, _ := a.getAllPeersFromGroups(ctx, route.Groups, peerID, nil, validatedPeersMap) + rules := generateRouteFirewallRules(ctx, route, rule, distributionGroupPeers, firewallRuleDirectionIN) + routesFirewallRules = append(routesFirewallRules, rules...) + } + } + } + + return routesFirewallRules +} + +func getDefaultPermit(route *route.Route) []*RouteFirewallRule { + var rules []*RouteFirewallRule + + sources := []string{"0.0.0.0/0"} + if route.Network.Addr().Is6() { + sources = []string{"::/0"} + } + rule := RouteFirewallRule{ + SourceRanges: sources, + Action: string(PolicyTrafficActionAccept), + Destination: route.Network.String(), + Protocol: string(PolicyRuleProtocolALL), + IsDynamic: route.IsDynamic(), + } + + rules = append(rules, &rule) + + // dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally + if route.IsDynamic() { + ruleV6 := rule + ruleV6.SourceRanges = []string{"::/0"} + rules = append(rules, &ruleV6) + } + + return rules +} + +// getAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups +// and returns a list of policies that have rules with destinations matching the specified groups. +func getAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy { + routePolicies := make([]*Policy, 0) + for _, groupID := range accessControlGroups { + group, ok := account.Groups[groupID] + if !ok { + continue + } + + for _, policy := range account.Policies { + for _, rule := range policy.Rules { + exist := slices.ContainsFunc(rule.Destinations, func(groupID string) bool { + return groupID == group.ID + }) + if exist { + routePolicies = append(routePolicies, policy) + continue + } + } + } + } + + return routePolicies +} + +// generateRouteFirewallRules generates a list of firewall rules for a given route. +func generateRouteFirewallRules(ctx context.Context, route *route.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule { + rulesExists := make(map[string]struct{}) + rules := make([]*RouteFirewallRule, 0) + + sourceRanges := make([]string, 0, len(groupPeers)) + for _, peer := range groupPeers { + if peer == nil { + continue + } + sourceRanges = append(sourceRanges, fmt.Sprintf(AllowedIPsFormat, peer.IP)) + } + + baseRule := RouteFirewallRule{ + SourceRanges: sourceRanges, + Action: string(rule.Action), + Destination: route.Network.String(), + Protocol: string(rule.Protocol), + IsDynamic: route.IsDynamic(), + } + + // generate rule for port range + if len(rule.Ports) == 0 { + rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...) + } else { + rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...) + + } + + // TODO: generate IPv6 rules for dynamic routes + + return rules +} + +// generateRuleIDBase generates the base rule ID for checking duplicates. +func generateRuleIDBase(rule *PolicyRule, baseRule RouteFirewallRule) string { + return rule.ID + strings.Join(baseRule.SourceRanges, ",") + strconv.Itoa(firewallRuleDirectionIN) + baseRule.Protocol + baseRule.Action +} + +// generateRulesForPeer generates rules for a given peer based on ports and port ranges. +func generateRulesWithPortRanges(baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule { + rules := make([]*RouteFirewallRule, 0) + + ruleIDBase := generateRuleIDBase(rule, baseRule) + if len(rule.Ports) == 0 { + if len(rule.PortRanges) == 0 { + if _, ok := rulesExists[ruleIDBase]; !ok { + rulesExists[ruleIDBase] = struct{}{} + rules = append(rules, &baseRule) + } + } else { + for _, portRange := range rule.PortRanges { + ruleID := fmt.Sprintf("%s%d-%d", ruleIDBase, portRange.Start, portRange.End) + if _, ok := rulesExists[ruleID]; !ok { + rulesExists[ruleID] = struct{}{} + pr := baseRule + pr.PortRange = portRange + rules = append(rules, &pr) + } + } + } + return rules + } + + return rules +} + +// generateRulesWithPorts generates rules when specific ports are provided. +func generateRulesWithPorts(ctx context.Context, baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule { + rules := make([]*RouteFirewallRule, 0) + ruleIDBase := generateRuleIDBase(rule, baseRule) + + for _, port := range rule.Ports { + ruleID := ruleIDBase + port + if _, ok := rulesExists[ruleID]; ok { + continue + } + rulesExists[ruleID] = struct{}{} + + pr := baseRule + p, err := strconv.ParseUint(port, 10, 16) + if err != nil { + log.WithContext(ctx).Errorf("failed to parse port %s for rule: %s", port, rule.ID) + continue + } + + pr.Port = uint16(p) + rules = append(rules, &pr) + } + + return rules +} + +func toProtocolRoutesFirewallRules(rules []*RouteFirewallRule) []*proto.RouteFirewallRule { + result := make([]*proto.RouteFirewallRule, len(rules)) + for i := range rules { + rule := rules[i] + result[i] = &proto.RouteFirewallRule{ + SourceRanges: rule.SourceRanges, + Action: getProtoAction(rule.Action), + Destination: rule.Destination, + Protocol: getProtoProtocol(rule.Protocol), + PortInfo: getProtoPortInfo(rule), + IsDynamic: rule.IsDynamic, + } + } + + return result +} + +// getProtoDirection converts the direction to proto.RuleDirection. +func getProtoDirection(direction int) proto.RuleDirection { + if direction == firewallRuleDirectionOUT { + return proto.RuleDirection_OUT + } + return proto.RuleDirection_IN +} + +// getProtoAction converts the action to proto.RuleAction. +func getProtoAction(action string) proto.RuleAction { + if action == string(PolicyTrafficActionDrop) { + return proto.RuleAction_DROP + } + return proto.RuleAction_ACCEPT +} + +// getProtoProtocol converts the protocol to proto.RuleProtocol. +func getProtoProtocol(protocol string) proto.RuleProtocol { + switch PolicyRuleProtocolType(protocol) { + case PolicyRuleProtocolALL: + return proto.RuleProtocol_ALL + case PolicyRuleProtocolTCP: + return proto.RuleProtocol_TCP + case PolicyRuleProtocolUDP: + return proto.RuleProtocol_UDP + case PolicyRuleProtocolICMP: + return proto.RuleProtocol_ICMP + default: + return proto.RuleProtocol_UNKNOWN + } +} + +// getProtoPortInfo converts the port info to proto.PortInfo. +func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo { + var portInfo proto.PortInfo + if rule.Port != 0 { + portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)} + } else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 { + portInfo.PortSelection = &proto.PortInfo_Range_{ + Range: &proto.PortInfo_Range{ + Start: uint32(portRange.Start), + End: uint32(portRange.End), + }, + } + } + return &portInfo +} diff --git a/management/server/route_test.go b/management/server/route_test.go index 4533c6b7e..09cbe53ff 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -2,6 +2,8 @@ package server import ( "context" + "fmt" + "net" "net/netip" "testing" @@ -44,18 +46,19 @@ var existingDomains = domain.List{"example.com"} func TestCreateRoute(t *testing.T) { type input struct { - network netip.Prefix - domains domain.List - keepRoute bool - networkType route.NetworkType - netID route.NetID - peerKey string - peerGroupIDs []string - description string - masquerade bool - metric int - enabled bool - groups []string + network netip.Prefix + domains domain.List + keepRoute bool + networkType route.NetworkType + netID route.NetID + peerKey string + peerGroupIDs []string + description string + masquerade bool + metric int + enabled bool + groups []string + accessControlGroups []string } testCases := []struct { @@ -69,100 +72,107 @@ func TestCreateRoute(t *testing.T) { { name: "Happy Path Network", inputArgs: input{ - network: netip.MustParsePrefix("192.168.0.0/16"), - networkType: route.IPv4Network, - netID: "happy", - peerKey: peer1ID, - description: "super", - masquerade: false, - metric: 9999, - enabled: true, - groups: []string{routeGroup1}, + network: netip.MustParsePrefix("192.168.0.0/16"), + networkType: route.IPv4Network, + netID: "happy", + peerKey: peer1ID, + description: "super", + masquerade: false, + metric: 9999, + enabled: true, + groups: []string{routeGroup1}, + accessControlGroups: []string{routeGroup1}, }, errFunc: require.NoError, shouldCreate: true, expectedRoute: &route.Route{ - Network: netip.MustParsePrefix("192.168.0.0/16"), - NetworkType: route.IPv4Network, - NetID: "happy", - Peer: peer1ID, - Description: "super", - Masquerade: false, - Metric: 9999, - Enabled: true, - Groups: []string{routeGroup1}, + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetworkType: route.IPv4Network, + NetID: "happy", + Peer: peer1ID, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + AccessControlGroups: []string{routeGroup1}, }, }, { name: "Happy Path Domains", inputArgs: input{ - domains: domain.List{"domain1", "domain2"}, - keepRoute: true, - networkType: route.DomainNetwork, - netID: "happy", - peerKey: peer1ID, - description: "super", - masquerade: false, - metric: 9999, - enabled: true, - groups: []string{routeGroup1}, + domains: domain.List{"domain1", "domain2"}, + keepRoute: true, + networkType: route.DomainNetwork, + netID: "happy", + peerKey: peer1ID, + description: "super", + masquerade: false, + metric: 9999, + enabled: true, + groups: []string{routeGroup1}, + accessControlGroups: []string{routeGroup1}, }, errFunc: require.NoError, shouldCreate: true, expectedRoute: &route.Route{ - Network: netip.MustParsePrefix("192.0.2.0/32"), - Domains: domain.List{"domain1", "domain2"}, - NetworkType: route.DomainNetwork, - NetID: "happy", - Peer: peer1ID, - Description: "super", - Masquerade: false, - Metric: 9999, - Enabled: true, - Groups: []string{routeGroup1}, - KeepRoute: true, + Network: netip.MustParsePrefix("192.0.2.0/32"), + Domains: domain.List{"domain1", "domain2"}, + NetworkType: route.DomainNetwork, + NetID: "happy", + Peer: peer1ID, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + KeepRoute: true, + AccessControlGroups: []string{routeGroup1}, }, }, { name: "Happy Path Peer Groups", inputArgs: input{ - network: netip.MustParsePrefix("192.168.0.0/16"), - networkType: route.IPv4Network, - netID: "happy", - peerGroupIDs: []string{routeGroupHA1, routeGroupHA2}, - description: "super", - masquerade: false, - metric: 9999, - enabled: true, - groups: []string{routeGroup1, routeGroup2}, + network: netip.MustParsePrefix("192.168.0.0/16"), + networkType: route.IPv4Network, + netID: "happy", + peerGroupIDs: []string{routeGroupHA1, routeGroupHA2}, + description: "super", + masquerade: false, + metric: 9999, + enabled: true, + groups: []string{routeGroup1, routeGroup2}, + accessControlGroups: []string{routeGroup1, routeGroup2}, }, errFunc: require.NoError, shouldCreate: true, expectedRoute: &route.Route{ - Network: netip.MustParsePrefix("192.168.0.0/16"), - NetworkType: route.IPv4Network, - NetID: "happy", - PeerGroups: []string{routeGroupHA1, routeGroupHA2}, - Description: "super", - Masquerade: false, - Metric: 9999, - Enabled: true, - Groups: []string{routeGroup1, routeGroup2}, + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetworkType: route.IPv4Network, + NetID: "happy", + PeerGroups: []string{routeGroupHA1, routeGroupHA2}, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1, routeGroup2}, + AccessControlGroups: []string{routeGroup1, routeGroup2}, }, }, { name: "Both network and domains provided should fail", inputArgs: input{ - network: netip.MustParsePrefix("192.168.0.0/16"), - domains: domain.List{"domain1", "domain2"}, - netID: "happy", - peerKey: peer1ID, - peerGroupIDs: []string{routeGroupHA1}, - description: "super", - masquerade: false, - metric: 9999, - enabled: true, - groups: []string{routeGroup1}, + network: netip.MustParsePrefix("192.168.0.0/16"), + domains: domain.List{"domain1", "domain2"}, + netID: "happy", + peerKey: peer1ID, + peerGroupIDs: []string{routeGroupHA1}, + description: "super", + masquerade: false, + metric: 9999, + enabled: true, + groups: []string{routeGroup1}, + accessControlGroups: []string{routeGroup2}, }, errFunc: require.Error, shouldCreate: false, @@ -170,16 +180,17 @@ func TestCreateRoute(t *testing.T) { { name: "Both peer and peer_groups Provided Should Fail", inputArgs: input{ - network: netip.MustParsePrefix("192.168.0.0/16"), - networkType: route.IPv4Network, - netID: "happy", - peerKey: peer1ID, - peerGroupIDs: []string{routeGroupHA1}, - description: "super", - masquerade: false, - metric: 9999, - enabled: true, - groups: []string{routeGroup1}, + network: netip.MustParsePrefix("192.168.0.0/16"), + networkType: route.IPv4Network, + netID: "happy", + peerKey: peer1ID, + peerGroupIDs: []string{routeGroupHA1}, + description: "super", + masquerade: false, + metric: 9999, + enabled: true, + groups: []string{routeGroup1}, + accessControlGroups: []string{routeGroup2}, }, errFunc: require.Error, shouldCreate: false, @@ -423,13 +434,13 @@ func TestCreateRoute(t *testing.T) { if testCase.createInitRoute { groupAll, errInit := account.GetGroupAll() require.NoError(t, errInit) - _, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID, false) + _, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false) require.NoError(t, errInit) - _, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID, false) + _, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false) require.NoError(t, errInit) } - outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute) + outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute) testCase.errFunc(t, err) @@ -1037,15 +1048,16 @@ func TestDeleteRoute(t *testing.T) { func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { baseRoute := &route.Route{ - Network: netip.MustParsePrefix("192.168.0.0/16"), - NetID: "superNet", - NetworkType: route.IPv4Network, - PeerGroups: []string{routeGroupHA1, routeGroupHA2}, - Description: "ha route", - Masquerade: false, - Metric: 9999, - Enabled: true, - Groups: []string{routeGroup1, routeGroup2}, + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetID: "superNet", + NetworkType: route.IPv4Network, + PeerGroups: []string{routeGroupHA1, routeGroupHA2}, + Description: "ha route", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1, routeGroup2}, + AccessControlGroups: []string{routeGroup1}, } am, err := createRouterManager(t) @@ -1062,7 +1074,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { require.NoError(t, err) require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") - newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.Enabled, userID, baseRoute.KeepRoute) + newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute) require.NoError(t, err) require.Equal(t, newRoute.Enabled, true) @@ -1127,16 +1139,17 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { // no routes for peer in different groups // no routes when route is deleted baseRoute := &route.Route{ - ID: "testingRoute", - Network: netip.MustParsePrefix("192.168.0.0/16"), - NetID: "superNet", - NetworkType: route.IPv4Network, - Peer: peer1ID, - Description: "super", - Masquerade: false, - Metric: 9999, - Enabled: true, - Groups: []string{routeGroup1}, + ID: "testingRoute", + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetID: "superNet", + NetworkType: route.IPv4Network, + Peer: peer1ID, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + AccessControlGroups: []string{routeGroup1}, } am, err := createRouterManager(t) @@ -1153,7 +1166,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { require.NoError(t, err) require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") - createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, false, userID, baseRoute.KeepRoute) + createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute) require.NoError(t, err) noDisabledRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) @@ -1244,7 +1257,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { func createRouterStore(t *testing.T) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } @@ -1467,3 +1480,300 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er return am.Store.GetAccount(context.Background(), account.Id) } + +func TestAccount_getPeersRoutesFirewall(t *testing.T) { + var ( + peerBIp = "100.65.80.39" + peerCIp = "100.65.254.139" + peerHIp = "100.65.29.55" + ) + + account := &Account{ + Peers: map[string]*nbpeer.Peer{ + "peerA": { + ID: "peerA", + IP: net.ParseIP("100.65.14.88"), + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + }, + }, + "peerB": { + ID: "peerB", + IP: net.ParseIP(peerBIp), + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{}, + }, + "peerC": { + ID: "peerC", + IP: net.ParseIP(peerCIp), + Status: &nbpeer.PeerStatus{}, + }, + "peerD": { + ID: "peerD", + IP: net.ParseIP("100.65.62.5"), + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + }, + }, + "peerE": { + ID: "peerE", + IP: net.ParseIP("100.65.32.206"), + Key: peer1Key, + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + }, + }, + "peerF": { + ID: "peerF", + IP: net.ParseIP("100.65.250.202"), + Status: &nbpeer.PeerStatus{}, + }, + "peerG": { + ID: "peerG", + IP: net.ParseIP("100.65.13.186"), + Status: &nbpeer.PeerStatus{}, + }, + "peerH": { + ID: "peerH", + IP: net.ParseIP(peerHIp), + Status: &nbpeer.PeerStatus{}, + }, + }, + Groups: map[string]*nbgroup.Group{ + "routingPeer1": { + ID: "routingPeer1", + Name: "RoutingPeer1", + Peers: []string{ + "peerA", + }, + }, + "routingPeer2": { + ID: "routingPeer2", + Name: "RoutingPeer2", + Peers: []string{ + "peerD", + }, + }, + "route1": { + ID: "route1", + Name: "Route1", + Peers: []string{}, + }, + "route2": { + ID: "route2", + Name: "Route2", + Peers: []string{}, + }, + "finance": { + ID: "finance", + Name: "Finance", + Peers: []string{ + "peerF", + "peerG", + }, + }, + "dev": { + ID: "dev", + Name: "Dev", + Peers: []string{ + "peerC", + "peerH", + "peerB", + }, + }, + "contractors": { + ID: "contractors", + Name: "Contractors", + Peers: []string{}, + }, + }, + Routes: map[route.ID]*route.Route{ + "route1": { + ID: "route1", + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetID: "route1", + NetworkType: route.IPv4Network, + PeerGroups: []string{"routingPeer1", "routingPeer2"}, + Description: "Route1 ha route", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{"dev"}, + AccessControlGroups: []string{"route1"}, + }, + "route2": { + ID: "route2", + Network: existingNetwork, + NetID: "route2", + NetworkType: route.IPv4Network, + Peer: "peerE", + Description: "Allow", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{"finance"}, + AccessControlGroups: []string{"route2"}, + }, + "route3": { + ID: "route3", + Network: netip.MustParsePrefix("192.0.2.0/32"), + Domains: domain.List{"example.com"}, + NetID: "route3", + NetworkType: route.DomainNetwork, + Peer: "peerE", + Description: "Allow all traffic to routed DNS network", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{"contractors"}, + AccessControlGroups: []string{}, + }, + }, + Policies: []*Policy{ + { + ID: "RuleRoute1", + Name: "Route1", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "RuleRoute1", + Name: "ruleRoute1", + Bidirectional: true, + Enabled: true, + Protocol: PolicyRuleProtocolALL, + Action: PolicyTrafficActionAccept, + Ports: []string{"80", "320"}, + Sources: []string{ + "dev", + }, + Destinations: []string{ + "route1", + }, + }, + }, + }, + { + ID: "RuleRoute2", + Name: "Route2", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "RuleRoute2", + Name: "ruleRoute2", + Bidirectional: true, + Enabled: true, + Protocol: PolicyRuleProtocolTCP, + Action: PolicyTrafficActionAccept, + PortRanges: []RulePortRange{ + { + Start: 80, + End: 350, + }, { + Start: 80, + End: 350, + }, + }, + Sources: []string{ + "finance", + }, + Destinations: []string{ + "route2", + }, + }, + }, + }, + }, + } + + validatedPeers := make(map[string]struct{}) + for p := range account.Peers { + validatedPeers[p] = struct{}{} + } + + t.Run("check applied policies for the route", func(t *testing.T) { + route1 := account.Routes["route1"] + policies := getAllRoutePoliciesFromGroups(account, route1.AccessControlGroups) + assert.Len(t, policies, 1) + + route2 := account.Routes["route2"] + policies = getAllRoutePoliciesFromGroups(account, route2.AccessControlGroups) + assert.Len(t, policies, 1) + + route3 := account.Routes["route3"] + policies = getAllRoutePoliciesFromGroups(account, route3.AccessControlGroups) + assert.Len(t, policies, 0) + }) + + t.Run("check peer routes firewall rules", func(t *testing.T) { + routesFirewallRules := account.getPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers) + assert.Len(t, routesFirewallRules, 2) + + expectedRoutesFirewallRules := []*RouteFirewallRule{ + { + SourceRanges: []string{ + fmt.Sprintf(AllowedIPsFormat, peerCIp), + fmt.Sprintf(AllowedIPsFormat, peerHIp), + fmt.Sprintf(AllowedIPsFormat, peerBIp), + }, + Action: "accept", + Destination: "192.168.0.0/16", + Protocol: "all", + Port: 80, + }, + { + SourceRanges: []string{ + fmt.Sprintf(AllowedIPsFormat, peerCIp), + fmt.Sprintf(AllowedIPsFormat, peerHIp), + fmt.Sprintf(AllowedIPsFormat, peerBIp), + }, + Action: "accept", + Destination: "192.168.0.0/16", + Protocol: "all", + Port: 320, + }, + } + assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) + + // peerD is also the routing peer for route1, should contain same routes firewall rules as peerA + routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers) + assert.Len(t, routesFirewallRules, 2) + assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) + + // peerE is a single routing peer for route 2 and route 3 + routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers) + assert.Len(t, routesFirewallRules, 3) + + expectedRoutesFirewallRules = []*RouteFirewallRule{ + { + SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"}, + Action: "accept", + Destination: existingNetwork.String(), + Protocol: "tcp", + PortRange: RulePortRange{Start: 80, End: 350}, + }, + { + SourceRanges: []string{"0.0.0.0/0"}, + Action: "accept", + Destination: "192.0.2.0/32", + Protocol: "all", + IsDynamic: true, + }, + { + SourceRanges: []string{"::/0"}, + Action: "accept", + Destination: "192.0.2.0/32", + Protocol: "all", + IsDynamic: true, + }, + } + assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) + + // peerC is part of route1 distribution groups but should not receive the routes firewall rules + routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers) + assert.Len(t, routesFirewallRules, 0) + }) + +} diff --git a/management/server/sql_store.go b/management/server/sql_store.go index d3840cca2..de3dfa945 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -10,6 +10,7 @@ import ( "path/filepath" "runtime" "runtime/debug" + "strconv" "strings" "sync" "time" @@ -63,8 +64,14 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine StoreEngine, metr if err != nil { return nil, err } - conns := runtime.NumCPU() - sql.SetMaxOpenConns(conns) // TODO: make it configurable + + conns, err := strconv.Atoi(os.Getenv("NB_SQL_MAX_OPEN_CONNS")) + if err != nil { + conns = runtime.NumCPU() + } + sql.SetMaxOpenConns(conns) + + log.Infof("Set max open db connections to %d", conns) if err := migrate(ctx, db); err != nil { return nil, fmt.Errorf("migrate: %w", err) @@ -316,6 +323,29 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer. return nil } +func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error { + accountCopy := Account{ + Domain: domain, + DomainCategory: category, + IsDomainPrimaryAccount: isPrimaryDomain, + } + + fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"} + result := s.db.WithContext(ctx).Model(&Account{}). + Select(fieldsToUpdate). + Where(idQueryCondition, accountID). + Updates(&accountCopy) + if result.Error != nil { + return result.Error + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "account %s", accountID) + } + + return nil +} + func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { var peerCopy nbpeer.Peer peerCopy.Status = &peerStatus @@ -361,32 +391,41 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P return nil } -// SaveUser saves a user to the store. -func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error { - return saveRecord[User](s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}), lockStrength, user) +// SaveUsers saves the given list of users to the database. +// It updates existing users if a conflict occurs. +func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error { + usersToSave := make([]User, 0, len(users)) + for _, user := range users { + user.AccountID = accountID + for id, pat := range user.PATs { + pat.ID = id + user.PATsG = append(user.PATsG, *pat) + } + usersToSave = append(usersToSave, *user) + } + return s.db.Session(&gorm.Session{FullSaveAssociations: true}). + Clauses(clause.OnConflict{UpdateAll: true}). + Create(&usersToSave).Error } -// SaveUsers saves a list of users to the store. -func (s *SqlStore) SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*User) error { - result := s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}). - Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&users) +// SaveUser saves the given user to the database. +func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user) if result.Error != nil { - return status.Errorf(status.Internal, "failed to save users to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error) } return nil } -// DeleteUser deletes a user from the store. -func (s *SqlStore) DeleteUser(ctx context.Context, lockStrength LockingStrength, userID, accountID string) error { - return deleteRecordByID[User](s.db.WithContext(ctx).Select(clause.Associations), lockStrength, userID, accountID) -} +// SaveGroups saves the given list of groups to the database. +func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error { + if len(groups) == 0 { + return nil + } -// DeleteUsers deletes a list of users from the store. -func (s *SqlStore) DeleteUsers(ctx context.Context, strength LockingStrength, userIDs []string, accountID string) error { - result := s.db.WithContext(ctx).Select(clause.Associations).Clauses(clause.Locking{Strength: string(strength)}). - Where("id IN ? AND account_id = ?", userIDs, accountID).Delete(&User{}) + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups) if result.Error != nil { - return status.Errorf(status.Internal, "failed to delete users from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error) } return nil } @@ -422,7 +461,7 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") } log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error) - return "", status.Errorf(status.Internal, "issue getting account from store") + return "", status.NewGetAccountFromStoreError(result.Error) } return accountID, nil @@ -435,7 +474,7 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (* if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - return nil, status.NewSetupKeyNotFoundError() + return nil, status.NewSetupKeyNotFoundError(result.Error) } if key.AccountID == "" { @@ -453,7 +492,7 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error) - return "", status.Errorf(status.Internal, "issue getting account from store") + return "", status.NewGetAccountFromStoreError(result.Error) } return token.ID, nil @@ -467,7 +506,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error) - return nil, status.Errorf(status.Internal, "issue getting account from store") + return nil, status.NewGetAccountFromStoreError(result.Error) } if token.UserID == "" { @@ -502,9 +541,18 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre return &user, nil } -// GetAccountUsers returns all users associated with the account. -func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) { - return getRecords[User](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID) +func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) { + var users []*User + result := s.db.Find(&users, accountIDCondition, accountID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") + } + log.WithContext(ctx).Errorf("error when getting users from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "issue getting users from store") + } + + return users, nil } func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { @@ -556,7 +604,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewAccountNotFoundError(accountID) } - return nil, status.Errorf(status.Internal, "issue getting account from store") + return nil, status.NewGetAccountFromStoreError(result.Error) } // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us @@ -619,7 +667,7 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - return nil, status.Errorf(status.Internal, "issue getting account from store") + return nil, status.NewGetAccountFromStoreError(result.Error) } if user.AccountID == "" { @@ -636,7 +684,7 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - return nil, status.Errorf(status.Internal, "issue getting account from store") + return nil, status.NewGetAccountFromStoreError(result.Error) } if peer.AccountID == "" { @@ -654,7 +702,7 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) ( if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - return nil, status.Errorf(status.Internal, "issue getting account from store") + return nil, status.NewGetAccountFromStoreError(result.Error) } if peer.AccountID == "" { @@ -672,7 +720,7 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - return "", status.Errorf(status.Internal, "issue getting account from store") + return "", status.NewGetAccountFromStoreError(result.Error) } return accountID, nil @@ -685,7 +733,7 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - return "", status.Errorf(status.Internal, "issue getting account from store") + return "", status.NewGetAccountFromStoreError(result.Error) } return accountID, nil @@ -698,7 +746,7 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - return "", status.NewSetupKeyNotFoundError() + return "", status.NewSetupKeyNotFoundError(result.Error) } if accountID == "" { @@ -708,22 +756,6 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) return accountID, nil } -// GetAccountOwnerID returns the owner ID of the account. -func (s *SqlStore) GetAccountOwnerID(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) { - var ownerID string - - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). - Select("created_by").Where(idQueryCondition, accountID).First(&ownerID) - if result.Error != nil { - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return "", status.Errorf(status.NotFound, "account not found") - } - return "", status.Errorf(status.Internal, "failed to get account owner from store: %v", result.Error) - } - - return ownerID, nil -} - func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) { var ipJSONStrings []string @@ -735,7 +767,7 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "no peers found for the account") } - return nil, status.Errorf(status.Internal, "issue getting IPs from store") + return nil, status.Errorf(status.Internal, "issue getting IPs from store: %s", result.Error) } // Convert the JSON strings to net.IP objects @@ -763,7 +795,7 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock return nil, status.Errorf(status.NotFound, "no peers found for the account") } log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error) - return nil, status.Errorf(status.Internal, "issue getting dns labels from store") + return nil, status.Errorf(status.Internal, "issue getting dns labels from store: %s", result.Error) } return labels, nil @@ -776,7 +808,7 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.NewAccountNotFoundError(accountID) } - return nil, status.Errorf(status.Internal, "issue getting network from store") + return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err) } return accountNetwork.Network, nil } @@ -788,7 +820,7 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "peer not found") } - return nil, status.Errorf(status.Internal, "issue getting peer from store") + return nil, status.Errorf(status.Internal, "issue getting peer from store: %s", result.Error) } return &peer, nil @@ -800,26 +832,11 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "settings not found") } - return nil, status.Errorf(status.Internal, "issue getting settings from store") + return nil, status.Errorf(status.Internal, "issue getting settings from store: %s", err) } return accountSettings.Settings, nil } -// SaveAccountSettings stores the account settings in DB. -func (s *SqlStore) SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error { - result := s.db.WithContext(ctx).Debug().Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). - Select("*").Where(idQueryCondition, accountID).Updates(&AccountSettings{Settings: settings}) - if result.Error != nil { - return status.Errorf(status.Internal, "failed to save account settings to store: %v", result.Error) - } - - if result.RowsAffected == 0 { - return status.Errorf(status.NotFound, "account not found") - } - - return nil -} - // SaveUserLastLogin stores the last login time for a user in DB. func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error { var user User @@ -897,7 +914,6 @@ func getGormConfig() *gorm.Config { Logger: logger.Default.LogMode(logger.Silent), CreateBatchSize: 400, PrepareStmt: true, - TranslateError: true, } } @@ -932,19 +948,19 @@ func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, data return store, nil } -// NewPostgresqlStoreFromFileStore restores a store from FileStore and stores Postgres DB. -func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { +// NewPostgresqlStoreFromSqlStore restores a store from SqlStore and stores Postgres DB. +func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { store, err := NewPostgresqlStore(ctx, dsn, metrics) if err != nil { return nil, err } - err = store.SaveInstallationID(ctx, fileStore.InstallationID) + err = store.SaveInstallationID(ctx, sqliteStore.GetInstallationID()) if err != nil { return nil, err } - for _, account := range fileStore.GetAllAccounts(ctx) { + for _, account := range sqliteStore.GetAllAccounts(ctx) { err := store.SaveAccount(ctx, account) if err != nil { return nil, err @@ -962,7 +978,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "setup key not found") } - return nil, status.NewSetupKeyNotFoundError() + return nil, status.NewSetupKeyNotFoundError(result.Error) } return &setupKey, nil } @@ -994,7 +1010,7 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.Errorf(status.NotFound, "group 'All' not found for account") } - return status.Errorf(status.Internal, "issue finding group 'All'") + return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error) } for _, existingPeerID := range group.Peers { @@ -1006,7 +1022,7 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer group.Peers = append(group.Peers, peerID) if err := s.db.Save(&group).Error; err != nil { - return status.Errorf(status.Internal, "issue updating group 'All'") + return status.Errorf(status.Internal, "issue updating group 'All': %s", err) } return nil @@ -1020,7 +1036,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.Errorf(status.NotFound, "group not found for account") } - return status.Errorf(status.Internal, "issue finding group") + return status.Errorf(status.Internal, "issue finding group: %s", result.Error) } for _, existingPeerID := range group.Peers { @@ -1032,25 +1048,29 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId group.Peers = append(group.Peers, peerId) if err := s.db.Save(&group).Error; err != nil { - return status.Errorf(status.Internal, "issue updating group") + return status.Errorf(status.Internal, "issue updating group: %s", err) } return nil } +// GetUserPeers retrieves peers for a user. +func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { + return getRecords[*nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID) +} + func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { if err := s.db.WithContext(ctx).Create(peer).Error; err != nil { - return status.Errorf(status.Internal, "issue adding peer to account") + return status.Errorf(status.Internal, "issue adding peer to account: %s", err) } return nil } -func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error { - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). - Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) +func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { + result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) if result.Error != nil { - return status.Errorf(status.Internal, "issue incrementing network serial count") + return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error) } return nil } @@ -1093,19 +1113,6 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki return &accountDNSSettings.DNSSettings, nil } -// SaveDNSSettings saves the DNS settings to the store. -func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error { - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). - Where(idQueryCondition, accountID).Updates(&AccountDNSSettings{DNSSettings: *settings}) - if result.Error != nil { - return status.Errorf(status.Internal, "failed to save dns settings to store: %v", result.Error) - } - if result.RowsAffected == 0 { - return status.Errorf(status.NotFound, "account not found") - } - return nil -} - // AccountExists checks whether an account exists by the given ID. func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) { var accountID string @@ -1158,38 +1165,18 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren return &group, nil } -// SaveGroup saves a group to the database. +// SaveGroup saves a group to the store. func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error { - return saveRecord[nbgroup.Group](s.db.WithContext(ctx), lockStrength, group) -} - -// SaveGroups saves the given list of groups to the database. -func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error { - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups) + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group) if result.Error != nil { - return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error) - } - return nil -} - -// DeleteGroup deletes a group from the database. -func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) error { - return deleteRecordByID[nbgroup.Group](s.db.WithContext(ctx), lockStrength, groupID, accountID) -} - -// DeleteGroups deletes groups from the database. -func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, groupIDs []string, accountID string) error { - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(strength)}). - Where("id IN ? AND account_id = ?", groupIDs, accountID).Delete(&nbgroup.Group{}) - if result.Error != nil { - return status.Errorf(status.Internal, "failed to delete groups from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error) } return nil } // GetAccountPolicies retrieves policies for an account. func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { - return getRecords[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID) + return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID) } // GetPolicyByID retrieves a policy by its ID and account ID. @@ -1197,21 +1184,9 @@ func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStreng return getRecordByID[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, policyID, accountID) } -// SavePolicy saves a policy to the database. -func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error { - return s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}). - Clauses(clause.Locking{Strength: string(lockStrength)}).Save(policy).Error -} - -// DeletePolicy deletes a policy from the database. -func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, policyID, accountID string) error { - return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&Policy{}, accountAndIDQueryCondition, accountID, policyID).Error -} - // GetAccountPostureChecks retrieves posture checks for an account. func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) { - return getRecords[posture.Checks](s.db.WithContext(ctx), lockStrength, accountID) + return getRecords[*posture.Checks](s.db.WithContext(ctx), lockStrength, accountID) } // GetPostureChecksByID retrieves posture checks by their ID and account ID. @@ -1219,28 +1194,9 @@ func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength Lockin return getRecordByID[posture.Checks](s.db.WithContext(ctx), lockStrength, postureCheckID, accountID) } -// SavePostureChecks saves a posture checks to the database. -func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error { - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck) - if result.Error != nil { - if errors.Is(result.Error, gorm.ErrDuplicatedKey) { - return status.Errorf(status.InvalidArgument, "name should be unique") - } - return status.Errorf(status.Internal, "failed to save posture checks to store: %v", result.Error) - } - - return nil -} - -// DeletePostureChecks deletes a posture checks from the database. -func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, postureChecksID, accountID string) error { - return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&posture.Checks{}, accountAndIDQueryCondition, accountID, postureChecksID).Error -} - // GetAccountRoutes retrieves network routes for an account. func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) { - return getRecords[route.Route](s.db.WithContext(ctx), lockStrength, accountID) + return getRecords[*route.Route](s.db.WithContext(ctx), lockStrength, accountID) } // GetRouteByID retrieves a route by its ID and account ID. @@ -1248,20 +1204,9 @@ func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrengt return getRecordByID[route.Route](s.db.WithContext(ctx), lockStrength, routeID, accountID) } -// SaveRoute saves a route to the database. -func (s *SqlStore) SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error { - return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(route).Error -} - -// DeleteRoute deletes a route from the database. -func (s *SqlStore) DeleteRoute(ctx context.Context, lockStrength LockingStrength, routeID, accountID string) error { - return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&route.Route{}, accountAndIDQueryCondition, accountID, routeID).Error -} - // GetAccountSetupKeys retrieves setup keys for an account. func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) { - return getRecords[SetupKey](s.db.WithContext(ctx), lockStrength, accountID) + return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID) } // GetSetupKeyByID retrieves a setup key by its ID and account ID. @@ -1269,21 +1214,9 @@ func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStre return getRecordByID[SetupKey](s.db.WithContext(ctx), lockStrength, setupKeyID, accountID) } -// SaveSetupKey saves a setup key to the database. -func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error { - return s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}). - Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&setupKey).Error -} - -// DeleteSetupKey deletes a setup key from the database. -func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, setupKeyID, accountID string) error { - return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&SetupKey{}, accountAndIDQueryCondition, accountID, setupKeyID).Error -} - // GetAccountNameServerGroups retrieves name server groups for an account. func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) { - return getRecords[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID) + return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID) } // GetNameServerGroupByID retrieves a name server group by its ID and account ID. @@ -1291,70 +1224,15 @@ func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength Lock return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID) } -// SaveNameServerGroup saves a name server group to the database. -func (s *SqlStore) SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *nbdns.NameServerGroup) error { - return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(nameServerGroup).Error -} - -// DeleteNameServerGroup deletes a name server group from the database. -func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroupID, accountID string) error { - return deleteRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nameServerGroupID, accountID) -} - -// GetPATByID retrieves a personal access token by its ID and user ID. -func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, patID string, userID string) (*PersonalAccessToken, error) { - var pat PersonalAccessToken - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). - First(&pat, "id = ? AND user_id = ?", patID, userID) - if err := result.Error; err != nil { - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "PAT not found") - } - return nil, status.Errorf(status.Internal, "failed to get PAT from store") - } - - return &pat, nil -} - -// SavePAT saves a personal access token to the database. -func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *PersonalAccessToken) error { - return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(pat).Error -} - -// DeletePAT deletes a personal access token from the database. -func (s *SqlStore) DeletePAT(ctx context.Context, lockStrength LockingStrength, patID, userID string) error { - return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&PersonalAccessToken{}, "id = ? AND user_id = ?", patID, userID).Error -} - -// GetAccountPeers retrieves peers for an account. -func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) { - return getRecords[nbpeer.Peer](s.db.WithContext(ctx), lockStrength, accountID) -} - -// GetUserPeers retrieves peers for a user. -func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { - return getRecords[nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID) -} - -// GetAccountPeersWithExpiration retrieves a list of peers that have Peer.LoginExpirationEnabled set to true and that were added by a user. -func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) { - db := s.db.WithContext(ctx).Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true) - return getRecords[nbpeer.Peer](db, lockStrength, accountID) -} - -// GetPeerByID retrieves a peer by its ID and account ID. -func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, peerID string, accountID string) (*nbpeer.Peer, error) { - return getRecordByID[nbpeer.Peer](s.db.WithContext(ctx), lockStrength, peerID, accountID) -} - // getRecords retrieves records from the database based on the account ID. -func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]*T, error) { - var record []*T +func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) { + var record []T result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&record, accountIDCondition, accountID) if err := result.Error; err != nil { - recordType := getRecordType(record) + parts := strings.Split(fmt.Sprintf("%T", record), ".") + recordType := parts[len(parts)-1] + return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err) } @@ -1368,7 +1246,8 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a result := db.Clauses(clause.Locking{Strength: string(lockStrength)}). First(&record, accountAndIDQueryCondition, accountID, recordID) if err := result.Error; err != nil { - recordType := getRecordType(record) + parts := strings.Split(fmt.Sprintf("%T", record), ".") + recordType := parts[len(parts)-1] if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "%s not found", recordType) @@ -1377,36 +1256,3 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a } return &record, nil } - -// saveRecord saves a record to the database. -func saveRecord[T any](db *gorm.DB, lockStrength LockingStrength, record *T) error { - result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(record) - if result.Error != nil { - return status.Errorf(status.Internal, "failed to save %s to store: %v", getRecordType(record), result.Error) - } - - return nil -} - -// deleteRecordByID deletes a record by its ID and account ID from the database. -func deleteRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) error { - var record T - - recordType := getRecordType(record) - - result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&record, accountAndIDQueryCondition, accountID, recordID) - if err := result.Error; err != nil { - return status.Errorf(status.Internal, "failed to delete %s from store: %v", recordType, err) - } - - if result.RowsAffected == 0 { - return status.Errorf(status.NotFound, "%s not found", recordType) - } - - return nil -} - -func getRecordType(record any) string { - parts := strings.Split(fmt.Sprintf("%T", record), ".") - return parts[len(parts)-1] -} diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 64ef36831..20e812ea7 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -7,25 +7,22 @@ import ( "net" "net/netip" "os" - "path/filepath" "runtime" "testing" "time" - nbdns "github.com/netbirdio/netbird/dns" - nbgroup "github.com/netbirdio/netbird/management/server/group" - "github.com/netbirdio/netbird/management/server/testutil" - "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + nbdns "github.com/netbirdio/netbird/dns" + nbgroup "github.com/netbirdio/netbird/management/server/group" + route2 "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/management/server/status" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/util" ) func TestSqlite_NewStore(t *testing.T) { @@ -33,7 +30,10 @@ func TestSqlite_NewStore(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) if len(store.GetAllAccounts(context.Background())) != 0 { t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") @@ -41,15 +41,23 @@ func TestSqlite_NewStore(t *testing.T) { } func TestSqlite_SaveAccount_Large(t *testing.T) { - if runtime.GOOS != "linux" && os.Getenv("CI") == "true" || runtime.GOOS == "windows" { - t.Skip("skip large test on non-linux OS due to environment restrictions") + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } + t.Run("SQLite", func(t *testing.T) { - store := newSqliteStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) runLargeTest(t, store) }) + // create store outside to have a better time counter for the test - store := newPostgresqlStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) t.Run("PostgreSQL", func(t *testing.T) { runLargeTest(t, store) }) @@ -201,7 +209,10 @@ func TestSqlite_SaveAccount(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) account := newAccountWithId(context.Background(), "account_id", "testuser", "") setupKey := GenerateDefaultSetupKey() @@ -215,7 +226,7 @@ func TestSqlite_SaveAccount(t *testing.T) { Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") @@ -273,7 +284,10 @@ func TestSqlite_DeleteAccount(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) testUserID := "testuser" user := NewAdminUser(testUserID) @@ -295,7 +309,7 @@ func TestSqlite_DeleteAccount(t *testing.T) { } account.Users[testUserID] = user - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) if len(store.GetAllAccounts(context.Background())) != 1 { @@ -326,7 +340,7 @@ func TestSqlite_DeleteAccount(t *testing.T) { for _, policy := range account.Policies { var rules []*PolicyRule - err = store.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + err = store.(*SqlStore).db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules") require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount") @@ -334,7 +348,7 @@ func TestSqlite_DeleteAccount(t *testing.T) { for _, accountUser := range account.Users { var pats []*PersonalAccessToken - err = store.db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error + err = store.(*SqlStore).db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token") require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount") @@ -347,7 +361,10 @@ func TestSqlite_GetAccount(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/store.json") + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) id := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -367,7 +384,10 @@ func TestSqlite_SavePeer(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/store.json") + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) @@ -415,7 +435,10 @@ func TestSqlite_SavePeerStatus(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/store.json") + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) @@ -468,7 +491,10 @@ func TestSqlite_SavePeerLocation(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/store.json") + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) @@ -519,7 +545,10 @@ func TestSqlite_TestGetAccountByPrivateDomain(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/store.json") + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) existingDomain := "test.com" @@ -539,7 +568,10 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/store.json") + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) hashed := "SoMeHaShEdToKeN" id := "9dj38s35-63fb-11ec-90d6-0242ac120003" @@ -560,7 +592,10 @@ func TestSqlite_GetUserByTokenID(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/store.json") + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) id := "9dj38s35-63fb-11ec-90d6-0242ac120003" @@ -576,13 +611,18 @@ func TestSqlite_GetUserByTokenID(t *testing.T) { } func TestMigrate(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newSqliteStore(t) + // TODO: figure out why this fails on postgres + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - err := migrate(context.Background(), store.db) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) + + err = migrate(context.Background(), store.(*SqlStore).db) require.NoError(t, err, "Migration should not fail on empty db") _, ipnet, err := net.ParseCIDR("10.0.0.0/24") @@ -618,7 +658,7 @@ func TestMigrate(t *testing.T) { }, } - err = store.db.Save(act).Error + err = store.(*SqlStore).db.Save(act).Error require.NoError(t, err, "Failed to insert Gob data") type route struct { @@ -634,16 +674,16 @@ func TestMigrate(t *testing.T) { Route: route2.Route{ID: "route1"}, } - err = store.db.Save(rt).Error + err = store.(*SqlStore).db.Save(rt).Error require.NoError(t, err, "Failed to insert Gob data") - err = migrate(context.Background(), store.db) + err = migrate(context.Background(), store.(*SqlStore).db) require.NoError(t, err, "Migration should not fail on gob populated db") - err = migrate(context.Background(), store.db) + err = migrate(context.Background(), store.(*SqlStore).db) require.NoError(t, err, "Migration should not fail on migrated db") - err = store.db.Delete(rt).Where("id = ?", "route1").Error + err = store.(*SqlStore).db.Delete(rt).Where("id = ?", "route1").Error require.NoError(t, err, "Failed to delete Gob data") prefix = netip.MustParsePrefix("12.0.0.0/24") @@ -653,13 +693,13 @@ func TestMigrate(t *testing.T) { Peer: "peer-id", } - err = store.db.Save(nRT).Error + err = store.(*SqlStore).db.Save(nRT).Error require.NoError(t, err, "Failed to insert json nil slice data") - err = migrate(context.Background(), store.db) + err = migrate(context.Background(), store.(*SqlStore).db) require.NoError(t, err, "Migration should not fail on json nil slice populated db") - err = migrate(context.Background(), store.db) + err = migrate(context.Background(), store.(*SqlStore).db) require.NoError(t, err, "Migration should not fail on migrated db") } @@ -668,24 +708,9 @@ func newSqliteStore(t *testing.T) *SqlStore { t.Helper() store, err := NewSqliteStore(context.Background(), t.TempDir(), nil) - require.NoError(t, err) - require.NotNil(t, store) - - return store -} - -func newSqliteStoreFromFile(t *testing.T, filename string) *SqlStore { - t.Helper() - - storeDir := t.TempDir() - - err := util.CopyFileContents(filename, filepath.Join(storeDir, "store.json")) - require.NoError(t, err) - - fStore, err := NewFileStore(context.Background(), storeDir, nil) - require.NoError(t, err) - - store, err := NewSqliteStoreFromFileStore(context.Background(), fStore, storeDir, nil) + t.Cleanup(func() { + store.Close(context.Background()) + }) require.NoError(t, err) require.NotNil(t, store) @@ -709,64 +734,15 @@ func newAccount(store Store, id int) error { return store.SaveAccount(context.Background(), account) } -func newPostgresqlStore(t *testing.T) *SqlStore { - t.Helper() - - cleanUp, err := testutil.CreatePGDB() - if err != nil { - t.Fatal(err) - } - t.Cleanup(cleanUp) - - postgresDsn, ok := os.LookupEnv(postgresDsnEnv) - if !ok { - t.Fatalf("could not initialize postgresql store: %s is not set", postgresDsnEnv) - } - - store, err := NewPostgresqlStore(context.Background(), postgresDsn, nil) - if err != nil { - t.Fatalf("could not initialize postgresql store: %s", err) - } - require.NoError(t, err) - require.NotNil(t, store) - - return store -} - -func newPostgresqlStoreFromFile(t *testing.T, filename string) *SqlStore { - t.Helper() - - storeDir := t.TempDir() - err := util.CopyFileContents(filename, filepath.Join(storeDir, "store.json")) - require.NoError(t, err) - - fStore, err := NewFileStore(context.Background(), storeDir, nil) - require.NoError(t, err) - - cleanUp, err := testutil.CreatePGDB() - if err != nil { - t.Fatal(err) - } - t.Cleanup(cleanUp) - - postgresDsn, ok := os.LookupEnv(postgresDsnEnv) - if !ok { - t.Fatalf("could not initialize postgresql store: %s is not set", postgresDsnEnv) - } - - store, err := NewPostgresqlStoreFromFileStore(context.Background(), fStore, postgresDsn, nil) - require.NoError(t, err) - require.NotNil(t, store) - - return store -} - func TestPostgresql_NewStore(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newPostgresqlStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) if len(store.GetAllAccounts(context.Background())) != 0 { t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") @@ -774,11 +750,14 @@ func TestPostgresql_NewStore(t *testing.T) { } func TestPostgresql_SaveAccount(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newPostgresqlStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) account := newAccountWithId(context.Background(), "account_id", "testuser", "") setupKey := GenerateDefaultSetupKey() @@ -792,7 +771,7 @@ func TestPostgresql_SaveAccount(t *testing.T) { Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") @@ -846,11 +825,14 @@ func TestPostgresql_SaveAccount(t *testing.T) { } func TestPostgresql_DeleteAccount(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newPostgresqlStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) testUserID := "testuser" user := NewAdminUser(testUserID) @@ -872,7 +854,7 @@ func TestPostgresql_DeleteAccount(t *testing.T) { } account.Users[testUserID] = user - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) if len(store.GetAllAccounts(context.Background())) != 1 { @@ -903,7 +885,7 @@ func TestPostgresql_DeleteAccount(t *testing.T) { for _, policy := range account.Policies { var rules []*PolicyRule - err = store.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + err = store.(*SqlStore).db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules") require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount") @@ -911,7 +893,7 @@ func TestPostgresql_DeleteAccount(t *testing.T) { for _, accountUser := range account.Users { var pats []*PersonalAccessToken - err = store.db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error + err = store.(*SqlStore).db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token") require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount") @@ -920,11 +902,14 @@ func TestPostgresql_DeleteAccount(t *testing.T) { } func TestPostgresql_SavePeerStatus(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newPostgresqlStoreFromFile(t, "testdata/store.json") + t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) @@ -959,11 +944,14 @@ func TestPostgresql_SavePeerStatus(t *testing.T) { } func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newPostgresqlStoreFromFile(t, "testdata/store.json") + t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) existingDomain := "test.com" @@ -976,11 +964,14 @@ func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) { } func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newPostgresqlStoreFromFile(t, "testdata/store.json") + t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) hashed := "SoMeHaShEdToKeN" id := "9dj38s35-63fb-11ec-90d6-0242ac120003" @@ -991,11 +982,14 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) { } func TestPostgresql_GetUserByTokenID(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newPostgresqlStoreFromFile(t, "testdata/store.json") + t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) id := "9dj38s35-63fb-11ec-90d6-0242ac120003" @@ -1005,16 +999,16 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) { } func TestSqlite_GetTakenIPs(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + defer cleanup() + if err != nil { + t.Fatal(err) } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") - defer store.Close(context.Background()) - existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - _, err := store.GetAccount(context.Background(), existingAccountID) + _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) takenIPs, err := store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID) @@ -1050,16 +1044,16 @@ func TestSqlite_GetTakenIPs(t *testing.T) { } func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + if err != nil { + return } - - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") - defer store.Close(context.Background()) + t.Cleanup(cleanup) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - _, err := store.GetAccount(context.Background(), existingAccountID) + _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) @@ -1092,16 +1086,16 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { } func TestSqlite_GetAccountNetwork(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") - defer store.Close(context.Background()) - existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - _, err := store.GetAccount(context.Background(), existingAccountID) + _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) network, err := store.GetAccountNetwork(context.Background(), LockingStrengthShare, existingAccountID) @@ -1115,15 +1109,16 @@ func TestSqlite_GetAccountNetwork(t *testing.T) { } func TestSqlite_GetSetupKeyBySecret(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") - defer store.Close(context.Background()) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - _, err := store.GetAccount(context.Background(), existingAccountID) + _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") @@ -1134,15 +1129,16 @@ func TestSqlite_GetSetupKeyBySecret(t *testing.T) { } func TestSqlite_incrementSetupKeyUsage(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") - defer store.Close(context.Background()) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - _, err := store.GetAccount(context.Background(), existingAccountID) + _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") @@ -1163,3 +1159,95 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) { require.NoError(t, err) assert.Equal(t, 2, setupKey.UsedTimes) } + +func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) + } + + group := &nbgroup.Group{ + ID: "group-id", + AccountID: "account-id", + Name: "group-name", + Issued: "api", + Peers: nil, + } + err = store.ExecuteInTransaction(context.Background(), func(transaction Store) error { + err := transaction.SaveGroup(context.Background(), LockingStrengthUpdate, group) + if err != nil { + t.Fatal("failed to save group") + return err + } + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID) + if err != nil { + t.Fatal("failed to get group") + return err + } + t.Logf("group: %v", group) + return nil + }) + assert.NoError(t, err) +} + +func TestSqlite_GetAccoundUsers(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) + } + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + account, err := store.GetAccount(context.Background(), accountID) + require.NoError(t, err) + users, err := store.GetAccountUsers(context.Background(), accountID) + require.NoError(t, err) + require.Len(t, users, len(account.Users)) +} + +func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) + } + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + t.Run("Should update attributes with public domain", func(t *testing.T) { + require.NoError(t, err) + domain := "example.com" + category := "public" + IsDomainPrimaryAccount := false + err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount) + require.NoError(t, err) + account, err := store.GetAccount(context.Background(), accountID) + require.NoError(t, err) + require.Equal(t, domain, account.Domain) + require.Equal(t, category, account.DomainCategory) + require.Equal(t, IsDomainPrimaryAccount, account.IsDomainPrimaryAccount) + }) + + t.Run("Should update attributes with private domain", func(t *testing.T) { + require.NoError(t, err) + domain := "test.com" + category := "private" + IsDomainPrimaryAccount := true + err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount) + require.NoError(t, err) + account, err := store.GetAccount(context.Background(), accountID) + require.NoError(t, err) + require.Equal(t, domain, account.Domain) + require.Equal(t, category, account.DomainCategory) + require.Equal(t, IsDomainPrimaryAccount, account.IsDomainPrimaryAccount) + }) + + t.Run("Should fail when account does not exist", func(t *testing.T) { + require.NoError(t, err) + domain := "test.com" + category := "private" + IsDomainPrimaryAccount := true + err = store.UpdateAccountDomainAttributes(context.Background(), "non-existing-account-id", domain, category, IsDomainPrimaryAccount) + require.Error(t, err) + }) + +} diff --git a/management/server/status/error.go b/management/server/status/error.go index d7fde35b9..29d185216 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -102,8 +102,12 @@ func NewPeerLoginExpiredError() error { } // NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key -func NewSetupKeyNotFoundError() error { - return Errorf(NotFound, "setup key not found") +func NewSetupKeyNotFoundError(err error) error { + return Errorf(NotFound, "setup key not found: %s", err) +} + +func NewGetAccountFromStoreError(err error) error { + return Errorf(Internal, "issue getting account from store: %s", err) } // NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store diff --git a/management/server/store.go b/management/server/store.go index fc5796d3c..131fd8aaa 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -9,13 +9,16 @@ import ( "os" "path" "path/filepath" + "runtime" "strings" "time" - "github.com/netbirdio/netbird/dns" log "github.com/sirupsen/logrus" + "gorm.io/driver/sqlite" "gorm.io/gorm" + "github.com/netbirdio/netbird/dns" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/telemetry" @@ -51,60 +54,41 @@ type Store interface { GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) - GetAccountOwnerID(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) + GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) + GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) SaveAccount(ctx context.Context, account *Account) error DeleteAccount(ctx context.Context, account *Account) error - - GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) - SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error - - GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) - SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error + UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) - GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) + GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) + SaveUsers(accountID string, users map[string]*User) error SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error - SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*User) error - DeleteUser(ctx context.Context, lockStrength LockingStrength, userID, accountID string) error - DeleteUsers(ctx context.Context, strength LockingStrength, userIDs []string, accountID string) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteTokenID2UserIDIndex(tokenID string) error - GetPATByID(ctx context.Context, lockStrength LockingStrength, patID string, userID string) (*PersonalAccessToken, error) - SavePAT(ctx context.Context, strength LockingStrength, pat *PersonalAccessToken) error - DeletePAT(ctx context.Context, strength LockingStrength, patID string, userID string) error - GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) - SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error - DeleteGroup(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) error - DeleteGroups(ctx context.Context, strength LockingStrength, groupIDs []string, accountID string) error + SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) - SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error - DeletePolicy(ctx context.Context, lockStrength LockingStrength, postureCheckID, accountID string) error GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) - SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error - DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, postureChecksID, accountID string) error GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) - GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) - GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) - GetPeerByID(ctx context.Context, lockStrength LockingStrength, peerID string, accountID string) (*nbpeer.Peer, error) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error @@ -113,21 +97,15 @@ type Store interface { IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) - SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error - DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, setupKeyID, accountID string) error GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) - SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error - DeleteRoute(ctx context.Context, lockStrength LockingStrength, routeID, accountID string) error GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) - SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *dns.NameServerGroup) error - DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroupID, accountID string) error GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) - IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error + IncrementNetworkSerial(ctx context.Context, accountId string) error GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error) GetInstallationID() string @@ -266,23 +244,40 @@ func getMigrations(ctx context.Context) []migrationFunc { } } -// NewTestStoreFromJson is only used in tests -func NewTestStoreFromJson(ctx context.Context, dataDir string) (Store, func(), error) { - fstore, err := NewFileStore(ctx, dataDir, nil) - if err != nil { - return nil, nil, err - } - - // if store engine is not set in the config we first try to evaluate NETBIRD_STORE_ENGINE +// NewTestStoreFromSQL is only used in tests. It will create a test database base of the store engine set in env. +// Optionally it can load a SQL file to the database. If the filename is empty it will return an empty database +func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) (Store, func(), error) { kind := getStoreEngineFromEnv() if kind == "" { kind = SqliteStoreEngine } - var ( - store Store - cleanUp func() - ) + storeStr := fmt.Sprintf("%s?cache=shared", storeSqliteFileName) + if runtime.GOOS == "windows" { + // Vo avoid `The process cannot access the file because it is being used by another process` on Windows + storeStr = storeSqliteFileName + } + + file := filepath.Join(dataDir, storeStr) + db, err := gorm.Open(sqlite.Open(file), getGormConfig()) + if err != nil { + return nil, nil, err + } + + if filename != "" { + err = loadSQL(db, filename) + if err != nil { + return nil, nil, fmt.Errorf("failed to load SQL file: %v", err) + } + } + + store, err := NewSqlStore(ctx, db, SqliteStoreEngine, nil) + if err != nil { + return nil, nil, fmt.Errorf("failed to create test store: %v", err) + } + cleanUp := func() { + store.Close(ctx) + } if kind == PostgresStoreEngine { cleanUp, err = testutil.CreatePGDB() @@ -295,21 +290,36 @@ func NewTestStoreFromJson(ctx context.Context, dataDir string) (Store, func(), e return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv) } - store, err = NewPostgresqlStoreFromFileStore(ctx, fstore, dsn, nil) + store, err = NewPostgresqlStoreFromSqlStore(ctx, store, dsn, nil) if err != nil { return nil, nil, err } - } else { - store, err = NewSqliteStoreFromFileStore(ctx, fstore, dataDir, nil) - if err != nil { - return nil, nil, err - } - cleanUp = func() { store.Close(ctx) } } return store, cleanUp, nil } +func loadSQL(db *gorm.DB, filepath string) error { + sqlContent, err := os.ReadFile(filepath) + if err != nil { + return err + } + + queries := strings.Split(string(sqlContent), ";") + + for _, query := range queries { + query = strings.TrimSpace(query) + if query != "" { + err := db.Exec(query).Error + if err != nil { + return err + } + } + } + + return nil +} + // MigrateFileStoreToSqlite migrates the file store to the SQLite store. func MigrateFileStoreToSqlite(ctx context.Context, dataDir string) error { fileStorePath := path.Join(dataDir, storeFileName) diff --git a/management/server/store_test.go b/management/server/store_test.go index 40c36c9e0..fc821670d 100644 --- a/management/server/store_test.go +++ b/management/server/store_test.go @@ -14,12 +14,6 @@ type benchCase struct { size int } -var newFs = func(b *testing.B) Store { - b.Helper() - store, _ := NewFileStore(context.Background(), b.TempDir(), nil) - return store -} - var newSqlite = func(b *testing.B) Store { b.Helper() store, _ := NewSqliteStore(context.Background(), b.TempDir(), nil) @@ -28,13 +22,9 @@ var newSqlite = func(b *testing.B) Store { func BenchmarkTest_StoreWrite(b *testing.B) { cases := []benchCase{ - {name: "FileStore_Write", storeFn: newFs, size: 100}, {name: "SqliteStore_Write", storeFn: newSqlite, size: 100}, - {name: "FileStore_Write", storeFn: newFs, size: 500}, {name: "SqliteStore_Write", storeFn: newSqlite, size: 500}, - {name: "FileStore_Write", storeFn: newFs, size: 1000}, {name: "SqliteStore_Write", storeFn: newSqlite, size: 1000}, - {name: "FileStore_Write", storeFn: newFs, size: 2000}, {name: "SqliteStore_Write", storeFn: newSqlite, size: 2000}, } @@ -61,11 +51,8 @@ func BenchmarkTest_StoreWrite(b *testing.B) { func BenchmarkTest_StoreRead(b *testing.B) { cases := []benchCase{ - {name: "FileStore_Read", storeFn: newFs, size: 100}, {name: "SqliteStore_Read", storeFn: newSqlite, size: 100}, - {name: "FileStore_Read", storeFn: newFs, size: 500}, {name: "SqliteStore_Read", storeFn: newSqlite, size: 500}, - {name: "FileStore_Read", storeFn: newFs, size: 1000}, {name: "SqliteStore_Read", storeFn: newSqlite, size: 1000}, } @@ -89,3 +76,11 @@ func BenchmarkTest_StoreRead(b *testing.B) { }) } } + +func newStore(t *testing.T) Store { + t.Helper() + + store := newSqliteStore(t) + + return store +} diff --git a/management/server/testdata/extended-store.json b/management/server/testdata/extended-store.json deleted file mode 100644 index 7f96e57a8..000000000 --- a/management/server/testdata/extended-store.json +++ /dev/null @@ -1,120 +0,0 @@ -{ - "Accounts": { - "bf1c8084-ba50-4ce7-9439-34653001fc3b": { - "Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b", - "CreatedBy": "", - "Domain": "test.com", - "DomainCategory": "private", - "IsDomainPrimaryAccount": true, - "SetupKeys": { - "A2C8E62B-38F5-4553-B31E-DD66C696CEBB": { - "Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", - "AccountID": "", - "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", - "Name": "Default key", - "Type": "reusable", - "CreatedAt": "2021-08-19T20:46:20.005936822+02:00", - "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00", - "UpdatedAt": "0001-01-01T00:00:00Z", - "Revoked": false, - "UsedTimes": 0, - "LastUsed": "0001-01-01T00:00:00Z", - "AutoGroups": ["cfefqs706sqkneg59g2g"], - "UsageLimit": 0, - "Ephemeral": false - }, - "A2C8E62B-38F5-4553-B31E-DD66C696CEBC": { - "Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC", - "AccountID": "", - "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC", - "Name": "Faulty key with non existing group", - "Type": "reusable", - "CreatedAt": "2021-08-19T20:46:20.005936822+02:00", - "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00", - "UpdatedAt": "0001-01-01T00:00:00Z", - "Revoked": false, - "UsedTimes": 0, - "LastUsed": "0001-01-01T00:00:00Z", - "AutoGroups": ["abcd"], - "UsageLimit": 0, - "Ephemeral": false - } - }, - "Network": { - "id": "af1c8024-ha40-4ce2-9418-34653101fc3c", - "Net": { - "IP": "100.64.0.0", - "Mask": "//8AAA==" - }, - "Dns": "", - "Serial": 0 - }, - "Peers": {}, - "Users": { - "edafee4e-63fb-11ec-90d6-0242ac120003": { - "Id": "edafee4e-63fb-11ec-90d6-0242ac120003", - "AccountID": "", - "Role": "admin", - "IsServiceUser": false, - "ServiceUserName": "", - "AutoGroups": ["cfefqs706sqkneg59g3g"], - "PATs": {}, - "Blocked": false, - "LastLogin": "0001-01-01T00:00:00Z" - }, - "f4f6d672-63fb-11ec-90d6-0242ac120003": { - "Id": "f4f6d672-63fb-11ec-90d6-0242ac120003", - "AccountID": "", - "Role": "user", - "IsServiceUser": false, - "ServiceUserName": "", - "AutoGroups": null, - "PATs": { - "9dj38s35-63fb-11ec-90d6-0242ac120003": { - "ID": "9dj38s35-63fb-11ec-90d6-0242ac120003", - "UserID": "", - "Name": "", - "HashedToken": "SoMeHaShEdToKeN", - "ExpirationDate": "2023-02-27T00:00:00Z", - "CreatedBy": "user", - "CreatedAt": "2023-01-01T00:00:00Z", - "LastUsed": "2023-02-01T00:00:00Z" - } - }, - "Blocked": false, - "LastLogin": "0001-01-01T00:00:00Z" - } - }, - "Groups": { - "cfefqs706sqkneg59g4g": { - "ID": "cfefqs706sqkneg59g4g", - "Name": "All", - "Peers": [] - }, - "cfefqs706sqkneg59g3g": { - "ID": "cfefqs706sqkneg59g3g", - "Name": "AwesomeGroup1", - "Peers": [] - }, - "cfefqs706sqkneg59g2g": { - "ID": "cfefqs706sqkneg59g2g", - "Name": "AwesomeGroup2", - "Peers": [] - } - }, - "Rules": null, - "Policies": [], - "Routes": null, - "NameServerGroups": null, - "DNSSettings": null, - "Settings": { - "PeerLoginExpirationEnabled": false, - "PeerLoginExpiration": 86400000000000, - "GroupsPropagationEnabled": false, - "JWTGroupsEnabled": false, - "JWTGroupsClaimName": "" - } - } - }, - "InstallationID": "" -} diff --git a/management/server/testdata/extended-store.sql b/management/server/testdata/extended-store.sql new file mode 100644 index 000000000..b522741e7 --- /dev/null +++ b/management/server/testdata/extended-store.sql @@ -0,0 +1,37 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE); +CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`)); +CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text); +CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); +CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); +CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); +CREATE INDEX `idx_peers_key` ON `peers`(`key`); +CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`); +CREATE INDEX `idx_users_account_id` ON `users`(`account_id`); +CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`); +CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`); +CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`); +CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`); +CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`); +CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`); +CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); + +INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:01:38.210014+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO setup_keys VALUES('A2C8E62B-38F5-4553-B31E-DD66C696CEBB','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','["cfefqs706sqkneg59g2g"]',0,0); +INSERT INTO setup_keys VALUES('A2C8E62B-38F5-4553-B31E-DD66C696CEBC','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBC','Faulty key with non existing group','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','["abcd"]',0,0); +INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','["cfefqs706sqkneg59g3g"]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.210678+02:00','api',0,''); +INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.210678+02:00','api',0,''); +INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003','f4f6d672-63fb-11ec-90d6-0242ac120003','','SoMeHaShEdToKeN','2023-02-27 00:00:00+00:00','user','2023-01-01 00:00:00+00:00','2023-02-01 00:00:00+00:00'); +INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); +INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,''); +INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,''); +INSERT INTO installations VALUES(1,''); diff --git a/management/server/testdata/store.json b/management/server/testdata/store.json deleted file mode 100644 index 1fa4e3a9a..000000000 --- a/management/server/testdata/store.json +++ /dev/null @@ -1,88 +0,0 @@ -{ - "Accounts": { - "bf1c8084-ba50-4ce7-9439-34653001fc3b": { - "Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b", - "CreatedBy": "", - "Domain": "test.com", - "DomainCategory": "private", - "IsDomainPrimaryAccount": true, - "SetupKeys": { - "A2C8E62B-38F5-4553-B31E-DD66C696CEBB": { - "Id": "", - "AccountID": "", - "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", - "Name": "Default key", - "Type": "reusable", - "CreatedAt": "2021-08-19T20:46:20.005936822+02:00", - "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00", - "UpdatedAt": "0001-01-01T00:00:00Z", - "Revoked": false, - "UsedTimes": 0, - "LastUsed": "0001-01-01T00:00:00Z", - "AutoGroups": null, - "UsageLimit": 0, - "Ephemeral": false - } - }, - "Network": { - "id": "af1c8024-ha40-4ce2-9418-34653101fc3c", - "Net": { - "IP": "100.64.0.0", - "Mask": "//8AAA==" - }, - "Dns": "", - "Serial": 0 - }, - "Peers": {}, - "Users": { - "edafee4e-63fb-11ec-90d6-0242ac120003": { - "Id": "edafee4e-63fb-11ec-90d6-0242ac120003", - "AccountID": "", - "Role": "admin", - "IsServiceUser": false, - "ServiceUserName": "", - "AutoGroups": null, - "PATs": {}, - "Blocked": false, - "LastLogin": "0001-01-01T00:00:00Z" - }, - "f4f6d672-63fb-11ec-90d6-0242ac120003": { - "Id": "f4f6d672-63fb-11ec-90d6-0242ac120003", - "AccountID": "", - "Role": "user", - "IsServiceUser": false, - "ServiceUserName": "", - "AutoGroups": null, - "PATs": { - "9dj38s35-63fb-11ec-90d6-0242ac120003": { - "ID": "9dj38s35-63fb-11ec-90d6-0242ac120003", - "UserID": "", - "Name": "", - "HashedToken": "SoMeHaShEdToKeN", - "ExpirationDate": "2023-02-27T00:00:00Z", - "CreatedBy": "user", - "CreatedAt": "2023-01-01T00:00:00Z", - "LastUsed": "2023-02-01T00:00:00Z" - } - }, - "Blocked": false, - "LastLogin": "0001-01-01T00:00:00Z" - } - }, - "Groups": null, - "Rules": null, - "Policies": [], - "Routes": null, - "NameServerGroups": null, - "DNSSettings": null, - "Settings": { - "PeerLoginExpirationEnabled": false, - "PeerLoginExpiration": 86400000000000, - "GroupsPropagationEnabled": false, - "JWTGroupsEnabled": false, - "JWTGroupsClaimName": "" - } - } - }, - "InstallationID": "" -} \ No newline at end of file diff --git a/management/server/testdata/store.sql b/management/server/testdata/store.sql new file mode 100644 index 000000000..32a59128b --- /dev/null +++ b/management/server/testdata/store.sql @@ -0,0 +1,33 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE); +CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`)); +CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text); +CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); +CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); +CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); +CREATE INDEX `idx_peers_key` ON `peers`(`key`); +CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`); +CREATE INDEX `idx_users_account_id` ON `users`(`account_id`); +CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`); +CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`); +CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`); +CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`); +CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`); +CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`); +CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); + +INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0); +INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,''); +INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,''); +INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003','f4f6d672-63fb-11ec-90d6-0242ac120003','','SoMeHaShEdToKeN','2023-02-27 00:00:00+00:00','user','2023-01-01 00:00:00+00:00','2023-02-01 00:00:00+00:00'); +INSERT INTO installations VALUES(1,''); diff --git a/management/server/testdata/store_policy_migrate.json b/management/server/testdata/store_policy_migrate.json deleted file mode 100644 index 1b046e632..000000000 --- a/management/server/testdata/store_policy_migrate.json +++ /dev/null @@ -1,116 +0,0 @@ -{ - "Accounts": { - "bf1c8084-ba50-4ce7-9439-34653001fc3b": { - "Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b", - "Domain": "test.com", - "DomainCategory": "private", - "IsDomainPrimaryAccount": true, - "SetupKeys": { - "A2C8E62B-38F5-4553-B31E-DD66C696CEBB": { - "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", - "Name": "Default key", - "Type": "reusable", - "CreatedAt": "2021-08-19T20:46:20.005936822+02:00", - "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00", - "Revoked": false, - "UsedTimes": 0 - } - }, - "Network": { - "Id": "af1c8024-ha40-4ce2-9418-34653101fc3c", - "Net": { - "IP": "100.64.0.0", - "Mask": "//8AAA==" - }, - "Dns": null - }, - "Peers": { - "cfefqs706sqkneg59g4g": { - "ID": "cfefqs706sqkneg59g4g", - "Key": "MI5mHfJhbggPfD3FqEIsXm8X5bSWeUI2LhO9MpEEtWA=", - "SetupKey": "", - "IP": "100.103.179.238", - "Meta": { - "Hostname": "Ubuntu-2204-jammy-amd64-base", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "22.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "development", - "UIVersion": "" - }, - "Name": "crocodile", - "DNSLabel": "crocodile", - "Status": { - "LastSeen": "2023-02-13T12:37:12.635454796Z", - "Connected": true - }, - "UserID": "edafee4e-63fb-11ec-90d6-0242ac120003", - "SSHKey": "AAAAC3NzaC1lZDI1NTE5AAAAIJN1NM4bpB9K", - "SSHEnabled": false - }, - "cfeg6sf06sqkneg59g50": { - "ID": "cfeg6sf06sqkneg59g50", - "Key": "zMAOKUeIYIuun4n0xPR1b3IdYZPmsyjYmB2jWCuloC4=", - "SetupKey": "", - "IP": "100.103.26.180", - "Meta": { - "Hostname": "borg", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "22.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "development", - "UIVersion": "" - }, - "Name": "dingo", - "DNSLabel": "dingo", - "Status": { - "LastSeen": "2023-02-21T09:37:42.565899199Z", - "Connected": false - }, - "UserID": "f4f6d672-63fb-11ec-90d6-0242ac120003", - "SSHKey": "AAAAC3NzaC1lZDI1NTE5AAAAILHW", - "SSHEnabled": true - } - }, - "Groups": { - "cfefqs706sqkneg59g3g": { - "ID": "cfefqs706sqkneg59g3g", - "Name": "All", - "Peers": [ - "cfefqs706sqkneg59g4g", - "cfeg6sf06sqkneg59g50" - ] - } - }, - "Rules": { - "cfefqs706sqkneg59g40": { - "ID": "cfefqs706sqkneg59g40", - "Name": "Default", - "Description": "This is a default rule that allows connections between all the resources", - "Disabled": false, - "Source": [ - "cfefqs706sqkneg59g3g" - ], - "Destination": [ - "cfefqs706sqkneg59g3g" - ], - "Flow": 0 - } - }, - "Users": { - "edafee4e-63fb-11ec-90d6-0242ac120003": { - "Id": "edafee4e-63fb-11ec-90d6-0242ac120003", - "Role": "admin" - }, - "f4f6d672-63fb-11ec-90d6-0242ac120003": { - "Id": "f4f6d672-63fb-11ec-90d6-0242ac120003", - "Role": "user" - } - } - } - } -} diff --git a/management/server/testdata/store_policy_migrate.sql b/management/server/testdata/store_policy_migrate.sql new file mode 100644 index 000000000..a9360e9d6 --- /dev/null +++ b/management/server/testdata/store_policy_migrate.sql @@ -0,0 +1,35 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE); +CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`)); +CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text); +CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); +CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); +CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); +CREATE INDEX `idx_peers_key` ON `peers`(`key`); +CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`); +CREATE INDEX `idx_users_account_id` ON `users`(`account_id`); +CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`); +CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`); +CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`); +CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`); +CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`); +CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`); +CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); + +INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:04:23.538411+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0); +INSERT INTO peers VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','MI5mHfJhbggPfD3FqEIsXm8X5bSWeUI2LhO9MpEEtWA=','','"100.103.179.238"','Ubuntu-2204-jammy-amd64-base','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'crocodile','crocodile','2023-02-13 12:37:12.635454796+00:00',1,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','AAAAC3NzaC1lZDI1NTE5AAAAIJN1NM4bpB9K',0,0,'2024-10-02 14:04:23.523293+00:00','2024-10-02 16:04:23.538926+02:00',0,'""','','',0); +INSERT INTO peers VALUES('cfeg6sf06sqkneg59g50','bf1c8084-ba50-4ce7-9439-34653001fc3b','zMAOKUeIYIuun4n0xPR1b3IdYZPmsyjYmB2jWCuloC4=','','"100.103.26.180"','borg','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'dingo','dingo','2023-02-21 09:37:42.565899199+00:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','AAAAC3NzaC1lZDI1NTE5AAAAILHW',1,0,'2024-10-02 14:04:23.523293+00:00','2024-10-02 16:04:23.538926+02:00',0,'""','','',0); +INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:04:23.539152+02:00','api',0,''); +INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:04:23.539152+02:00','api',0,''); +INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','["cfefqs706sqkneg59g4g","cfeg6sf06sqkneg59g50"]',0,''); +INSERT INTO installations VALUES(1,''); diff --git a/management/server/testdata/store_with_expired_peers.json b/management/server/testdata/store_with_expired_peers.json deleted file mode 100644 index 44c225682..000000000 --- a/management/server/testdata/store_with_expired_peers.json +++ /dev/null @@ -1,130 +0,0 @@ -{ - "Accounts": { - "bf1c8084-ba50-4ce7-9439-34653001fc3b": { - "Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b", - "Domain": "test.com", - "DomainCategory": "private", - "IsDomainPrimaryAccount": true, - "Settings": { - "PeerLoginExpirationEnabled": true, - "PeerLoginExpiration": 3600000000000 - }, - "SetupKeys": { - "A2C8E62B-38F5-4553-B31E-DD66C696CEBB": { - "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", - "Name": "Default key", - "Type": "reusable", - "CreatedAt": "2021-08-19T20:46:20.005936822+02:00", - "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00", - "Revoked": false, - "UsedTimes": 0 - - } - }, - "Network": { - "Id": "af1c8024-ha40-4ce2-9418-34653101fc3c", - "Net": { - "IP": "100.64.0.0", - "Mask": "//8AAA==" - }, - "Dns": null - }, - "Peers": { - "cfvprsrlo1hqoo49ohog": { - "ID": "cfvprsrlo1hqoo49ohog", - "Key": "5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=", - "SetupKey": "72546A29-6BC8-4311-BCFC-9CDBF33F1A48", - "IP": "100.64.114.31", - "Meta": { - "Hostname": "f2a34f6a4731", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "11", - "Platform": "unknown", - "OS": "Debian GNU/Linux", - "WtVersion": "0.12.0", - "UIVersion": "" - }, - "Name": "f2a34f6a4731", - "DNSLabel": "f2a34f6a4731", - "Status": { - "LastSeen": "2023-03-02T09:21:02.189035775+01:00", - "Connected": false, - "LoginExpired": false - }, - "UserID": "", - "SSHKey": "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk", - "SSHEnabled": false, - "LoginExpirationEnabled": true, - "LastLogin": "2023-03-01T19:48:19.817799698+01:00" - }, - "cg05lnblo1hkg2j514p0": { - "ID": "cg05lnblo1hkg2j514p0", - "Key": "RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=", - "SetupKey": "", - "IP": "100.64.39.54", - "Meta": { - "Hostname": "expiredhost", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "22.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "development", - "UIVersion": "" - }, - "Name": "expiredhost", - "DNSLabel": "expiredhost", - "Status": { - "LastSeen": "2023-03-02T09:19:57.276717255+01:00", - "Connected": false, - "LoginExpired": true - }, - "UserID": "edafee4e-63fb-11ec-90d6-0242ac120003", - "SSHKey": "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK", - "SSHEnabled": false, - "LoginExpirationEnabled": true, - "LastLogin": "2023-03-02T09:14:21.791679181+01:00" - }, - "cg3161rlo1hs9cq94gdg": { - "ID": "cg3161rlo1hs9cq94gdg", - "Key": "mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=", - "SetupKey": "", - "IP": "100.64.117.96", - "Meta": { - "Hostname": "testhost", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "22.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "development", - "UIVersion": "" - }, - "Name": "testhost", - "DNSLabel": "testhost", - "Status": { - "LastSeen": "2023-03-06T18:21:27.252010027+01:00", - "Connected": false, - "LoginExpired": false - }, - "UserID": "edafee4e-63fb-11ec-90d6-0242ac120003", - "SSHKey": "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM", - "SSHEnabled": false, - "LoginExpirationEnabled": false, - "LastLogin": "2023-03-07T09:02:47.442857106+01:00" - } - }, - "Users": { - "edafee4e-63fb-11ec-90d6-0242ac120003": { - "Id": "edafee4e-63fb-11ec-90d6-0242ac120003", - "Role": "admin" - }, - "f4f6d672-63fb-11ec-90d6-0242ac120003": { - "Id": "f4f6d672-63fb-11ec-90d6-0242ac120003", - "Role": "user" - } - } - } - } -} \ No newline at end of file diff --git a/management/server/testdata/store_with_expired_peers.sql b/management/server/testdata/store_with_expired_peers.sql new file mode 100644 index 000000000..100a6470f --- /dev/null +++ b/management/server/testdata/store_with_expired_peers.sql @@ -0,0 +1,35 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE); +CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`)); +CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text); +CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); +CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); +CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); +CREATE INDEX `idx_peers_key` ON `peers`(`key`); +CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`); +CREATE INDEX `idx_users_account_id` ON `users`(`account_id`); +CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`); +CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`); +CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`); +CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`); +CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`); +CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`); +CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); + +INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 17:00:32.527528+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,3600000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0); +INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-34653001fc3b','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,''); +INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,''); +INSERT INTO installations VALUES(1,''); diff --git a/management/server/testdata/storev1.json b/management/server/testdata/storev1.json deleted file mode 100644 index 674b2b87a..000000000 --- a/management/server/testdata/storev1.json +++ /dev/null @@ -1,154 +0,0 @@ -{ - "Accounts": { - "auth0|61bf82ddeab084006aa1bccd": { - "Id": "auth0|61bf82ddeab084006aa1bccd", - "SetupKeys": { - "1B2B50B0-B3E8-4B0C-A426-525EDB8481BD": { - "Id": "831727121", - "Key": "1B2B50B0-B3E8-4B0C-A426-525EDB8481BD", - "Name": "One-off key", - "Type": "one-off", - "CreatedAt": "2021-12-24T16:09:45.926075752+01:00", - "ExpiresAt": "2022-01-23T16:09:45.926075752+01:00", - "Revoked": false, - "UsedTimes": 1, - "LastUsed": "2021-12-24T16:12:45.763424077+01:00" - }, - "EB51E9EB-A11F-4F6E-8E49-C982891B405A": { - "Id": "1769568301", - "Key": "EB51E9EB-A11F-4F6E-8E49-C982891B405A", - "Name": "Default key", - "Type": "reusable", - "CreatedAt": "2021-12-24T16:09:45.926073628+01:00", - "ExpiresAt": "2022-01-23T16:09:45.926073628+01:00", - "Revoked": false, - "UsedTimes": 1, - "LastUsed": "2021-12-24T16:13:06.236748538+01:00" - } - }, - "Network": { - "Id": "a443c07a-5765-4a78-97fc-390d9c1d0e49", - "Net": { - "IP": "100.64.0.0", - "Mask": "/8AAAA==" - }, - "Dns": "" - }, - "Peers": { - "oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=": { - "Key": "oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=", - "SetupKey": "EB51E9EB-A11F-4F6E-8E49-C982891B405A", - "IP": "100.64.0.2", - "Meta": { - "Hostname": "braginini", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "21.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "" - }, - "Name": "braginini", - "Status": { - "LastSeen": "2021-12-24T16:13:11.244342541+01:00", - "Connected": false - } - }, - "xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=": { - "Key": "xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=", - "SetupKey": "1B2B50B0-B3E8-4B0C-A426-525EDB8481BD", - "IP": "100.64.0.1", - "Meta": { - "Hostname": "braginini", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "21.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "" - }, - "Name": "braginini", - "Status": { - "LastSeen": "2021-12-24T16:12:49.089339333+01:00", - "Connected": false - } - } - } - }, - "google-oauth2|103201118415301331038": { - "Id": "google-oauth2|103201118415301331038", - "SetupKeys": { - "5AFB60DB-61F2-4251-8E11-494847EE88E9": { - "Id": "2485964613", - "Key": "5AFB60DB-61F2-4251-8E11-494847EE88E9", - "Name": "Default key", - "Type": "reusable", - "CreatedAt": "2021-12-24T16:10:02.238476+01:00", - "ExpiresAt": "2022-01-23T16:10:02.238476+01:00", - "Revoked": false, - "UsedTimes": 1, - "LastUsed": "2021-12-24T16:12:05.994307717+01:00" - }, - "A72E4DC2-00DE-4542-8A24-62945438104E": { - "Id": "3504804807", - "Key": "A72E4DC2-00DE-4542-8A24-62945438104E", - "Name": "One-off key", - "Type": "one-off", - "CreatedAt": "2021-12-24T16:10:02.238478209+01:00", - "ExpiresAt": "2022-01-23T16:10:02.238478209+01:00", - "Revoked": false, - "UsedTimes": 1, - "LastUsed": "2021-12-24T16:11:27.015741738+01:00" - } - }, - "Network": { - "Id": "b6d0b152-364e-40c1-a8a1-fa7bcac2267f", - "Net": { - "IP": "100.64.0.0", - "Mask": "/8AAAA==" - }, - "Dns": "" - }, - "Peers": { - "6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=": { - "Key": "6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=", - "SetupKey": "5AFB60DB-61F2-4251-8E11-494847EE88E9", - "IP": "100.64.0.2", - "Meta": { - "Hostname": "braginini", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "21.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "" - }, - "Name": "braginini", - "Status": { - "LastSeen": "2021-12-24T16:12:05.994305438+01:00", - "Connected": false - } - }, - "Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=": { - "Key": "Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=", - "SetupKey": "A72E4DC2-00DE-4542-8A24-62945438104E", - "IP": "100.64.0.1", - "Meta": { - "Hostname": "braginini", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "21.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "" - }, - "Name": "braginini", - "Status": { - "LastSeen": "2021-12-24T16:11:27.015739803+01:00", - "Connected": false - } - } - } - } - } -} \ No newline at end of file diff --git a/management/server/testdata/storev1.sql b/management/server/testdata/storev1.sql new file mode 100644 index 000000000..69194d623 --- /dev/null +++ b/management/server/testdata/storev1.sql @@ -0,0 +1,39 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE); +CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`)); +CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text); +CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); +CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); +CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); +CREATE INDEX `idx_peers_key` ON `peers`(`key`); +CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`); +CREATE INDEX `idx_users_account_id` ON `users`(`account_id`); +CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`); +CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`); +CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`); +CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`); +CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`); +CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`); +CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); + +INSERT INTO accounts VALUES('auth0|61bf82ddeab084006aa1bccd','','2024-10-02 17:00:54.181873+02:00','','',0,'a443c07a-5765-4a78-97fc-390d9c1d0e49','{"IP":"100.64.0.0","Mask":"/8AAAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO accounts VALUES('google-oauth2|103201118415301331038','','2024-10-02 17:00:54.225803+02:00','','',0,'b6d0b152-364e-40c1-a8a1-fa7bcac2267f','{"IP":"100.64.0.0","Mask":"/8AAAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO setup_keys VALUES('831727121','auth0|61bf82ddeab084006aa1bccd','1B2B50B0-B3E8-4B0C-A426-525EDB8481BD','One-off key','one-off','2021-12-24 16:09:45.926075752+01:00','2022-01-23 16:09:45.926075752+01:00','2021-12-24 16:09:45.926075752+01:00',0,1,'2021-12-24 16:12:45.763424077+01:00','[]',0,0); +INSERT INTO setup_keys VALUES('1769568301','auth0|61bf82ddeab084006aa1bccd','EB51E9EB-A11F-4F6E-8E49-C982891B405A','Default key','reusable','2021-12-24 16:09:45.926073628+01:00','2022-01-23 16:09:45.926073628+01:00','2021-12-24 16:09:45.926073628+01:00',0,1,'2021-12-24 16:13:06.236748538+01:00','[]',0,0); +INSERT INTO setup_keys VALUES('2485964613','google-oauth2|103201118415301331038','5AFB60DB-61F2-4251-8E11-494847EE88E9','Default key','reusable','2021-12-24 16:10:02.238476+01:00','2022-01-23 16:10:02.238476+01:00','2021-12-24 16:10:02.238476+01:00',0,1,'2021-12-24 16:12:05.994307717+01:00','[]',0,0); +INSERT INTO setup_keys VALUES('3504804807','google-oauth2|103201118415301331038','A72E4DC2-00DE-4542-8A24-62945438104E','One-off key','one-off','2021-12-24 16:10:02.238478209+01:00','2022-01-23 16:10:02.238478209+01:00','2021-12-24 16:10:02.238478209+01:00',0,1,'2021-12-24 16:11:27.015741738+01:00','[]',0,0); +INSERT INTO peers VALUES('oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','auth0|61bf82ddeab084006aa1bccd','oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','EB51E9EB-A11F-4F6E-8E49-C982891B405A','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:13:11.244342541+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.182618+02:00',0,'""','','',0); +INSERT INTO peers VALUES('xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','auth0|61bf82ddeab084006aa1bccd','xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','1B2B50B0-B3E8-4B0C-A426-525EDB8481BD','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:12:49.089339333+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.182618+02:00',0,'""','','',0); +INSERT INTO peers VALUES('6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','google-oauth2|103201118415301331038','6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','5AFB60DB-61F2-4251-8E11-494847EE88E9','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:12:05.994305438+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.228182+02:00',0,'""','','',0); +INSERT INTO peers VALUES('Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','google-oauth2|103201118415301331038','Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','A72E4DC2-00DE-4542-8A24-62945438104E','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:11:27.015739803+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.228182+02:00',0,'""','','',0); +INSERT INTO installations VALUES(1,''); + diff --git a/management/server/user.go b/management/server/user.go index 4d70c5210..71608ef20 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -8,22 +8,22 @@ import ( "time" "github.com/google/uuid" - nbgroup "github.com/netbirdio/netbird/management/server/group" - log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + log "github.com/sirupsen/logrus" ) const ( - UserRoleOwner UserRole = "owner" - UserRoleAdmin UserRole = "admin" - UserRoleUser UserRole = "user" - UserRoleUnknown UserRole = "unknown" + UserRoleOwner UserRole = "owner" + UserRoleAdmin UserRole = "admin" + UserRoleUser UserRole = "user" + UserRoleUnknown UserRole = "unknown" + UserRoleBillingAdmin UserRole = "billing_admin" UserStatusActive UserStatus = "active" UserStatusDisabled UserStatus = "disabled" @@ -42,6 +42,8 @@ func StrRoleToUserRole(strRole string) UserRole { return UserRoleAdmin case "user": return UserRoleUser + case "billing_admin": + return UserRoleBillingAdmin default: return UserRoleUnknown } @@ -214,12 +216,19 @@ func NewOwnerUser(id string) *User { // createServiceUser creates a new service user under the given account. func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) { - executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { - return nil, err + return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID) } - if !executingUser.HasAdminPower() || executingUser.AccountID != accountID { + executingUser := account.Users[initiatorUserID] + if executingUser == nil { + return nil, status.Errorf(status.NotFound, "user not found") + } + if !executingUser.HasAdminPower() { return nil, status.Errorf(status.PermissionDenied, "only users with admin power can create service users") } @@ -230,9 +239,10 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI newUserID := uuid.New().String() newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI) log.WithContext(ctx).Debugf("New User: %v", newUser) - newUser.AccountID = accountID + account.Users[newUserID] = newUser - if err = am.Store.SaveUser(ctx, LockingStrengthUpdate, newUser); err != nil { + err = am.Store.SaveAccount(ctx, account) + if err != nil { return nil, err } @@ -260,8 +270,11 @@ func (am *DefaultAccountManager) CreateUser(ctx context.Context, accountID, user return am.inviteNewUser(ctx, accountID, userID, user) } -// inviteNewUser Invites a User to a given account and creates reference in datastore +// inviteNewUser Invites a USer to a given account and creates reference in datastore func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, userID string, invite *UserInfo) (*UserInfo, error) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + if am.idpManager == nil { return nil, status.Errorf(status.PreconditionFailed, "IdP manager must be enabled to send user invites") } @@ -282,24 +295,23 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u default: } - initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + account, err := am.Store.GetAccount(ctx, accountID) + if err != nil { + return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID) + } + + initiatorUser, err := account.FindUser(userID) if err != nil { return nil, status.Errorf(status.NotFound, "initiator user with ID %s doesn't exist", userID) } inviterID := userID if initiatorUser.IsServiceUser { - ownerID, err := am.Store.GetAccountOwnerID(ctx, LockingStrengthShare, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get account owner: %v", err) - return nil, err - } - - inviterID = ownerID + inviterID = account.CreatedBy } // inviterUser is the one who is inviting the new user - inviterUser, err := am.lookupUserInCache(ctx, inviterID, accountID) + inviterUser, err := am.lookupUserInCache(ctx, inviterID, account) if err != nil || inviterUser == nil { return nil, status.Errorf(status.NotFound, "inviter user with ID %s doesn't exist in IdP", inviterID) } @@ -330,29 +342,27 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u newUser := &User{ Id: idpUser.ID, - AccountID: accountID, Role: invitedRole, AutoGroups: invite.AutoGroups, Issued: invite.Issued, IntegrationReference: invite.IntegrationReference, CreatedAt: time.Now().UTC(), } - if err = am.Store.SaveUser(ctx, LockingStrengthUpdate, newUser); err != nil { + account.Users[idpUser.ID] = newUser + + err = am.Store.SaveAccount(ctx, account) + if err != nil { return nil, err } - _, err = am.refreshCache(ctx, accountID) + _, err = am.refreshCache(ctx, account.Id) if err != nil { return nil, err } am.StoreEvent(ctx, userID, newUser.Id, accountID, activity.UserInvited, nil) - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) - if err != nil { - return nil, err - } - return newUser.ToUserInfo(idpUser, settings) + return newUser.ToUserInfo(idpUser, account.Settings) } func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*User, error) { @@ -392,7 +402,20 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A // ListUsers returns lists of all users under the account. // It doesn't populate user information such as email or name. func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*User, error) { - return am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + account, err := am.Store.GetAccount(ctx, accountID) + if err != nil { + return nil, err + } + + users := make([]*User, 0, len(account.Users)) + for _, item := range account.Users { + users = append(users, item) + } + + return users, nil } func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, account *Account, initiatorUserID string, targetUser *User) { @@ -483,12 +506,20 @@ func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorU // InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period. func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + if am.idpManager == nil { return status.Errorf(status.PreconditionFailed, "IdP manager must be enabled to send user invites") } + account, err := am.Store.GetAccount(ctx, accountID) + if err != nil { + return status.Errorf(status.NotFound, "account %s doesn't exist", accountID) + } + // check if the user is already registered with this ID - user, err := am.lookupUserInCache(ctx, targetUserID, accountID) + user, err := am.lookupUserInCache(ctx, targetUserID, account) if err != nil { return err } @@ -515,6 +546,9 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin // CreatePAT creates a new PAT for the given user func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + if tokenName == "" { return nil, status.Errorf(status.InvalidArgument, "token name can't be empty") } @@ -523,28 +557,35 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365") } - executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } - targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID) - if err != nil { - return nil, err + targetUser, ok := account.Users[targetUserID] + if !ok { + return nil, status.Errorf(status.NotFound, "user not found") } - if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) || - executingUser.AccountID != accountID { + executingUser, ok := account.Users[initiatorUserID] + if !ok { + return nil, status.Errorf(status.NotFound, "user not found") + } + + if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) { return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user") } - pat, err := CreateNewPAT(tokenName, expiresIn, targetUser.Id, executingUser.Id) + pat, err := CreateNewPAT(tokenName, expiresIn, executingUser.Id) if err != nil { return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err) } - if err = am.Store.SavePAT(ctx, LockingStrengthUpdate, &pat.PersonalAccessToken); err != nil { - return nil, fmt.Errorf("failed to save PAT: %w", err) + targetUser.PATs[pat.ID] = &pat.PersonalAccessToken + + err = am.Store.SaveAccount(ctx, account) + if err != nil { + return nil, status.Errorf(status.Internal, "failed to save account: %v", err) } meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName} @@ -555,33 +596,51 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string // DeletePAT deletes a specific PAT from a user func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error { - executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { - return err + return status.Errorf(status.NotFound, "account not found: %s", err) } - targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID) - if err != nil { - return err + targetUser, ok := account.Users[targetUserID] + if !ok { + return status.Errorf(status.NotFound, "user not found") } - if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) || - executingUser.AccountID != accountID { + executingUser, ok := account.Users[initiatorUserID] + if !ok { + return status.Errorf(status.NotFound, "user not found") + } + + if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) { return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user") } - pat, err := am.Store.GetPATByID(ctx, LockingStrengthShare, tokenID, targetUserID) - if err != nil { - return err + pat := targetUser.PATs[tokenID] + if pat == nil { + return status.Errorf(status.NotFound, "PAT not found") } - if err = am.Store.DeletePAT(ctx, LockingStrengthUpdate, tokenID, targetUserID); err != nil { - return fmt.Errorf("failed to delete PAT: %w", err) + err = am.Store.DeleteTokenID2UserIDIndex(pat.ID) + if err != nil { + return status.Errorf(status.Internal, "Failed to delete token id index: %s", err) + } + err = am.Store.DeleteHashedPAT2TokenIDIndex(pat.HashedToken) + if err != nil { + return status.Errorf(status.Internal, "Failed to delete hashed token index: %s", err) } meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName} am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta) + delete(targetUser.PATs, tokenID) + + err = am.Store.SaveAccount(ctx, account) + if err != nil { + return status.Errorf(status.Internal, "Failed to save account: %s", err) + } return nil } @@ -592,11 +651,22 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i return nil, err } + targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID) + if err != nil { + return nil, err + } + if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user") } - return am.Store.GetPATByID(ctx, LockingStrengthShare, tokenID, targetUserID) + for _, pat := range targetUser.PATsG { + if pat.ID == tokenID { + return pat.Copy(), nil + } + } + + return nil, status.Errorf(status.NotFound, "PAT not found") } // GetAllPATs returns all PATs for a user @@ -606,15 +676,15 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin return nil, err } - if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user") - } - targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID) if err != nil { return nil, err } + if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID { + return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user") + } + pats := make([]*PersonalAccessToken, 0, len(targetUser.PATsG)) for _, pat := range targetUser.PATsG { pats = append(pats, pat.Copy()) @@ -635,6 +705,9 @@ func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, i return nil, status.Errorf(status.InvalidArgument, "provided user update is nil") } + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*User{update}, addIfNotExists) if err != nil { return nil, err @@ -655,12 +728,17 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, return nil, nil //nolint:nilnil } - initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } - if !initiatorUser.HasAdminPower() || initiatorUser.IsBlocked() || initiatorUser.AccountID != accountID { + initiatorUser, err := account.FindUser(initiatorUserID) + if err != nil { + return nil, err + } + + if !initiatorUser.HasAdminPower() || initiatorUser.IsBlocked() { return nil, status.Errorf(status.PermissionDenied, "only users with admin power are authorized to perform user update operations") } @@ -668,21 +746,15 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, var ( expiredPeers []*nbpeer.Peer eventsToStore []func() - usersToSave []*User ) - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) - if err != nil { - return nil, err - } - for _, update := range updates { if update == nil { return nil, status.Errorf(status.InvalidArgument, "provided user update is nil") } - oldUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, update.Id) - if err != nil { + oldUser := account.Users[update.Id] + if oldUser == nil { if !addIfNotExists { return nil, status.Errorf(status.NotFound, "user to update doesn't exist: %s", update.Id) } @@ -690,7 +762,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, oldUser = update } - if err := am.validateUserUpdate(ctx, accountID, initiatorUser, oldUser, update); err != nil { + if err := validateUserUpdate(account, initiatorUser, oldUser, update); err != nil { return nil, err } @@ -703,40 +775,29 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, newUser.Issued = update.Issued newUser.IntegrationReference = update.IntegrationReference - // handle owner role transfer - transferredOwnerRole := initiatorUser.Role == UserRoleOwner && initiatorUser.Id != update.Id && update.Role == UserRoleOwner - if transferredOwnerRole { - newInitiatorUser := initiatorUser.Copy() - newInitiatorUser.Role = UserRoleAdmin - - usersToSave = append(usersToSave, newInitiatorUser) - } - usersToSave = append(usersToSave, newUser) + transferredOwnerRole := handleOwnerRoleTransfer(account, initiatorUser, update) + account.Users[newUser.Id] = newUser if !oldUser.IsBlocked() && update.IsBlocked() { // expire peers that belong to the user who's getting blocked - blockedPeers, err := am.Store.GetUserPeers(ctx, LockingStrengthShare, update.Id, accountID) + blockedPeers, err := account.FindUserPeers(update.Id) if err != nil { return nil, err } expiredPeers = append(expiredPeers, blockedPeers...) } - if update.AutoGroups != nil && settings.GroupsPropagationEnabled { - //removedGroups := difference(oldUser.AutoGroups, update.AutoGroups) + if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled { + removedGroups := difference(oldUser.AutoGroups, update.AutoGroups) // need force update all auto groups in any case they will not be duplicated - - //TODO: wraps this in a transaction - - //account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...) - //account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...) - + account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...) + account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...) } - events := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, accountID, transferredOwnerRole) + events := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole) eventsToStore = append(eventsToStore, events...) - updatedUserInfo, err := getUserInfo(ctx, am, newUser, accountID) + updatedUserInfo, err := getUserInfo(ctx, am, newUser, account) if err != nil { return nil, err } @@ -744,63 +805,40 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } if len(expiredPeers) > 0 { - if err := am.expireAndUpdatePeers(ctx, accountID, expiredPeers); err != nil { + if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil { log.WithContext(ctx).Errorf("failed update expired peers: %s", err) return nil, err } } - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { - return fmt.Errorf("failed to increment network serial: %w", err) - } - - //TODO: update groups with new members - - if err = transaction.SaveUsers(ctx, LockingStrengthUpdate, usersToSave); err != nil { - return fmt.Errorf("failed to save users: %w", err) - } - - return nil - }) - if err != nil { + account.Network.IncSerial() + if err = am.Store.SaveAccount(ctx, account); err != nil { return nil, err } + if account.Settings.GroupsPropagationEnabled { + am.updateAccountPeers(ctx, account) + } + for _, storeEvent := range eventsToStore { storeEvent() } - if settings.GroupsPropagationEnabled { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, fmt.Errorf("error getting account: %w", err) - } - am.updateAccountPeers(ctx, account) - } - return updatedUsers, nil } -// propagateAutoGroupChangesForUser updates the user's auto-groups. -// If group propagation is enabled, it adds or removes groups from -// the peers owned by the user based on changes in their group assignments. -func (am *DefaultAccountManager) propagateAutoGroupChangesForUser(ctx context.Context, oldUser, updatedUser *User) []*nbgroup.Group { - return nil -} - // prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data. -func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, accountID string, transferredOwnerRole bool) []func() { +func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, transferredOwnerRole bool) []func() { var eventsToStore []func() if oldUser.IsBlocked() != newUser.IsBlocked() { if newUser.IsBlocked() { eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserBlocked, nil) + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserBlocked, nil) }) } else { eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserUnblocked, nil) + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserUnblocked, nil) }) } } @@ -808,11 +846,11 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, in switch { case transferredOwnerRole: eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.TransferredOwnerRole, nil) + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.TransferredOwnerRole, nil) }) case oldUser.Role != newUser.Role: eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role}) + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserRoleUpdated, map[string]any{"role": newUser.Role}) }) } @@ -820,35 +858,23 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, in removedGroups := difference(oldUser.AutoGroups, newUser.AutoGroups) addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups) for _, g := range removedGroups { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) - if err != nil { - log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, accountID) - } else { + group := account.GetGroup(g) + if group != nil { eventsToStore = append(eventsToStore, func() { - meta := map[string]any{ - "group": group.Name, - "group_id": group.ID, - "is_service_user": newUser.IsServiceUser, - "user_name": newUser.ServiceUserName, - } - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupRemovedFromUser, meta) + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupRemovedFromUser, + map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) }) + + } else { + log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id) } } - for _, g := range addedGroups { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) - if err != nil { - log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, accountID) - } else { + group := account.GetGroup(g) + if group != nil { eventsToStore = append(eventsToStore, func() { - meta := map[string]any{ - "group": group.Name, - "group_id": group.ID, - "is_service_user": newUser.IsServiceUser, - "user_name": newUser.ServiceUserName, - } - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupAddedToUser, meta) + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser, + map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) }) } } @@ -857,27 +883,32 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, in return eventsToStore } +func handleOwnerRoleTransfer(account *Account, initiatorUser, update *User) bool { + if initiatorUser.Role == UserRoleOwner && initiatorUser.Id != update.Id && update.Role == UserRoleOwner { + newInitiatorUser := initiatorUser.Copy() + newInitiatorUser.Role = UserRoleAdmin + account.Users[initiatorUser.Id] = newInitiatorUser + return true + } + return false +} + // getUserInfo retrieves the UserInfo for a given User and Account. // If the AccountManager has a non-nil idpManager and the User is not a service user, // it will attempt to look up the UserData from the cache. -func getUserInfo(ctx context.Context, am *DefaultAccountManager, user *User, accountID string) (*UserInfo, error) { - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) - if err != nil { - return nil, err - } - +func getUserInfo(ctx context.Context, am *DefaultAccountManager, user *User, account *Account) (*UserInfo, error) { if !isNil(am.idpManager) && !user.IsServiceUser { - userData, err := am.lookupUserInCache(ctx, user.Id, accountID) + userData, err := am.lookupUserInCache(ctx, user.Id, account) if err != nil { return nil, err } - return user.ToUserInfo(userData, settings) + return user.ToUserInfo(userData, account.Settings) } - return user.ToUserInfo(nil, settings) + return user.ToUserInfo(nil, account.Settings) } // validateUserUpdate validates the update operation for a user. -func (am *DefaultAccountManager) validateUserUpdate(ctx context.Context, accountID string, initiatorUser, oldUser, update *User) error { +func validateUserUpdate(account *Account, initiatorUser, oldUser, update *User) error { if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked { return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves") } @@ -898,12 +929,11 @@ func (am *DefaultAccountManager) validateUserUpdate(ctx context.Context, account } for _, newGroupID := range update.AutoGroups { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, newGroupID, accountID) - if err != nil { + group, ok := account.Groups[newGroupID] + if !ok { return status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist", newGroupID, update.Id) } - if group.Name == "All" { return status.Errorf(status.InvalidArgument, "can't add All group to the user") } @@ -954,26 +984,21 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, u // GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return // based on provided user role. func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } - if user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "no permission to get users") - } - - accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID) + user, err := account.FindUser(userID) if err != nil { return nil, err } queriedUsers := make([]*idp.UserData, 0) if !isNil(am.idpManager) { - users := make(map[string]userLoggedInOnce, len(accountUsers)) - + users := make(map[string]userLoggedInOnce, len(account.Users)) usersFromIntegration := make([]*idp.UserData, 0) - for _, user := range accountUsers { + for _, user := range account.Users { if user.Issued == UserIssuedIntegration { key := user.IntegrationReference.CacheKey(accountID, user.Id) info, err := am.externalCacheManager.Get(am.ctx, key) @@ -998,21 +1023,16 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun queriedUsers = append(queriedUsers, usersFromIntegration...) } - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) - if err != nil { - return nil, err - } - userInfos := make([]*UserInfo, 0) // in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo if len(queriedUsers) == 0 { - for _, accountUser := range accountUsers { + for _, accountUser := range account.Users { if !(user.HasAdminPower() || user.IsServiceUser || user.Id == accountUser.Id) { // if user is not an admin then show only current user and do not show other users continue } - info, err := accountUser.ToUserInfo(nil, settings) + info, err := accountUser.ToUserInfo(nil, account.Settings) if err != nil { return nil, err } @@ -1021,7 +1041,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun return userInfos, nil } - for _, localUser := range accountUsers { + for _, localUser := range account.Users { if !(user.HasAdminPower() || user.IsServiceUser) && user.Id != localUser.Id { // if user is not an admin then show only current user and do not show other users continue @@ -1029,7 +1049,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun var info *UserInfo if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains { - info, err = localUser.ToUserInfo(queriedUser, settings) + info, err = localUser.ToUserInfo(queriedUser, account.Settings) if err != nil { return nil, err } @@ -1042,7 +1062,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun dashboardViewPermissions := "full" if !localUser.HasAdminPower() { dashboardViewPermissions = "limited" - if settings.RegularUsersViewBlocked { + if account.Settings.RegularUsersViewBlocked { dashboardViewPermissions = "blocked" } } @@ -1066,7 +1086,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun } // expireAndUpdatePeers expires all peers of the given user and updates them in the account -func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accountID string, peers []*nbpeer.Peer) error { +func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, account *Account, peers []*nbpeer.Peer) error { var peerIDs []string for _, peer := range peers { if peer.Status.LoginExpired { @@ -1074,13 +1094,13 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou } peerIDs = append(peerIDs, peer.ID) peer.MarkLoginExpired(true) - - if err := am.Store.SavePeerStatus(accountID, peer.ID, *peer.Status); err != nil { + account.UpdatePeer(peer) + if err := am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status); err != nil { return err } am.StoreEvent( ctx, - peer.UserID, peer.ID, accountID, + peer.UserID, peer.ID, account.Id, activity.PeerLoginExpired, peer.EventMeta(am.GetDNSDomain()), ) } @@ -1088,10 +1108,6 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service am.peersUpdateManager.CloseChannels(ctx, peerIDs) - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf("error getting account: %w", err) - } am.updateAccountPeers(ctx, account) } return nil @@ -1241,6 +1257,74 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, nil } +// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them. +func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[string]*nbgroup.Group, peers []*nbpeer.Peer, groupsToAdd, + groupsToRemove []string) (groupsToUpdate []*nbgroup.Group, err error) { + + if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { + return + } + + userPeerIDMap := make(map[string]struct{}, len(peers)) + for _, peer := range peers { + userPeerIDMap[peer.ID] = struct{}{} + } + + for _, gid := range groupsToAdd { + group, ok := accountGroups[gid] + if !ok { + return nil, errors.New("group not found") + } + addUserPeersToGroup(userPeerIDMap, group) + groupsToUpdate = append(groupsToUpdate, group) + } + + for _, gid := range groupsToRemove { + group, ok := accountGroups[gid] + if !ok { + return nil, errors.New("group not found") + } + removeUserPeersFromGroup(userPeerIDMap, group) + groupsToUpdate = append(groupsToUpdate, group) + } + + return groupsToUpdate, nil +} + +// addUserPeersToGroup adds the user's peers to the group. +func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) { + groupPeers := make(map[string]struct{}, len(group.Peers)) + for _, pid := range group.Peers { + groupPeers[pid] = struct{}{} + } + + for pid := range userPeerIDs { + groupPeers[pid] = struct{}{} + } + + group.Peers = make([]string, 0, len(groupPeers)) + for pid := range groupPeers { + group.Peers = append(group.Peers, pid) + } +} + +// removeUserPeersFromGroup removes user's peers from the group. +func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) { + // skip removing peers from group All + if group.Name == "All" { + return + } + + updatedPeers := make([]string, 0, len(group.Peers)) + for _, pid := range group.Peers { + if _, found := userPeerIDs[pid]; !found { + updatedPeers = append(updatedPeers, pid) + } + } + + group.Peers = updatedPeers +} + func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) { for _, user := range userData { if user.ID == userID { diff --git a/management/server/user_test.go b/management/server/user_test.go index e394ef840..1a5704551 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -59,8 +59,10 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) { assert.Equal(t, pat.CreatedBy, mockUserID) - fileStore := am.Store.(*FileStore) - tokenID := fileStore.HashedPAT2TokenID[pat.HashedToken] + tokenID, err := am.Store.GetTokenIDByHashedToken(context.Background(), pat.HashedToken) + if err != nil { + t.Fatalf("Error when getting token ID by hashed token: %s", err) + } if tokenID == "" { t.Fatal("GetTokenIDByHashedToken failed after adding PAT") @@ -68,11 +70,12 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) { assert.Equal(t, pat.ID, tokenID) - userID := fileStore.TokenID2UserID[tokenID] - if userID == "" { - t.Fatal("GetUserByTokenId failed after adding PAT") + user, err := am.Store.GetUserByTokenID(context.Background(), tokenID) + if err != nil { + t.Fatalf("Error when getting user by token ID: %s", err) } - assert.Equal(t, mockUserID, userID) + + assert.Equal(t, mockUserID, user.Id) } func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { @@ -189,9 +192,12 @@ func TestUser_DeletePAT(t *testing.T) { t.Fatalf("Error when adding PAT to user: %s", err) } - assert.Nil(t, store.Accounts[mockAccountID].Users[mockUserID].PATs[mockTokenID1]) - assert.Empty(t, store.HashedPAT2TokenID[mockToken1]) - assert.Empty(t, store.TokenID2UserID[mockTokenID1]) + account, err = store.GetAccount(context.Background(), mockAccountID) + if err != nil { + t.Fatalf("Error when getting account: %s", err) + } + + assert.Nil(t, account.Users[mockUserID].PATs[mockTokenID1]) } func TestUser_GetPAT(t *testing.T) { @@ -350,13 +356,16 @@ func TestUser_CreateServiceUser(t *testing.T) { t.Fatalf("Error when creating service user: %s", err) } - assert.Equal(t, 2, len(store.Accounts[mockAccountID].Users)) - assert.NotNil(t, store.Accounts[mockAccountID].Users[user.ID]) - assert.True(t, store.Accounts[mockAccountID].Users[user.ID].IsServiceUser) - assert.Equal(t, mockServiceUserName, store.Accounts[mockAccountID].Users[user.ID].ServiceUserName) - assert.Equal(t, UserRole(mockRole), store.Accounts[mockAccountID].Users[user.ID].Role) - assert.Equal(t, []string{"group1", "group2"}, store.Accounts[mockAccountID].Users[user.ID].AutoGroups) - assert.Equal(t, map[string]*PersonalAccessToken{}, store.Accounts[mockAccountID].Users[user.ID].PATs) + account, err = store.GetAccount(context.Background(), mockAccountID) + assert.NoError(t, err) + + assert.Equal(t, 2, len(account.Users)) + assert.NotNil(t, account.Users[user.ID]) + assert.True(t, account.Users[user.ID].IsServiceUser) + assert.Equal(t, mockServiceUserName, account.Users[user.ID].ServiceUserName) + assert.Equal(t, UserRole(mockRole), account.Users[user.ID].Role) + assert.Equal(t, []string{"group1", "group2"}, account.Users[user.ID].AutoGroups) + assert.Equal(t, map[string]*PersonalAccessToken{}, account.Users[user.ID].PATs) assert.Zero(t, user.Email) assert.True(t, user.IsServiceUser) @@ -394,12 +403,15 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) { t.Fatalf("Error when creating user: %s", err) } + account, err = store.GetAccount(context.Background(), mockAccountID) + assert.NoError(t, err) + assert.True(t, user.IsServiceUser) - assert.Equal(t, 2, len(store.Accounts[mockAccountID].Users)) - assert.True(t, store.Accounts[mockAccountID].Users[user.ID].IsServiceUser) - assert.Equal(t, mockServiceUserName, store.Accounts[mockAccountID].Users[user.ID].ServiceUserName) - assert.Equal(t, UserRole(mockRole), store.Accounts[mockAccountID].Users[user.ID].Role) - assert.Equal(t, []string{"group1", "group2"}, store.Accounts[mockAccountID].Users[user.ID].AutoGroups) + assert.Equal(t, 2, len(account.Users)) + assert.True(t, account.Users[user.ID].IsServiceUser) + assert.Equal(t, mockServiceUserName, account.Users[user.ID].ServiceUserName) + assert.Equal(t, UserRole(mockRole), account.Users[user.ID].Role) + assert.Equal(t, []string{"group1", "group2"}, account.Users[user.ID].AutoGroups) assert.Equal(t, mockServiceUserName, user.Name) assert.Equal(t, mockRole, user.Role) @@ -550,12 +562,15 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { err = am.DeleteUser(context.Background(), mockAccountID, mockUserID, mockServiceUserID) tt.assertErrFunc(t, err, tt.assertErrMessage) + account, err2 := store.GetAccount(context.Background(), mockAccountID) + assert.NoError(t, err2) + if err != nil { - assert.Equal(t, 2, len(store.Accounts[mockAccountID].Users)) - assert.NotNil(t, store.Accounts[mockAccountID].Users[mockServiceUserID]) + assert.Equal(t, 2, len(account.Users)) + assert.NotNil(t, account.Users[mockServiceUserID]) } else { - assert.Equal(t, 1, len(store.Accounts[mockAccountID].Users)) - assert.Nil(t, store.Accounts[mockAccountID].Users[mockServiceUserID]) + assert.Equal(t, 1, len(account.Users)) + assert.Nil(t, account.Users[mockServiceUserID]) } }) } @@ -798,10 +813,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) { assert.NoError(t, err) } - accID, err := am.GetAccountIDByUserOrAccountID(context.Background(), "", account.Id, "") - assert.NoError(t, err) - - acc, err := am.Store.GetAccount(context.Background(), accID) + acc, err := am.Store.GetAccount(context.Background(), account.Id) assert.NoError(t, err) for _, id := range tc.expectedDeleted { diff --git a/relay/metrics/realy.go b/relay/metrics/realy.go index 13799713a..4dc98a0e0 100644 --- a/relay/metrics/realy.go +++ b/relay/metrics/realy.go @@ -16,8 +16,10 @@ const ( type Metrics struct { metric.Meter - TransferBytesSent metric.Int64Counter - TransferBytesRecv metric.Int64Counter + TransferBytesSent metric.Int64Counter + TransferBytesRecv metric.Int64Counter + AuthenticationTime metric.Float64Histogram + PeerStoreTime metric.Float64Histogram peers metric.Int64UpDownCounter peerActivityChan chan string @@ -52,11 +54,23 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) { return nil, err } + authTime, err := meter.Float64Histogram("relay_peer_authentication_time_milliseconds", metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...)) + if err != nil { + return nil, err + } + + peerStoreTime, err := meter.Float64Histogram("relay_peer_store_time_milliseconds", metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...)) + if err != nil { + return nil, err + } + m := &Metrics{ - Meter: meter, - TransferBytesSent: bytesSent, - TransferBytesRecv: bytesRecv, - peers: peers, + Meter: meter, + TransferBytesSent: bytesSent, + TransferBytesRecv: bytesRecv, + AuthenticationTime: authTime, + PeerStoreTime: peerStoreTime, + peers: peers, ctx: ctx, peerActivityChan: make(chan string, 10), @@ -89,6 +103,16 @@ func (m *Metrics) PeerConnected(id string) { m.peerLastActive[id] = time.Time{} } +// RecordAuthenticationTime measures the time taken for peer authentication +func (m *Metrics) RecordAuthenticationTime(duration time.Duration) { + m.AuthenticationTime.Record(m.ctx, float64(duration.Nanoseconds())/1e6) +} + +// RecordPeerStoreTime measures the time to store the peer in map +func (m *Metrics) RecordPeerStoreTime(duration time.Duration) { + m.PeerStoreTime.Record(m.ctx, float64(duration.Nanoseconds())/1e6) +} + // PeerDisconnected decrements the number of connected peers and decrements number of idle or active connections func (m *Metrics) PeerDisconnected(id string) { m.peers.Add(m.ctx, -1) @@ -134,3 +158,19 @@ func (m *Metrics) readPeerActivity() { } } } + +func getStandardBucketBoundaries() []float64 { + return []float64{ + 0.1, + 0.5, + 1, + 5, + 10, + 50, + 100, + 500, + 1000, + 5000, + 10000, + } +} diff --git a/relay/server/handshake.go b/relay/server/handshake.go new file mode 100644 index 000000000..0257300f8 --- /dev/null +++ b/relay/server/handshake.go @@ -0,0 +1,153 @@ +package server + +import ( + "fmt" + "net" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/relay/auth" + "github.com/netbirdio/netbird/relay/messages" + //nolint:staticcheck + "github.com/netbirdio/netbird/relay/messages/address" + //nolint:staticcheck + authmsg "github.com/netbirdio/netbird/relay/messages/auth" +) + +// preparedMsg contains the marshalled success response messages +type preparedMsg struct { + responseHelloMsg []byte + responseAuthMsg []byte +} + +func newPreparedMsg(instanceURL string) (*preparedMsg, error) { + rhm, err := marshalResponseHelloMsg(instanceURL) + if err != nil { + return nil, err + } + + ram, err := messages.MarshalAuthResponse(instanceURL) + if err != nil { + return nil, fmt.Errorf("failed to marshal auth response msg: %w", err) + } + + return &preparedMsg{ + responseHelloMsg: rhm, + responseAuthMsg: ram, + }, nil +} + +func marshalResponseHelloMsg(instanceURL string) ([]byte, error) { + addr := &address.Address{URL: instanceURL} + addrData, err := addr.Marshal() + if err != nil { + return nil, fmt.Errorf("failed to marshal response address: %w", err) + } + + //nolint:staticcheck + responseMsg, err := messages.MarshalHelloResponse(addrData) + if err != nil { + return nil, fmt.Errorf("failed to marshal hello response: %w", err) + } + return responseMsg, nil +} + +type handshake struct { + conn net.Conn + validator auth.Validator + preparedMsg *preparedMsg + + handshakeMethodAuth bool + peerID string +} + +func (h *handshake) handshakeReceive() ([]byte, error) { + buf := make([]byte, messages.MaxHandshakeSize) + n, err := h.conn.Read(buf) + if err != nil { + return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err) + } + + _, err = messages.ValidateVersion(buf[:n]) + if err != nil { + return nil, fmt.Errorf("validate version from %s: %w", h.conn.RemoteAddr(), err) + } + + msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n]) + if err != nil { + return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err) + } + + var ( + bytePeerID []byte + peerID string + ) + switch msgType { + //nolint:staticcheck + case messages.MsgTypeHello: + bytePeerID, peerID, err = h.handleHelloMsg(buf[messages.SizeOfProtoHeader:n]) + case messages.MsgTypeAuth: + h.handshakeMethodAuth = true + bytePeerID, peerID, err = h.handleAuthMsg(buf[messages.SizeOfProtoHeader:n]) + default: + return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr()) + } + if err != nil { + return nil, err + } + h.peerID = peerID + return bytePeerID, nil +} + +func (h *handshake) handshakeResponse() error { + var responseMsg []byte + if h.handshakeMethodAuth { + responseMsg = h.preparedMsg.responseAuthMsg + } else { + responseMsg = h.preparedMsg.responseHelloMsg + } + + if _, err := h.conn.Write(responseMsg); err != nil { + return fmt.Errorf("handshake response write to %s (%s): %w", h.peerID, h.conn.RemoteAddr(), err) + } + + return nil +} + +func (h *handshake) handleHelloMsg(buf []byte) ([]byte, string, error) { + //nolint:staticcheck + rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf) + if err != nil { + return nil, "", fmt.Errorf("unmarshal hello message: %w", err) + } + + peerID := messages.HashIDToString(rawPeerID) + log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, h.conn.RemoteAddr()) + + authMsg, err := authmsg.UnmarshalMsg(authData) + if err != nil { + return nil, "", fmt.Errorf("unmarshal auth message: %w", err) + } + + //nolint:staticcheck + if err := h.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil { + return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err) + } + + return rawPeerID, peerID, nil +} + +func (h *handshake) handleAuthMsg(buf []byte) ([]byte, string, error) { + rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf) + if err != nil { + return nil, "", fmt.Errorf("unmarshal hello message: %w", err) + } + + peerID := messages.HashIDToString(rawPeerID) + + if err := h.validator.Validate(authPayload); err != nil { + return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err) + } + + return rawPeerID, peerID, nil +} diff --git a/relay/server/relay.go b/relay/server/relay.go index 76c01a697..6cd8506ae 100644 --- a/relay/server/relay.go +++ b/relay/server/relay.go @@ -7,16 +7,13 @@ import ( "net/url" "strings" "sync" + "time" log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/metric" "github.com/netbirdio/netbird/relay/auth" - "github.com/netbirdio/netbird/relay/messages" //nolint:staticcheck - "github.com/netbirdio/netbird/relay/messages/address" - //nolint:staticcheck - authmsg "github.com/netbirdio/netbird/relay/messages/auth" "github.com/netbirdio/netbird/relay/metrics" ) @@ -28,6 +25,7 @@ type Relay struct { store *Store instanceURL string + preparedMsg *preparedMsg closed bool closeMu sync.RWMutex @@ -69,6 +67,12 @@ func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, valida return nil, fmt.Errorf("get instance URL: %v", err) } + r.preparedMsg, err = newPreparedMsg(r.instanceURL) + if err != nil { + metricsCancel() + return nil, fmt.Errorf("prepare message: %v", err) + } + return r, nil } @@ -100,17 +104,22 @@ func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) { // Accept start to handle a new peer connection func (r *Relay) Accept(conn net.Conn) { + acceptTime := time.Now() r.closeMu.RLock() defer r.closeMu.RUnlock() if r.closed { return } - peerID, err := r.handshake(conn) + h := handshake{ + conn: conn, + validator: r.validator, + preparedMsg: r.preparedMsg, + } + peerID, err := h.handshakeReceive() if err != nil { log.Errorf("failed to handshake: %s", err) - cErr := conn.Close() - if cErr != nil { + if cErr := conn.Close(); cErr != nil { log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr) } return @@ -118,7 +127,9 @@ func (r *Relay) Accept(conn net.Conn) { peer := NewPeer(r.metrics, peerID, conn, r.store) peer.log.Infof("peer connected from: %s", conn.RemoteAddr()) + storeTime := time.Now() r.store.AddPeer(peer) + r.metrics.RecordPeerStoreTime(time.Since(storeTime)) r.metrics.PeerConnected(peer.String()) go func() { peer.Work() @@ -126,6 +137,12 @@ func (r *Relay) Accept(conn net.Conn) { peer.log.Debugf("relay connection closed") r.metrics.PeerDisconnected(peer.String()) }() + + if err := h.handshakeResponse(); err != nil { + log.Errorf("failed to send handshake response, close peer: %s", err) + peer.Close() + } + r.metrics.RecordAuthenticationTime(time.Since(acceptTime)) } // Shutdown closes the relay server @@ -151,99 +168,3 @@ func (r *Relay) Shutdown(ctx context.Context) { func (r *Relay) InstanceURL() string { return r.instanceURL } - -func (r *Relay) handshake(conn net.Conn) ([]byte, error) { - buf := make([]byte, messages.MaxHandshakeSize) - n, err := conn.Read(buf) - if err != nil { - return nil, fmt.Errorf("read from %s: %w", conn.RemoteAddr(), err) - } - - _, err = messages.ValidateVersion(buf[:n]) - if err != nil { - return nil, fmt.Errorf("validate version from %s: %w", conn.RemoteAddr(), err) - } - - msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n]) - if err != nil { - return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err) - } - - var ( - responseMsg []byte - peerID []byte - ) - switch msgType { - //nolint:staticcheck - case messages.MsgTypeHello: - peerID, responseMsg, err = r.handleHelloMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr()) - case messages.MsgTypeAuth: - peerID, responseMsg, err = r.handleAuthMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr()) - default: - return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr()) - } - if err != nil { - return nil, err - } - - _, err = conn.Write(responseMsg) - if err != nil { - return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err) - } - - return peerID, nil -} - -func (r *Relay) handleHelloMsg(buf []byte, remoteAddr net.Addr) ([]byte, []byte, error) { - //nolint:staticcheck - rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf) - if err != nil { - return nil, nil, fmt.Errorf("unmarshal hello message: %w", err) - } - - peerID := messages.HashIDToString(rawPeerID) - log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, remoteAddr) - - authMsg, err := authmsg.UnmarshalMsg(authData) - if err != nil { - return nil, nil, fmt.Errorf("unmarshal auth message: %w", err) - } - - //nolint:staticcheck - if err := r.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil { - return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, remoteAddr, err) - } - - addr := &address.Address{URL: r.instanceURL} - addrData, err := addr.Marshal() - if err != nil { - return nil, nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, remoteAddr, err) - } - - //nolint:staticcheck - responseMsg, err := messages.MarshalHelloResponse(addrData) - if err != nil { - return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, remoteAddr, err) - } - return rawPeerID, responseMsg, nil -} - -func (r *Relay) handleAuthMsg(buf []byte, addr net.Addr) ([]byte, []byte, error) { - rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf) - if err != nil { - return nil, nil, fmt.Errorf("unmarshal hello message: %w", err) - } - - peerID := messages.HashIDToString(rawPeerID) - - if err := r.validator.Validate(authPayload); err != nil { - return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, addr, err) - } - - responseMsg, err := messages.MarshalAuthResponse(r.instanceURL) - if err != nil { - return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, addr, err) - } - - return rawPeerID, responseMsg, nil -} diff --git a/release_files/install.sh b/release_files/install.sh index d6aabebd8..b7a6c08f9 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -21,6 +21,8 @@ SUDO="" if command -v sudo > /dev/null && [ "$(id -u)" -ne 0 ]; then SUDO="sudo" +elif command -v doas > /dev/null && [ "$(id -u)" -ne 0 ]; then + SUDO="doas" fi if [ -z ${NETBIRD_RELEASE+x} ]; then @@ -31,14 +33,16 @@ get_release() { local RELEASE=$1 if [ "$RELEASE" = "latest" ]; then local TAG="latest" + local URL="https://pkgs.netbird.io/releases/latest" else local TAG="tags/${RELEASE}" + local URL="https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}" fi if [ -n "$GITHUB_TOKEN" ]; then - curl -H "Authorization: token ${GITHUB_TOKEN}" -s "https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}" \ + curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}" \ | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' else - curl -s "https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}" \ + curl -s "${URL}" \ | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' fi } @@ -68,7 +72,7 @@ download_release_binary() { if [ -n "$GITHUB_TOKEN" ]; then cd /tmp && curl -H "Authorization: token ${GITHUB_TOKEN}" -LO "$DOWNLOAD_URL" else - cd /tmp && curl -LO "$DOWNLOAD_URL" + cd /tmp && curl -LO "$DOWNLOAD_URL" || curl -LO --dns-servers 8.8.8.8 "$DOWNLOAD_URL" fi @@ -316,7 +320,7 @@ install_netbird() { } version_greater_equal() { - printf '%s\n%s\n' "$2" "$1" | sort -V -C + printf '%s\n%s\n' "$2" "$1" | sort -V -c } is_bin_package_manager() { diff --git a/route/route.go b/route/route.go index eb6c36bd8..e23801e6e 100644 --- a/route/route.go +++ b/route/route.go @@ -100,6 +100,7 @@ type Route struct { Metric int Enabled bool Groups []string `gorm:"serializer:json"` + AccessControlGroups []string `gorm:"serializer:json"` } // EventMeta returns activity event meta related to the route @@ -123,6 +124,7 @@ func (r *Route) Copy() *Route { Masquerade: r.Masquerade, Enabled: r.Enabled, Groups: slices.Clone(r.Groups), + AccessControlGroups: slices.Clone(r.AccessControlGroups), } return route } @@ -147,7 +149,8 @@ func (r *Route) IsEqual(other *Route) bool { other.Masquerade == r.Masquerade && other.Enabled == r.Enabled && slices.Equal(r.Groups, other.Groups) && - slices.Equal(r.PeerGroups, other.PeerGroups) + slices.Equal(r.PeerGroups, other.PeerGroups)&& + slices.Equal(r.AccessControlGroups, other.AccessControlGroups) } // IsDynamic returns if the route is dynamic, i.e. has domains diff --git a/signal/client/client_test.go b/signal/client/client_test.go index 2525493b4..f7d4ebc50 100644 --- a/signal/client/client_test.go +++ b/signal/client/client_test.go @@ -199,7 +199,7 @@ func startSignal() (*grpc.Server, net.Listener) { panic(err) } s := grpc.NewServer() - srv, err := server.NewServer(otel.Meter("")) + srv, err := server.NewServer(context.Background(), otel.Meter("")) if err != nil { panic(err) } diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 0bdc62ead..1bb2f1d0c 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -102,7 +102,7 @@ var ( } }() - srv, err := server.NewServer(metricsServer.Meter) + srv, err := server.NewServer(cmd.Context(), metricsServer.Meter) if err != nil { return fmt.Errorf("creating signal server: %v", err) } diff --git a/signal/server/signal.go b/signal/server/signal.go index b268aa3fc..305fd052b 100644 --- a/signal/server/signal.go +++ b/signal/server/signal.go @@ -6,6 +6,7 @@ import ( "io" "time" + "github.com/netbirdio/signal-dispatcher/dispatcher" log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" @@ -13,8 +14,6 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" - "github.com/netbirdio/signal-dispatcher/dispatcher" - "github.com/netbirdio/netbird/signal/metrics" "github.com/netbirdio/netbird/signal/peer" "github.com/netbirdio/netbird/signal/proto" @@ -47,13 +46,13 @@ type Server struct { } // NewServer creates a new Signal server -func NewServer(meter metric.Meter) (*Server, error) { +func NewServer(ctx context.Context, meter metric.Meter) (*Server, error) { appMetrics, err := metrics.NewAppMetrics(meter) if err != nil { return nil, fmt.Errorf("creating app metrics: %v", err) } - dispatcher, err := dispatcher.NewDispatcher() + dispatcher, err := dispatcher.NewDispatcher(ctx, meter) if err != nil { return nil, fmt.Errorf("creating dispatcher: %v", err) } @@ -71,11 +70,6 @@ func NewServer(meter metric.Meter) (*Server, error) { func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { log.Debugf("received a new message to send from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey) - if msg.RemoteKey == "dummy" { - // Test message send during netbird status - return &proto.EncryptedMessage{}, nil - } - if _, found := s.registry.Get(msg.RemoteKey); found { s.forwardMessageToPeer(ctx, msg) return &proto.EncryptedMessage{}, nil diff --git a/util/file.go b/util/file.go index 8355488c9..ecaecd222 100644 --- a/util/file.go +++ b/util/file.go @@ -1,11 +1,15 @@ package util import ( + "bytes" "context" "encoding/json" + "fmt" "io" "os" "path/filepath" + "strings" + "text/template" log "github.com/sirupsen/logrus" ) @@ -160,6 +164,55 @@ func ReadJson(file string, res interface{}) (interface{}, error) { return res, nil } +// ReadJsonWithEnvSub reads JSON config file and maps to a provided interface with environment variable substitution +func ReadJsonWithEnvSub(file string, res interface{}) (interface{}, error) { + envVars := getEnvMap() + + f, err := os.Open(file) + if err != nil { + return nil, err + } + defer f.Close() + + bs, err := io.ReadAll(f) + if err != nil { + return nil, err + } + + t, err := template.New("").Parse(string(bs)) + if err != nil { + return nil, fmt.Errorf("error parsing template: %v", err) + } + + var output bytes.Buffer + // Execute the template, substituting environment variables + err = t.Execute(&output, envVars) + if err != nil { + return nil, fmt.Errorf("error executing template: %v", err) + } + + err = json.Unmarshal(output.Bytes(), &res) + if err != nil { + return nil, fmt.Errorf("failed parsing Json file after template was executed, err: %v", err) + } + + return res, nil +} + +// getEnvMap Convert the output of os.Environ() to a map +func getEnvMap() map[string]string { + envMap := make(map[string]string) + + for _, env := range os.Environ() { + parts := strings.SplitN(env, "=", 2) + if len(parts) == 2 { + envMap[parts[0]] = parts[1] + } + } + + return envMap +} + // CopyFileContents copies contents of the given src file to the dst file func CopyFileContents(src, dst string) (err error) { in, err := os.Open(src) diff --git a/util/file_suite_test.go b/util/file_suite_test.go new file mode 100644 index 000000000..3de7db49b --- /dev/null +++ b/util/file_suite_test.go @@ -0,0 +1,126 @@ +package util_test + +import ( + "crypto/md5" + "encoding/hex" + "io" + "os" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "github.com/netbirdio/netbird/util" +) + +var _ = Describe("Client", func() { + + var ( + tmpDir string + ) + + type TestConfig struct { + SomeMap map[string]string + SomeArray []string + SomeField int + } + + BeforeEach(func() { + var err error + tmpDir, err = os.MkdirTemp("", "wiretrustee_util_test_tmp_*") + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + err := os.RemoveAll(tmpDir) + Expect(err).NotTo(HaveOccurred()) + }) + + Describe("Config", func() { + Context("in JSON format", func() { + It("should be written and read successfully", func() { + + m := make(map[string]string) + m["key1"] = "value1" + m["key2"] = "value2" + + arr := []string{"value1", "value2"} + + written := &TestConfig{ + SomeMap: m, + SomeArray: arr, + SomeField: 99, + } + + err := util.WriteJson(tmpDir+"/testconfig.json", written) + Expect(err).NotTo(HaveOccurred()) + + read, err := util.ReadJson(tmpDir+"/testconfig.json", &TestConfig{}) + Expect(err).NotTo(HaveOccurred()) + Expect(read).NotTo(BeNil()) + Expect(read.(*TestConfig).SomeMap["key1"]).To(BeEquivalentTo(written.SomeMap["key1"])) + Expect(read.(*TestConfig).SomeMap["key2"]).To(BeEquivalentTo(written.SomeMap["key2"])) + Expect(read.(*TestConfig).SomeArray).To(ContainElements(arr)) + Expect(read.(*TestConfig).SomeField).To(BeEquivalentTo(written.SomeField)) + + }) + }) + }) + + Describe("Copying file contents", func() { + Context("from one file to another", func() { + It("should be successful", func() { + + src := tmpDir + "/copytest_src" + dst := tmpDir + "/copytest_dst" + + err := util.WriteJson(src, []string{"1", "2", "3"}) + Expect(err).NotTo(HaveOccurred()) + + err = util.CopyFileContents(src, dst) + Expect(err).NotTo(HaveOccurred()) + + hashSrc := md5.New() + hashDst := md5.New() + + srcFile, err := os.Open(src) + Expect(err).NotTo(HaveOccurred()) + + dstFile, err := os.Open(dst) + Expect(err).NotTo(HaveOccurred()) + + _, err = io.Copy(hashSrc, srcFile) + Expect(err).NotTo(HaveOccurred()) + + _, err = io.Copy(hashDst, dstFile) + Expect(err).NotTo(HaveOccurred()) + + err = srcFile.Close() + Expect(err).NotTo(HaveOccurred()) + + err = dstFile.Close() + Expect(err).NotTo(HaveOccurred()) + + Expect(hex.EncodeToString(hashSrc.Sum(nil)[:16])).To(BeEquivalentTo(hex.EncodeToString(hashDst.Sum(nil)[:16]))) + }) + }) + }) + + Describe("Handle config file without full path", func() { + Context("config file handling", func() { + It("should be successful", func() { + written := &TestConfig{ + SomeField: 123, + } + cfgFile := "test_cfg.json" + defer os.Remove(cfgFile) + + err := util.WriteJson(cfgFile, written) + Expect(err).NotTo(HaveOccurred()) + + read, err := util.ReadJson(cfgFile, &TestConfig{}) + Expect(err).NotTo(HaveOccurred()) + Expect(read).NotTo(BeNil()) + }) + }) + }) +}) diff --git a/util/file_test.go b/util/file_test.go index 3de7db49b..1330e738e 100644 --- a/util/file_test.go +++ b/util/file_test.go @@ -1,126 +1,198 @@ -package util_test +package util import ( - "crypto/md5" - "encoding/hex" - "io" "os" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" - - "github.com/netbirdio/netbird/util" + "reflect" + "strings" + "testing" ) -var _ = Describe("Client", func() { - - var ( - tmpDir string - ) - - type TestConfig struct { - SomeMap map[string]string - SomeArray []string - SomeField int +func TestReadJsonWithEnvSub(t *testing.T) { + type Config struct { + CertFile string `json:"CertFile"` + Credentials string `json:"Credentials"` + NestedOption struct { + URL string `json:"URL"` + } `json:"NestedOption"` } - BeforeEach(func() { - var err error - tmpDir, err = os.MkdirTemp("", "wiretrustee_util_test_tmp_*") - Expect(err).NotTo(HaveOccurred()) - }) + type testCase struct { + name string + envVars map[string]string + jsonTemplate string + expectedResult Config + expectError bool + errorContains string + } - AfterEach(func() { - err := os.RemoveAll(tmpDir) - Expect(err).NotTo(HaveOccurred()) - }) + tests := []testCase{ + { + name: "All environment variables set", + envVars: map[string]string{ + "CERT_FILE": "/etc/certs/env_cert.crt", + "CREDENTIALS": "env_credentials", + "URL": "https://env.testing.com", + }, + jsonTemplate: `{ + "CertFile": "{{ .CERT_FILE }}", + "Credentials": "{{ .CREDENTIALS }}", + "NestedOption": { + "URL": "{{ .URL }}" + } + }`, + expectedResult: Config{ + CertFile: "/etc/certs/env_cert.crt", + Credentials: "env_credentials", + NestedOption: struct { + URL string `json:"URL"` + }{ + URL: "https://env.testing.com", + }, + }, + expectError: false, + }, + { + name: "Missing environment variable", + envVars: map[string]string{ + "CERT_FILE": "/etc/certs/env_cert.crt", + "CREDENTIALS": "env_credentials", + // "URL" is intentionally missing + }, + jsonTemplate: `{ + "CertFile": "{{ .CERT_FILE }}", + "Credentials": "{{ .CREDENTIALS }}", + "NestedOption": { + "URL": "{{ .URL }}" + } + }`, + expectedResult: Config{ + CertFile: "/etc/certs/env_cert.crt", + Credentials: "env_credentials", + NestedOption: struct { + URL string `json:"URL"` + }{ + URL: "", + }, + }, + expectError: false, + }, + { + name: "Invalid JSON template", + envVars: map[string]string{ + "CERT_FILE": "/etc/certs/env_cert.crt", + "CREDENTIALS": "env_credentials", + "URL": "https://env.testing.com", + }, + jsonTemplate: `{ + "CertFile": "{{ .CERT_FILE }}", + "Credentials": "{{ .CREDENTIALS }", + "NestedOption": { + "URL": "{{ .URL }}" + } + }`, // Note the missing closing brace in "{{ .CREDENTIALS }" + expectedResult: Config{}, + expectError: true, + errorContains: "unexpected \"}\" in operand", + }, + { + name: "No substitutions", + envVars: map[string]string{ + "CERT_FILE": "/etc/certs/env_cert.crt", + "CREDENTIALS": "env_credentials", + "URL": "https://env.testing.com", + }, + jsonTemplate: `{ + "CertFile": "/etc/certs/cert.crt", + "Credentials": "admnlknflkdasdf", + "NestedOption" : { + "URL": "https://testing.com" + } + }`, + expectedResult: Config{ + CertFile: "/etc/certs/cert.crt", + Credentials: "admnlknflkdasdf", + NestedOption: struct { + URL string `json:"URL"` + }{ + URL: "https://testing.com", + }, + }, + expectError: false, + }, + { + name: "Should fail when Invalid characters in variables", + envVars: map[string]string{ + "CERT_FILE": `"/etc/certs/"cert".crt"`, + "CREDENTIALS": `env_credentia{ls}`, + "URL": `https://env.testing.com?param={{value}}`, + }, + jsonTemplate: `{ + "CertFile": "{{ .CERT_FILE }}", + "Credentials": "{{ .CREDENTIALS }}", + "NestedOption": { + "URL": "{{ .URL }}" + } + }`, + expectedResult: Config{ + CertFile: `"/etc/certs/"cert".crt"`, + Credentials: `env_credentia{ls}`, + NestedOption: struct { + URL string `json:"URL"` + }{ + URL: `https://env.testing.com?param={{value}}`, + }, + }, + expectError: true, + }, + } - Describe("Config", func() { - Context("in JSON format", func() { - It("should be written and read successfully", func() { + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + for key, value := range tc.envVars { + t.Setenv(key, value) + } - m := make(map[string]string) - m["key1"] = "value1" - m["key2"] = "value2" + tempFile, err := os.CreateTemp("", "config*.json") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } - arr := []string{"value1", "value2"} - - written := &TestConfig{ - SomeMap: m, - SomeArray: arr, - SomeField: 99, + defer func() { + err = os.Remove(tempFile.Name()) + if err != nil { + t.Logf("Failed to remove temp file: %v", err) } + }() - err := util.WriteJson(tmpDir+"/testconfig.json", written) - Expect(err).NotTo(HaveOccurred()) + _, err = tempFile.WriteString(tc.jsonTemplate) + if err != nil { + t.Fatalf("Failed to write to temp file: %v", err) + } + err = tempFile.Close() + if err != nil { + t.Fatalf("Failed to close temp file: %v", err) + } - read, err := util.ReadJson(tmpDir+"/testconfig.json", &TestConfig{}) - Expect(err).NotTo(HaveOccurred()) - Expect(read).NotTo(BeNil()) - Expect(read.(*TestConfig).SomeMap["key1"]).To(BeEquivalentTo(written.SomeMap["key1"])) - Expect(read.(*TestConfig).SomeMap["key2"]).To(BeEquivalentTo(written.SomeMap["key2"])) - Expect(read.(*TestConfig).SomeArray).To(ContainElements(arr)) - Expect(read.(*TestConfig).SomeField).To(BeEquivalentTo(written.SomeField)) + var result Config - }) - }) - }) + _, err = ReadJsonWithEnvSub(tempFile.Name(), &result) - Describe("Copying file contents", func() { - Context("from one file to another", func() { - It("should be successful", func() { - - src := tmpDir + "/copytest_src" - dst := tmpDir + "/copytest_dst" - - err := util.WriteJson(src, []string{"1", "2", "3"}) - Expect(err).NotTo(HaveOccurred()) - - err = util.CopyFileContents(src, dst) - Expect(err).NotTo(HaveOccurred()) - - hashSrc := md5.New() - hashDst := md5.New() - - srcFile, err := os.Open(src) - Expect(err).NotTo(HaveOccurred()) - - dstFile, err := os.Open(dst) - Expect(err).NotTo(HaveOccurred()) - - _, err = io.Copy(hashSrc, srcFile) - Expect(err).NotTo(HaveOccurred()) - - _, err = io.Copy(hashDst, dstFile) - Expect(err).NotTo(HaveOccurred()) - - err = srcFile.Close() - Expect(err).NotTo(HaveOccurred()) - - err = dstFile.Close() - Expect(err).NotTo(HaveOccurred()) - - Expect(hex.EncodeToString(hashSrc.Sum(nil)[:16])).To(BeEquivalentTo(hex.EncodeToString(hashDst.Sum(nil)[:16]))) - }) - }) - }) - - Describe("Handle config file without full path", func() { - Context("config file handling", func() { - It("should be successful", func() { - written := &TestConfig{ - SomeField: 123, + if tc.expectError { + if err == nil { + t.Fatalf("Expected error but got none") } - cfgFile := "test_cfg.json" - defer os.Remove(cfgFile) - - err := util.WriteJson(cfgFile, written) - Expect(err).NotTo(HaveOccurred()) - - read, err := util.ReadJson(cfgFile, &TestConfig{}) - Expect(err).NotTo(HaveOccurred()) - Expect(read).NotTo(BeNil()) - }) + if !strings.Contains(err.Error(), tc.errorContains) { + t.Errorf("Expected error containing '%s', but got '%v'", tc.errorContains, err) + } + } else { + if err != nil { + t.Fatalf("ReadJsonWithEnvSub failed: %v", err) + } + if !reflect.DeepEqual(result, tc.expectedResult) { + t.Errorf("Result does not match expected.\nGot: %+v\nExpected: %+v", result, tc.expectedResult) + } + } }) - }) -}) + } +} diff --git a/util/net/net.go b/util/net/net.go index 8d1fcebd0..035d7552b 100644 --- a/util/net/net.go +++ b/util/net/net.go @@ -4,14 +4,15 @@ import ( "net" "os" - "github.com/netbirdio/netbird/iface/netstack" + "github.com/netbirdio/netbird/client/iface/netstack" "github.com/google/uuid" ) const ( // NetbirdFwmark is the fwmark value used by Netbird via wireguard - NetbirdFwmark = 0x1BD00 + NetbirdFwmark = 0x1BD00 + PreroutingFwmark = 0x1BD01 envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" )