diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 000000000..70aebf578 --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,15 @@ +FROM golang:1.20-bullseye + +RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ + && apt-get -y install --no-install-recommends\ + gettext-base=0.21-4 \ + iptables=1.8.7-1 \ + libgl1-mesa-dev=20.3.5-1 \ + xorg-dev=1:7.7+22 \ + libayatana-appindicator3-dev=0.5.5-2+deb11u2 \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* \ + && go install -v golang.org/x/tools/gopls@latest + + +WORKDIR /app diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 000000000..7dce7f058 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,20 @@ +{ + "name": "NetBird", + "build": { + "context": "..", + "dockerfile": "Dockerfile" + }, + "features": { + "ghcr.io/devcontainers/features/docker-in-docker:2": {}, + "ghcr.io/devcontainers/features/go:1": { + "version": "1.20" + } + }, + "workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}", + "capAdd": [ + "NET_ADMIN", + "SYS_ADMIN", + "SYS_RESOURCE" + ], + "privileged": true +} \ No newline at end of file diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 000000000..d207b1802 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +*.go text eol=lf diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml index 97fdeabe8..5998fab01 100644 --- a/.github/workflows/golang-test-darwin.yml +++ b/.github/workflows/golang-test-darwin.yml @@ -12,6 +12,9 @@ concurrency: jobs: test: + strategy: + matrix: + store: ['jsonfile', 'sqlite'] runs-on: macos-latest steps: - name: Install Go @@ -33,4 +36,4 @@ jobs: run: go mod tidy - name: Test - run: go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./... + run: NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./... diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 13061f6eb..8015fb36a 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -15,6 +15,7 @@ jobs: strategy: matrix: arch: ['386','amd64'] + store: ['jsonfile', 'sqlite'] runs-on: ubuntu-latest steps: - name: Install Go @@ -41,17 +42,16 @@ jobs: run: go mod tidy - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./... + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./... test_client_on_docker: - runs-on: ubuntu-latest + runs-on: ubuntu-20.04 steps: - name: Install Go uses: actions/setup-go@v4 with: go-version: "1.20.x" - - name: Cache Go modules uses: actions/cache@v3 with: @@ -64,7 +64,7 @@ jobs: uses: actions/checkout@v3 - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib - name: Install modules run: go mod tidy @@ -82,7 +82,7 @@ jobs: run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/... - name: Generate Engine Test bin - run: CGO_ENABLED=0 go test -c -o engine-testing.bin ./client/internal + run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal - name: Generate Peer Test bin run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/... @@ -95,15 +95,17 @@ jobs: - name: Run Iface tests in docker run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin -test.timeout 5m -test.parallel 1 - - name: Run RouteManager tests in docker run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1 - name: Run nftables Manager tests in docker run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1 - - name: Run Engine tests in docker - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1 + - name: Run Engine tests in docker with file store + run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1 + + - name: Run Engine tests in docker with sqlite store + run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1 - name: Run Peer tests in docker run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1 \ No newline at end of file diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index 6dd91666c..34f0ec680 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -39,7 +39,9 @@ jobs: - run: mv ${{ env.downloadPath }}/wintun/bin/amd64/wintun.dll 'C:\Windows\System32\' - - run: choco install -y sysinternals + - run: choco install -y sysinternals --ignore-checksums + - run: choco install -y mingw + - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 2a5c51c8a..4e584ecc2 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -1,12 +1,23 @@ name: golangci-lint on: [pull_request] + +permissions: + contents: read + pull-requests: read + concurrency: group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }} cancel-in-progress: true + jobs: golangci: + strategy: + fail-fast: false + matrix: + os: [macos-latest, windows-latest, ubuntu-latest] name: lint - runs-on: ubuntu-latest + runs-on: ${{ matrix.os }} + timeout-minutes: 15 steps: - name: Checkout code uses: actions/checkout@v3 @@ -14,7 +25,12 @@ jobs: uses: actions/setup-go@v4 with: go-version: "1.20.x" + cache: false - name: Install dependencies + if: matrix.os == 'ubuntu-latest' run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev - name: golangci-lint - uses: golangci/golangci-lint-action@v3 \ No newline at end of file + uses: golangci/golangci-lint-action@v3 + with: + version: latest + args: --timeout=12m \ No newline at end of file diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 3feefdd49..5833638c5 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -17,6 +17,7 @@ on: - 'release_files/**' - '**/Dockerfile' - '**/Dockerfile.*' + - 'client/ui/**' env: SIGN_PIPE_VER: "v0.0.9" diff --git a/.github/workflows/test-infrastructure-files.yml b/.github/workflows/test-infrastructure-files.yml index c2c4f7598..6482b716f 100644 --- a/.github/workflows/test-infrastructure-files.yml +++ b/.github/workflows/test-infrastructure-files.yml @@ -8,6 +8,8 @@ on: paths: - 'infrastructure_files/**' - '.github/workflows/test-infrastructure-files.yml' + - 'management/cmd/**' + - 'signal/cmd/**' concurrency: group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }} @@ -56,6 +58,8 @@ jobs: CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified" + CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite" + CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false - name: check values working-directory: infrastructure_files @@ -81,6 +85,8 @@ jobs: CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret CI_NETBIRD_SIGNAL_PORT: 12345 + CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite" + CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false run: | grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID @@ -97,7 +103,9 @@ jobs: grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM grep -A 3 DeviceAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE - grep -A 8 DeviceAuthorizationFlow management.json | grep -A 6 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE" + grep -A 3 DeviceAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE + grep Engine management.json | grep "$CI_NETBIRD_STORE_CONFIG_ENGINE" + grep IdpSignKeyRefreshEnabled management.json | grep "$CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH" grep UseIDToken management.json | grep false grep -A 1 IdpManagerConfig management.json | grep ManagerType | grep $CI_NETBIRD_MGMT_IDP grep -A 3 IdpManagerConfig management.json | grep -A 1 ClientConfig | grep Issuer | grep $CI_NETBIRD_AUTH_AUTHORITY @@ -105,12 +113,13 @@ jobs: grep -A 5 IdpManagerConfig management.json | grep -A 3 ClientConfig | grep ClientID | grep $CI_NETBIRD_IDP_MGMT_CLIENT_ID grep -A 6 IdpManagerConfig management.json | grep -A 4 ClientConfig | grep ClientSecret | grep $CI_NETBIRD_IDP_MGMT_CLIENT_SECRET grep -A 7 IdpManagerConfig management.json | grep -A 5 ClientConfig | grep GrantType | grep client_credentials - grep -A 2 PKCEAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_AUDIENCE - grep -A 3 PKCEAuthorizationFlow management.json | grep -A 2 ProviderConfig | grep ClientID | grep $CI_NETBIRD_AUTH_CLIENT_ID - grep -A 4 PKCEAuthorizationFlow management.json | grep -A 3 ProviderConfig | grep ClientSecret | grep $CI_NETBIRD_AUTH_CLIENT_SECRET - grep -A 5 PKCEAuthorizationFlow management.json | grep -A 4 ProviderConfig | grep AuthorizationEndpoint | grep $CI_NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT - grep -A 6 PKCEAuthorizationFlow management.json | grep -A 5 ProviderConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT - grep -A 7 PKCEAuthorizationFlow management.json | grep -A 6 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES" + grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_AUDIENCE + grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep ClientID | grep $CI_NETBIRD_AUTH_CLIENT_ID + grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep ClientSecret | grep $CI_NETBIRD_AUTH_CLIENT_SECRET + grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep AuthorizationEndpoint | grep $CI_NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT + grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT + grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES" + grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000" - name: Install modules run: go mod tidy diff --git a/.gitignore b/.gitignore index dc62780ad..7f7f53ce8 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,5 @@ client/.distfiles/ infrastructure_files/setup.env infrastructure_files/setup-*.env .vscode -.DS_Store \ No newline at end of file +.DS_Store +*.db \ No newline at end of file diff --git a/.goreleaser_ui.yaml b/.goreleaser_ui.yaml index c6f7a7c34..66a22ee34 100644 --- a/.goreleaser_ui.yaml +++ b/.goreleaser_ui.yaml @@ -54,7 +54,7 @@ nfpms: contents: - src: client/ui/netbird.desktop dst: /usr/share/applications/netbird.desktop - - src: client/ui/disconnected.png + - src: client/ui/netbird-systemtray-default.png dst: /usr/share/pixmaps/netbird.png dependencies: - netbird @@ -71,7 +71,7 @@ nfpms: contents: - src: client/ui/netbird.desktop dst: /usr/share/applications/netbird.desktop - - src: client/ui/disconnected.png + - src: client/ui/netbird-systemtray-default.png dst: /usr/share/pixmaps/netbird.png dependencies: - netbird @@ -91,4 +91,4 @@ uploads: mode: archive target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }} username: dev@wiretrustee.com - method: PUT \ No newline at end of file + method: PUT diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 298251108..80be72fa9 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -23,7 +23,6 @@ If you haven't already, join our slack workspace [here](https://join.slack.com/t - [Test suite](#test-suite) - [Checklist before submitting a PR](#checklist-before-submitting-a-pr) - [Other project repositories](#other-project-repositories) - - [Checklist before submitting a new node](#checklist-before-submitting-a-new-node) - [Contributor License Agreement](#contributor-license-agreement) ## Code of conduct @@ -70,7 +69,7 @@ dependencies are installed. Here is a short guide on how that can be done. ### Requirements -#### Go 1.19 +#### Go 1.21 Follow the installation guide from https://go.dev/ @@ -139,15 +138,14 @@ checked out and set up: ### Build and start #### Client -> Windows clients have a Wireguard driver requirement. We provide a bash script that can be executed in WLS 2 with docker support [wireguard_nt.sh](/client/wireguard_nt.sh). - To start NetBird, execute: ``` cd client -# bash wireguard_nt.sh # if windows -go build . +CGO_ENABLED=0 go build . ``` +> Windows clients have a Wireguard driver requirement. You can download the wintun driver from https://www.wintun.net/builds/wintun-0.14.1.zip, after decompressing, you can copy the file `windtun\bin\ARCH\wintun.dll` to the same path as your binary file or to `C:\Windows\System32\wintun.dll`. + To start NetBird the client in the foreground: ``` @@ -215,4 +213,4 @@ NetBird project is composed of 3 main repositories: That we do not have any potential problems later it is sadly necessary to sign a [Contributor License Agreement](CONTRIBUTOR_LICENSE_AGREEMENT.md). That can be done literally with the push of a button. -A bot will automatically comment on the pull request once it got opened asking for the agreement to be signed. Before it did not get signed it is sadly not possible to merge it in. \ No newline at end of file +A bot will automatically comment on the pull request once it got opened asking for the agreement to be signed. Before it did not get signed it is sadly not possible to merge it in. diff --git a/client/android/client.go b/client/android/client.go index bb15268eb..f2dd9e3f7 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -8,8 +8,8 @@ import ( "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/formatter" @@ -31,9 +31,9 @@ type IFaceDiscover interface { stdnet.ExternalIFaceDiscover } -// RouteListener export internal RouteListener for mobile -type RouteListener interface { - routemanager.RouteListener +// NetworkChangeListener export internal NetworkChangeListener for mobile +type NetworkChangeListener interface { + listener.NetworkChangeListener } // DnsReadyListener export internal dns ReadyListener for mobile @@ -47,26 +47,26 @@ func init() { // Client struct manage the life circle of background service type Client struct { - cfgFile string - tunAdapter iface.TunAdapter - iFaceDiscover IFaceDiscover - recorder *peer.Status - ctxCancel context.CancelFunc - ctxCancelLock *sync.Mutex - deviceName string - routeListener routemanager.RouteListener + cfgFile string + tunAdapter iface.TunAdapter + iFaceDiscover IFaceDiscover + recorder *peer.Status + ctxCancel context.CancelFunc + ctxCancelLock *sync.Mutex + deviceName string + networkChangeListener listener.NetworkChangeListener } // NewClient instantiate a new Client -func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, routeListener RouteListener) *Client { +func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client { return &Client{ - cfgFile: cfgFile, - deviceName: deviceName, - tunAdapter: tunAdapter, - iFaceDiscover: iFaceDiscover, - recorder: peer.NewRecorder(""), - ctxCancelLock: &sync.Mutex{}, - routeListener: routeListener, + cfgFile: cfgFile, + deviceName: deviceName, + tunAdapter: tunAdapter, + iFaceDiscover: iFaceDiscover, + recorder: peer.NewRecorder(""), + ctxCancelLock: &sync.Mutex{}, + networkChangeListener: networkChangeListener, } } @@ -96,7 +96,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) - return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener, dns.items, dnsReadyListener) + return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) } // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). @@ -120,7 +120,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) - return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener, dns.items, dnsReadyListener) + return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) } // Stop the internal client and free the resources diff --git a/client/android/preferences_test.go b/client/android/preferences_test.go index 73c8692d7..985175913 100644 --- a/client/android/preferences_test.go +++ b/client/android/preferences_test.go @@ -57,11 +57,11 @@ func TestPreferences_ReadUncommitedValues(t *testing.T) { p.SetManagementURL(exampleString) resp, err = p.GetManagementURL() if err != nil { - t.Fatalf("failed to read managmenet url: %s", err) + t.Fatalf("failed to read management url: %s", err) } if resp != exampleString { - t.Errorf("unexpected managemenet url: %s", resp) + t.Errorf("unexpected management url: %s", resp) } p.SetPreSharedKey(exampleString) @@ -102,11 +102,11 @@ func TestPreferences_Commit(t *testing.T) { resp, err = p.GetManagementURL() if err != nil { - t.Fatalf("failed to read managmenet url: %s", err) + t.Fatalf("failed to read management url: %s", err) } if resp != exampleURL { - t.Errorf("unexpected managemenet url: %s", resp) + t.Errorf("unexpected management url: %s", resp) } resp, err = p.GetPreSharedKey() diff --git a/client/cmd/testutil.go b/client/cmd/testutil.go index 6d47021dd..47ae9ddb4 100644 --- a/client/cmd/testutil.go +++ b/client/cmd/testutil.go @@ -65,7 +65,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste t.Fatal(err) } s := grpc.NewServer() - store, err := mgmt.NewFileStore(config.Datadir, nil) + store, err := mgmt.NewStoreFromJson(config.Datadir, nil) if err != nil { t.Fatal(err) } diff --git a/client/cmd/up.go b/client/cmd/up.go index 8d682c46b..80ed04b57 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -123,7 +123,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { defer func() { err := conn.Close() if err != nil { - log.Warnf("failed closing dameon gRPC client connection %v", err) + log.Warnf("failed closing daemon gRPC client connection %v", err) return } }() @@ -200,11 +200,11 @@ func validateNATExternalIPs(list []string) error { subElements := strings.Split(element, "/") if len(subElements) > 2 { - return fmt.Errorf("%s is not a valid input for %s. it should be formated as \"String\" or \"String/String\"", element, externalIPMapFlag) + return fmt.Errorf("%s is not a valid input for %s. it should be formatted as \"String\" or \"String/String\"", element, externalIPMapFlag) } if len(subElements) == 1 && !isValidIP(subElements[0]) { - return fmt.Errorf("%s is not a valid input for %s. it should be formated as \"IP\" or \"IP/IP\", or \"IP/Interface Name\"", element, externalIPMapFlag) + return fmt.Errorf("%s is not a valid input for %s. it should be formatted as \"IP\" or \"IP/IP\", or \"IP/Interface Name\"", element, externalIPMapFlag) } last := 0 @@ -259,7 +259,7 @@ func parseCustomDNSAddress(modified bool) ([]byte, error) { var parsed []byte if modified { if !isValidAddrPort(customDNSAddress) { - return nil, fmt.Errorf("%s is invalid, it should be formated as IP:Port string or as an empty string like \"\"", customDNSAddress) + return nil, fmt.Errorf("%s is invalid, it should be formatted as IP:Port string or as an empty string like \"\"", customDNSAddress) } if customDNSAddress == "" && logFile != "console" { parsed = []byte("empty") diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 048c0fd50..4ce904df6 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -192,7 +192,7 @@ func (m *Manager) AddFiltering( } if ipsetName != "" { // ipset name is defined and it means that this rule was created - // for it, need to assosiate it with ruleset + // for it, need to associate it with ruleset m.rulesets[ipsetName] = ruleset{ rule: rule, ips: map[string]string{rule.ip: ruleID}, @@ -236,7 +236,7 @@ func (m *Manager) DeleteRule(rule fw.Rule) error { } // we delete last IP from the set, that means we need to delete - // set itself and assosiated firewall rule too + // set itself and associated firewall rule too delete(m.rulesets, r.ipsetName) if err := ipset.Destroy(r.ipsetName); err != nil { diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 2273f4edc..6c46048b4 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -754,7 +754,7 @@ func (m *Manager) AllowNetbird() error { } if chain == nil { - log.Debugf("chain INPUT not found. Skiping add allow netbird rule") + log.Debugf("chain INPUT not found. Skipping add allow netbird rule") return nil } diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 164d5d0dc..0a5c499b2 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -148,7 +148,7 @@ func TestNftablesManager(t *testing.T) { // test expectations: // 1) "accept extra routed traffic rule" for the interface // 2) "drop all rule" for the interface - require.Len(t, rules, 2, "expected 2 rules after deleteion") + require.Len(t, rules, 2, "expected 2 rules after deletion") err = manager.Reset() require.NoError(t, err, "failed to reset") diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index 05a6d22ae..140bfc87a 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -1,21 +1,19 @@ package uspfilter import ( - "errors" "fmt" "os/exec" - "strings" "syscall" + + log "github.com/sirupsen/logrus" ) type action string const ( - addRule action = "add" - deleteRule action = "delete" - - firewallRuleName = "Netbird" - noRulesMatchCriteria = "No rules match the specified criteria" + addRule action = "add" + deleteRule action = "delete" + firewallRuleName = "Netbird" ) // Reset firewall to the default state @@ -26,6 +24,14 @@ func (m *Manager) Reset() error { m.outgoingRules = make(map[string]RuleSet) m.incomingRules = make(map[string]RuleSet) + if !isWindowsFirewallReachable() { + return nil + } + + if !isFirewallRuleActive(firewallRuleName) { + return nil + } + if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil { return fmt.Errorf("couldn't remove windows firewall: %w", err) } @@ -35,6 +41,13 @@ func (m *Manager) Reset() error { // AllowNetbird allows netbird interface traffic func (m *Manager) AllowNetbird() error { + if !isWindowsFirewallReachable() { + return nil + } + + if isFirewallRuleActive(firewallRuleName) { + return nil + } return manageFirewallRule(firewallRuleName, addRule, "dir=in", @@ -45,47 +58,37 @@ func (m *Manager) AllowNetbird() error { ) } -func manageFirewallRule(ruleName string, action action, args ...string) error { - active, err := isFirewallRuleActive(ruleName) - if err != nil { - return err +func manageFirewallRule(ruleName string, action action, extraArgs ...string) error { + + args := []string{"advfirewall", "firewall", string(action), "rule", "name=" + ruleName} + if action == addRule { + args = append(args, extraArgs...) } - if (action == addRule && !active) || (action == deleteRule && active) { - baseArgs := []string{"advfirewall", "firewall", string(action), "rule", "name=" + ruleName} - args := append(baseArgs, args...) - - cmd := exec.Command("netsh", args...) - cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} - return cmd.Run() - } - - return nil + cmd := exec.Command("netsh", args...) + cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} + return cmd.Run() } -func isFirewallRuleActive(ruleName string) (bool, error) { +func isWindowsFirewallReachable() bool { + args := []string{"advfirewall", "show", "allprofiles", "state"} + cmd := exec.Command("netsh", args...) + cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} + + _, err := cmd.Output() + if err != nil { + log.Infof("Windows firewall is not reachable, skipping default rule management. Using only user space rules. Error: %s", err) + return false + } + + return true +} + +func isFirewallRuleActive(ruleName string) bool { args := []string{"advfirewall", "firewall", "show", "rule", "name=" + ruleName} cmd := exec.Command("netsh", args...) cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} - output, err := cmd.Output() - if err != nil { - var exitError *exec.ExitError - if errors.As(err, &exitError) { - // if the firewall rule is not active, we expect last exit code to be 1 - exitStatus := exitError.Sys().(syscall.WaitStatus).ExitStatus() - if exitStatus == 1 { - if strings.Contains(string(output), noRulesMatchCriteria) { - return false, nil - } - } - } - return false, err - } - - if strings.Contains(string(output), noRulesMatchCriteria) { - return false, nil - } - - return true, nil + _, err := cmd.Output() + return err == nil } diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 50170b46c..6fd11e652 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -32,7 +32,7 @@ type Manager struct { wgNetwork *net.IPNet decoders sync.Pool wgIface IFaceMapper - resetHook func() error + resetHook func() error mutex sync.RWMutex } @@ -188,7 +188,7 @@ func (m *Manager) DropIncoming(packetData []byte) bool { return m.dropFilter(packetData, m.incomingRules, true) } -// dropFilter imlements same logic for booth direction of the traffic +// dropFilter implements same logic for booth direction of the traffic func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isIncomingPacket bool) bool { m.mutex.RLock() defer m.mutex.RUnlock() diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index 9a9e624d6..feaaa7b8b 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -53,7 +53,7 @@ func newDefaultManager(fm firewall.Manager) *DefaultManager { // ApplyFiltering firewall rules to the local firewall manager processed by ACL policy. // -// If allowByDefault is ture it appends allow ALL traffic rules to input and output chains. +// If allowByDefault is true it appends allow ALL traffic rules to input and output chains. func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { d.mutex.Lock() defer d.mutex.Unlock() @@ -366,7 +366,7 @@ func (d *DefaultManager) squashAcceptRules( protocols[r.Protocol] = map[string]int{} } - // special case, when we recieve this all network IP address + // special case, when we receive this all network IP address // it means that rules for that protocol was already optimized on the // management side if r.PeerIP == "0.0.0.0" { @@ -393,7 +393,7 @@ func (d *DefaultManager) squashAcceptRules( } // order of squashing by protocol is important - // only for ther first element ALL, it must be done first + // only for their first element ALL, it must be done first protocolOrders := []mgmProto.FirewallRuleProtocol{ mgmProto.FirewallRule_ALL, mgmProto.FirewallRule_ICMP, diff --git a/client/internal/acl/manager_create.go b/client/internal/acl/manager_create.go index 2fdca02ae..66185749b 100644 --- a/client/internal/acl/manager_create.go +++ b/client/internal/acl/manager_create.go @@ -20,7 +20,7 @@ func Create(iface IFaceMapper) (manager *DefaultManager, err error) { return nil, err } if err := fm.AllowNetbird(); err != nil { - log.Errorf("failed to allow netbird interface traffic: %v", err) + log.Warnf("failed to allow netbird interface traffic: %v", err) } return newDefaultManager(fm), nil } diff --git a/client/internal/connect.go b/client/internal/connect.go index aa3ac629b..747146d10 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -13,8 +13,8 @@ import ( gstatus "google.golang.org/grpc/status" "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" @@ -22,6 +22,7 @@ import ( mgm "github.com/netbirdio/netbird/management/client" mgmProto "github.com/netbirdio/netbird/management/proto" signal "github.com/netbirdio/netbird/signal/client" + "github.com/netbirdio/netbird/version" ) // RunClient with main logic. @@ -30,14 +31,14 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) } // RunClientMobile with main logic on mobile system -func RunClientMobile(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, routeListener routemanager.RouteListener, dnsAddresses []string, dnsReadyListener dns.ReadyListener) error { +func RunClientMobile(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, networkChangeListener listener.NetworkChangeListener, dnsAddresses []string, dnsReadyListener dns.ReadyListener) error { // in case of non Android os these variables will be nil mobileDependency := MobileDependency{ - TunAdapter: tunAdapter, - IFaceDiscover: iFaceDiscover, - RouteListener: routeListener, - HostDNSAddresses: dnsAddresses, - DnsReadyListener: dnsReadyListener, + TunAdapter: tunAdapter, + IFaceDiscover: iFaceDiscover, + NetworkChangeListener: networkChangeListener, + HostDNSAddresses: dnsAddresses, + DnsReadyListener: dnsReadyListener, } return runClient(ctx, config, statusRecorder, mobileDependency) } @@ -53,6 +54,8 @@ func RunClientiOS(ctx context.Context, config *Config, statusRecorder *peer.Stat } func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status, mobileDependency MobileDependency) error { + log.Infof("starting NetBird client version %s", version.NetbirdVersion()) + backOff := &backoff.ExponentialBackOff{ InitialInterval: time.Second, RandomizationFactor: 1, @@ -106,7 +109,7 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status, cancel() }() - log.Debugf("conecting to the Management service %s", config.ManagementURL.Host) + log.Debugf("connecting to the Management service %s", config.ManagementURL.Host) mgmClient, err := mgm.NewClient(engineCtx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled) if err != nil { return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err)) diff --git a/client/internal/dns/file_linux.go b/client/internal/dns/file_linux.go index fb59ba63b..67128f79a 100644 --- a/client/internal/dns/file_linux.go +++ b/client/internal/dns/file_linux.go @@ -3,29 +3,25 @@ package dns import ( + "bufio" "bytes" "fmt" "os" + "strings" log "github.com/sirupsen/logrus" ) const ( - fileGeneratedResolvConfContentHeader = "# Generated by NetBird" - fileGeneratedResolvConfSearchBeginContent = "search " - fileGeneratedResolvConfContentFormat = fileGeneratedResolvConfContentHeader + - "\n# If needed you can restore the original file by copying back %s\n\nnameserver %s\n" + - fileGeneratedResolvConfSearchBeginContent + "%s\n\n" + - "%s\n" -) + fileGeneratedResolvConfContentHeader = "# Generated by NetBird" + fileGeneratedResolvConfContentHeaderNextLine = fileGeneratedResolvConfContentHeader + ` +# If needed you can restore the original file by copying back ` + fileDefaultResolvConfBackupLocation + "\n\n" -const ( fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird" - fileMaxLineCharsLimit = 256 - fileMaxNumberOfSearchDomains = 6 -) -var fileSearchLineBeginCharCount = len(fileGeneratedResolvConfSearchBeginContent) + fileMaxLineCharsLimit = 256 + fileMaxNumberOfSearchDomains = 6 +) type fileConfigurator struct { originalPerms os.FileMode @@ -55,58 +51,39 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error { } return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured") } - managerType, err := getOSDNSManagerType() - if err != nil { - return err - } - switch managerType { - case fileManager, netbirdManager: - if !backupFileExist { - err = f.backup() - if err != nil { - return fmt.Errorf("unable to backup the resolv.conf file") - } - } - default: - // todo improve this and maybe restart DNS manager from scratch - return fmt.Errorf("something happened and file manager is not your prefered host dns configurator, restart the agent") - } - var searchDomains string - appendedDomains := 0 - for _, dConf := range config.domains { - if dConf.matchOnly || dConf.disabled { - continue - } - if appendedDomains >= fileMaxNumberOfSearchDomains { - // lets log all skipped domains - log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, dConf.domain) - continue - } - if fileSearchLineBeginCharCount+len(searchDomains) > fileMaxLineCharsLimit { - // lets log all skipped domains - log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, dConf.domain) - continue - } - - searchDomains += " " + dConf.domain - appendedDomains++ - } - - originalContent, err := os.ReadFile(fileDefaultResolvConfBackupLocation) - if err != nil { - log.Errorf("Could not read existing resolv.conf") - } - content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains, string(originalContent)) - err = writeDNSConfig(content, defaultResolvConfPath, f.originalPerms) - if err != nil { - err = f.restore() + if !backupFileExist { + err = f.backup() if err != nil { + return fmt.Errorf("unable to backup the resolv.conf file") + } + } + + searchDomainList := searchDomains(config) + + originalSearchDomains, nameServers, others, err := originalDNSConfigs(fileDefaultResolvConfBackupLocation) + if err != nil { + log.Error(err) + } + + searchDomainList = mergeSearchDomains(searchDomainList, originalSearchDomains) + + buf := prepareResolvConfContent( + searchDomainList, + append([]string{config.serverIP}, nameServers...), + others) + + log.Debugf("creating managed file %s", defaultResolvConfPath) + err = os.WriteFile(defaultResolvConfPath, buf.Bytes(), f.originalPerms) + if err != nil { + restoreErr := f.restore() + if restoreErr != nil { log.Errorf("attempt to restore default file failed with error: %s", err) } - return err + return fmt.Errorf("got an creating resolver file %s. Error: %s", defaultResolvConfPath, err) } - log.Infof("created a NetBird managed %s file with your DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, appendedDomains, searchDomains) + + log.Infof("created a NetBird managed %s file with your DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList) return nil } @@ -138,15 +115,138 @@ func (f *fileConfigurator) restore() error { return os.RemoveAll(fileDefaultResolvConfBackupLocation) } -func writeDNSConfig(content, fileName string, permissions os.FileMode) error { - log.Debugf("creating managed file %s", fileName) +func prepareResolvConfContent(searchDomains, nameServers, others []string) bytes.Buffer { var buf bytes.Buffer - buf.WriteString(content) - err := os.WriteFile(fileName, buf.Bytes(), permissions) - if err != nil { - return fmt.Errorf("got an creating resolver file %s. Error: %s", fileName, err) + buf.WriteString(fileGeneratedResolvConfContentHeaderNextLine) + + for _, cfgLine := range others { + buf.WriteString(cfgLine) + buf.WriteString("\n") } - return nil + + if len(searchDomains) > 0 { + buf.WriteString("search ") + buf.WriteString(strings.Join(searchDomains, " ")) + buf.WriteString("\n") + } + + for _, ns := range nameServers { + buf.WriteString("nameserver ") + buf.WriteString(ns) + buf.WriteString("\n") + } + return buf +} + +func searchDomains(config hostDNSConfig) []string { + listOfDomains := make([]string, 0) + for _, dConf := range config.domains { + if dConf.matchOnly || dConf.disabled { + continue + } + + listOfDomains = append(listOfDomains, dConf.domain) + } + return listOfDomains +} + +func originalDNSConfigs(resolvconfFile string) (searchDomains, nameServers, others []string, err error) { + file, err := os.Open(resolvconfFile) + if err != nil { + err = fmt.Errorf(`could not read existing resolv.conf`) + return + } + defer file.Close() + + reader := bufio.NewReader(file) + + for { + lineBytes, isPrefix, readErr := reader.ReadLine() + if readErr != nil { + break + } + + if isPrefix { + err = fmt.Errorf(`resolv.conf line too long`) + return + } + + line := strings.TrimSpace(string(lineBytes)) + + if strings.HasPrefix(line, "#") { + continue + } + + if strings.HasPrefix(line, "domain") { + continue + } + + if strings.HasPrefix(line, "options") && strings.Contains(line, "rotate") { + line = strings.ReplaceAll(line, "rotate", "") + splitLines := strings.Fields(line) + if len(splitLines) == 1 { + continue + } + line = strings.Join(splitLines, " ") + } + + if strings.HasPrefix(line, "search") { + splitLines := strings.Fields(line) + if len(splitLines) < 2 { + continue + } + + searchDomains = splitLines[1:] + continue + } + + if strings.HasPrefix(line, "nameserver") { + splitLines := strings.Fields(line) + if len(splitLines) != 2 { + continue + } + nameServers = append(nameServers, splitLines[1]) + continue + } + + others = append(others, line) + } + return +} + +// merge search domains lists and cut off the list if it is too long +func mergeSearchDomains(searchDomains []string, originalSearchDomains []string) []string { + lineSize := len("search") + searchDomainsList := make([]string, 0, len(searchDomains)+len(originalSearchDomains)) + + lineSize = validateAndFillSearchDomains(lineSize, &searchDomainsList, searchDomains) + _ = validateAndFillSearchDomains(lineSize, &searchDomainsList, originalSearchDomains) + + return searchDomainsList +} + +// validateAndFillSearchDomains checks if the search domains list is not too long and if the line is not too long +// extend s slice with vs elements +// return with the number of characters in the searchDomains line +func validateAndFillSearchDomains(initialLineChars int, s *[]string, vs []string) int { + for _, sd := range vs { + tmpCharsNumber := initialLineChars + 1 + len(sd) + if tmpCharsNumber > fileMaxLineCharsLimit { + // lets log all skipped domains + log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, sd) + continue + } + + initialLineChars = tmpCharsNumber + + if len(*s) >= fileMaxNumberOfSearchDomains { + // lets log all skipped domains + log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, sd) + continue + } + *s = append(*s, sd) + } + return initialLineChars } func copyFile(src, dest string) error { diff --git a/client/internal/dns/file_linux_test.go b/client/internal/dns/file_linux_test.go new file mode 100644 index 000000000..369a47ef4 --- /dev/null +++ b/client/internal/dns/file_linux_test.go @@ -0,0 +1,62 @@ +package dns + +import ( + "fmt" + "testing" +) + +func Test_mergeSearchDomains(t *testing.T) { + searchDomains := []string{"a", "b"} + originDomains := []string{"a", "b"} + mergedDomains := mergeSearchDomains(searchDomains, originDomains) + if len(mergedDomains) != 4 { + t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 4) + } +} + +func Test_mergeSearchTooMuchDomains(t *testing.T) { + searchDomains := []string{"a", "b", "c", "d", "e", "f", "g"} + originDomains := []string{"h", "i"} + mergedDomains := mergeSearchDomains(searchDomains, originDomains) + if len(mergedDomains) != 6 { + t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 6) + } +} + +func Test_mergeSearchTooMuchDomainsInOrigin(t *testing.T) { + searchDomains := []string{"a", "b"} + originDomains := []string{"c", "d", "e", "f", "g"} + mergedDomains := mergeSearchDomains(searchDomains, originDomains) + if len(mergedDomains) != 6 { + t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 6) + } +} + +func Test_mergeSearchTooLongDomain(t *testing.T) { + searchDomains := []string{getLongLine()} + originDomains := []string{"b"} + mergedDomains := mergeSearchDomains(searchDomains, originDomains) + if len(mergedDomains) != 1 { + t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 1) + } + + searchDomains = []string{"b"} + originDomains = []string{getLongLine()} + + mergedDomains = mergeSearchDomains(searchDomains, originDomains) + if len(mergedDomains) != 1 { + t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 1) + } +} + +func getLongLine() string { + x := "search " + for { + for i := 0; i <= 9; i++ { + if len(x) > fileMaxLineCharsLimit { + return x + } + x = fmt.Sprintf("%s%d", x, i) + } + } +} diff --git a/client/internal/dns/host.go b/client/internal/dns/host.go index 16c6c032d..34a0769a2 100644 --- a/client/internal/dns/host.go +++ b/client/internal/dns/host.go @@ -78,7 +78,7 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD for _, domain := range nsConfig.Domains { config.domains = append(config.domains, domainConfig{ domain: strings.TrimSuffix(domain, "."), - matchOnly: true, + matchOnly: !nsConfig.SearchDomainsEnabled, }) } } diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index ffb35ef6b..1e88a6c7b 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -22,13 +22,11 @@ const ( interfaceConfigPath = "SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Interfaces" interfaceConfigNameServerKey = "NameServer" interfaceConfigSearchListKey = "SearchList" - tcpipParametersPath = "SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters" ) type registryConfigurator struct { - guid string - routingAll bool - existingSearchDomains []string + guid string + routingAll bool } func newHostManager(wgInterface WGIface) (hostManager, error) { @@ -148,30 +146,11 @@ func (r *registryConfigurator) restoreHostDNS() error { log.Error(err) } - return r.updateSearchDomains([]string{}) + return r.deleteInterfaceRegistryKeyProperty(interfaceConfigSearchListKey) } func (r *registryConfigurator) updateSearchDomains(domains []string) error { - value, err := getLocalMachineRegistryKeyStringValue(tcpipParametersPath, interfaceConfigSearchListKey) - if err != nil { - return fmt.Errorf("unable to get current search domains failed with error: %s", err) - } - - valueList := strings.Split(value, ",") - setExisting := false - if len(r.existingSearchDomains) == 0 { - r.existingSearchDomains = valueList - setExisting = true - } - - if len(domains) == 0 && setExisting { - log.Infof("added %d search domains to the registry. Domain list: %s", len(domains), domains) - return nil - } - - newList := append(r.existingSearchDomains, domains...) - - err = setLocalMachineRegistryKeyStringValue(tcpipParametersPath, interfaceConfigSearchListKey, strings.Join(newList, ",")) + err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ",")) if err != nil { return fmt.Errorf("adding search domain failed with error: %s", err) } @@ -235,33 +214,3 @@ func removeRegistryKeyFromDNSPolicyConfig(regKeyPath string) error { } return nil } - -func getLocalMachineRegistryKeyStringValue(keyPath, key string) (string, error) { - regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.QUERY_VALUE) - if err != nil { - return "", fmt.Errorf("unable to open existing key from registry, key path: HKEY_LOCAL_MACHINE\\%s, error: %s", keyPath, err) - } - defer regKey.Close() - - val, _, err := regKey.GetStringValue(key) - if err != nil { - return "", fmt.Errorf("getting %s value for key path HKEY_LOCAL_MACHINE\\%s failed with error: %s", key, keyPath, err) - } - - return val, nil -} - -func setLocalMachineRegistryKeyStringValue(keyPath, key, value string) error { - regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.SET_VALUE) - if err != nil { - return fmt.Errorf("unable to open existing key from registry, key path: HKEY_LOCAL_MACHINE\\%s, error: %s", keyPath, err) - } - defer regKey.Close() - - err = regKey.SetStringValue(key, value) - if err != nil { - return fmt.Errorf("setting %s value %s for key path HKEY_LOCAL_MACHINE\\%s failed with error: %s", key, value, keyPath, err) - } - - return nil -} diff --git a/client/internal/dns/mockServer.go b/client/internal/dns/mockServer.go index 8970eec6e..3534fc0c3 100644 --- a/client/internal/dns/mockServer.go +++ b/client/internal/dns/mockServer.go @@ -2,6 +2,7 @@ package dns import ( "fmt" + nbdns "github.com/netbirdio/netbird/dns" ) @@ -43,3 +44,7 @@ func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error { } return fmt.Errorf("method UpdateDNSServer is not implemented") } + +func (m *MockServer) SearchDomains() []string { + return make([]string, 0) +} diff --git a/client/internal/dns/notifier.go b/client/internal/dns/notifier.go new file mode 100644 index 000000000..85c270e58 --- /dev/null +++ b/client/internal/dns/notifier.go @@ -0,0 +1,57 @@ +package dns + +import ( + "reflect" + "sort" + "sync" + + "github.com/netbirdio/netbird/client/internal/listener" +) + +type notifier struct { + listener listener.NetworkChangeListener + listenerMux sync.Mutex + searchDomains []string +} + +func newNotifier(initialSearchDomains []string) *notifier { + sort.Strings(initialSearchDomains) + return ¬ifier{ + searchDomains: initialSearchDomains, + } +} + +func (n *notifier) setListener(listener listener.NetworkChangeListener) { + n.listenerMux.Lock() + defer n.listenerMux.Unlock() + n.listener = listener +} + +func (n *notifier) onNewSearchDomains(searchDomains []string) { + sort.Strings(searchDomains) + + if len(n.searchDomains) != len(searchDomains) { + n.searchDomains = searchDomains + n.notify() + return + } + + if reflect.DeepEqual(n.searchDomains, searchDomains) { + return + } + + n.searchDomains = searchDomains + n.notify() +} + +func (n *notifier) notify() { + n.listenerMux.Lock() + defer n.listenerMux.Unlock() + if n.listener == nil { + return + } + + go func(l listener.NetworkChangeListener) { + l.OnNetworkChanged() + }(n.listener) +} diff --git a/client/internal/dns/resolvconf_linux.go b/client/internal/dns/resolvconf_linux.go index 17e4c0196..9bde4b15a 100644 --- a/client/internal/dns/resolvconf_linux.go +++ b/client/internal/dns/resolvconf_linux.go @@ -3,10 +3,9 @@ package dns import ( + "bytes" "fmt" - "os" "os/exec" - "strings" log "github.com/sirupsen/logrus" ) @@ -15,11 +14,24 @@ const resolvconfCommand = "resolvconf" type resolvconf struct { ifaceName string + + originalSearchDomains []string + originalNameServers []string + othersConfigs []string } +// supported "openresolv" only func newResolvConfConfigurator(wgInterface WGIface) (hostManager, error) { + originalSearchDomains, nameServers, others, err := originalDNSConfigs("/etc/resolv.conf") + if err != nil { + log.Error(err) + } + return &resolvconf{ - ifaceName: wgInterface.Name(), + ifaceName: wgInterface.Name(), + originalSearchDomains: originalSearchDomains, + originalNameServers: nameServers, + othersConfigs: others, }, nil } @@ -37,41 +49,20 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error { return fmt.Errorf("unable to configure DNS for this peer using resolvconf manager without a nameserver group with all domains configured") } - var searchDomains string - appendedDomains := 0 - for _, dConf := range config.domains { - if dConf.matchOnly || dConf.disabled { - continue - } + searchDomainList := searchDomains(config) + searchDomainList = mergeSearchDomains(searchDomainList, r.originalSearchDomains) - if appendedDomains >= fileMaxNumberOfSearchDomains { - // lets log all skipped domains - log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, dConf.domain) - continue - } + buf := prepareResolvConfContent( + searchDomainList, + append([]string{config.serverIP}, r.originalNameServers...), + r.othersConfigs) - if fileSearchLineBeginCharCount+len(searchDomains) > fileMaxLineCharsLimit { - // lets log all skipped domains - log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, dConf.domain) - continue - } - - searchDomains += " " + dConf.domain - appendedDomains++ - } - - originalContent, err := os.ReadFile(fileDefaultResolvConfBackupLocation) - if err != nil { - log.Errorf("Could not read existing resolv.conf") - } - content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains, string(originalContent)) - - err = r.applyConfig(content) + err = r.applyConfig(buf) if err != nil { return err } - log.Infof("added %d search domains. Search list: %s", appendedDomains, searchDomains) + log.Infof("added %d search domains. Search list: %s", len(searchDomainList), searchDomainList) return nil } @@ -84,12 +75,12 @@ func (r *resolvconf) restoreHostDNS() error { return nil } -func (r *resolvconf) applyConfig(content string) error { +func (r *resolvconf) applyConfig(content bytes.Buffer) error { cmd := exec.Command(resolvconfCommand, "-x", "-a", r.ifaceName) - cmd.Stdin = strings.NewReader(content) + cmd.Stdin = &content _, err := cmd.Output() if err != nil { - return fmt.Errorf("got an error while appying resolvconf configuration for %s interface, error: %s", r.ifaceName, err) + return fmt.Errorf("got an error while applying resolvconf configuration for %s interface, error: %s", r.ifaceName, err) } return nil } diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index fb763bec2..122aae7b5 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -4,13 +4,13 @@ import ( "context" "fmt" "net/netip" - "runtime" "sync" "github.com/miekg/dns" "github.com/mitchellh/hashstructure/v2" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/internal/listener" nbdns "github.com/netbirdio/netbird/dns" ) @@ -31,6 +31,7 @@ type Server interface { DnsIP() string UpdateDNSServer(serial uint64, update nbdns.Config) error OnUpdatedHostDNSServer(strings []string) + SearchDomains() []string } type registeredHandlerMap map[string]handlerWithStop @@ -56,6 +57,9 @@ type DefaultServer struct { interfaceName string wgAddr string + + // make sense on mobile only + searchDomainNotifier *notifier } type handlerWithStop interface { @@ -90,12 +94,15 @@ func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress st } // NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems -func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface, hostsDnsList []string) *DefaultServer { +func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface, hostsDnsList []string, config nbdns.Config, listener listener.NetworkChangeListener) *DefaultServer { log.Debugf("host dns address list is: %v", hostsDnsList) ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), "", "") ds.permanent = true ds.hostsDnsList = hostsDnsList ds.addHostRootZone() + ds.currentConfig = dnsConfigToHostDNSConfig(config, ds.service.RuntimeIP(), ds.service.RuntimePort()) + ds.searchDomainNotifier = newNotifier(ds.SearchDomains()) + ds.searchDomainNotifier.setListener(listener) setServerDns(ds) return ds } @@ -227,6 +234,21 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro } } +func (s *DefaultServer) SearchDomains() []string { + var searchDomains []string + + for _, dConf := range s.currentConfig.domains { + if dConf.disabled { + continue + } + if dConf.matchOnly { + continue + } + searchDomains = append(searchDomains, dConf.domain) + } + return searchDomains +} + func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { // is the service should be disabled, we stop the listener or fake resolver // and proceed with a regular update to clean up the handlers and records @@ -261,6 +283,10 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { log.Error(err) } + if s.searchDomainNotifier != nil { + s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains()) + } + return nil } @@ -303,7 +329,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam handler := newUpstreamResolver(s.ctx, s.interfaceName, s.wgAddr) for _, ns := range nsGroup.NameServers { if ns.NSType != nbdns.UDPNameServerType { - log.Warnf("skiping nameserver %s with type %s, this peer supports only %s", + log.Warnf("skipping nameserver %s with type %s, this peer supports only %s", ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String()) continue } @@ -321,7 +347,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam // reapply DNS settings, but it not touch the original configuration and serial number // because it is temporal deactivation until next try // - // after some period defined by upstream it trys to reactivate self by calling this hook + // after some period defined by upstream it tries to reactivate self by calling this hook // everything we need here is just to re-apply current configuration because it already // contains this upstream settings (temporal deactivation not removed it) handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler) diff --git a/client/internal/dns/server_export_test.go b/client/internal/dns/server_export_test.go index 784dcb3ad..1fa343b52 100644 --- a/client/internal/dns/server_export_test.go +++ b/client/internal/dns/server_export_test.go @@ -19,6 +19,6 @@ func TestGetServerDns(t *testing.T) { } if srvB != srv { - t.Errorf("missmatch dns instances") + t.Errorf("mismatch dns instances") } } diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index c23b31249..d9fec43c5 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -593,7 +593,8 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) { defer wgIFace.Close() var dnsList []string - dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList) + dnsConfig := nbdns.Config{} + dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil) err = dnsServer.Initialize() if err != nil { t.Errorf("failed to initialize DNS server: %v", err) @@ -616,8 +617,8 @@ func TestDNSPermanent_updateUpstream(t *testing.T) { t.Fatal("failed to initialize wg interface") } defer wgIFace.Close() - - dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}) + dnsConfig := nbdns.Config{} + dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil) err = dnsServer.Initialize() if err != nil { t.Errorf("failed to initialize DNS server: %v", err) @@ -708,8 +709,8 @@ func TestDNSPermanent_matchOnly(t *testing.T) { t.Fatal("failed to initialize wg interface") } defer wgIFace.Close() - - dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}) + dnsConfig := nbdns.Config{} + dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil) err = dnsServer.Initialize() if err != nil { t.Errorf("failed to initialize DNS server: %v", err) diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index d5a74c3da..b93dd5bb4 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -118,7 +118,7 @@ func (u *upstreamResolver) getClientPrivate() *dns.Client { } func (u *upstreamResolver) stop() { - log.Debugf("stoping serving DNS for upstreams %s", u.upstreamServers) + log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers) u.cancel() } diff --git a/client/internal/engine.go b/client/internal/engine.go index 58f459de5..8f1f1315f 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -195,12 +195,13 @@ func (e *Engine) Start() error { var routes []*route.Route if runtime.GOOS == "android" { - routes, err = e.readInitialSettings() + var dnsConfig *nbdns.Config + routes, dnsConfig, err = e.readInitialSettings() if err != nil { return err } if e.dnsServer == nil { - e.dnsServer = dns.NewDefaultServerPermanentUpstream(e.ctx, e.wgInterface, e.mobileDep.HostDNSAddresses) + e.dnsServer = dns.NewDefaultServerPermanentUpstream(e.ctx, e.wgInterface, e.mobileDep.HostDNSAddresses, *dnsConfig, e.mobileDep.NetworkChangeListener) go e.mobileDep.DnsReadyListener.OnReady() } } else { @@ -214,17 +215,16 @@ func (e *Engine) Start() error { } } - log.Debugf("Initial routes contain %d routes", len(routes)) e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, routes) - e.mobileDep.RouteListener.SetInterfaceIP(wgAddr) - - e.routeManager.SetRouteChangeListener(e.mobileDep.RouteListener) + e.mobileDep.NetworkChangeListener.SetInterfaceIP(wgAddr) + e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) switch runtime.GOOS { case "android": err = e.wgInterface.CreateOnAndroid(iface.MobileIFaceArguments{ - Routes: e.routeManager.InitialRouteRange(), - Dns: e.dnsServer.DnsIP(), + Routes: e.routeManager.InitialRouteRange(), + Dns: e.dnsServer.DnsIP(), + SearchDomains: e.dnsServer.SearchDomains(), }) case "ios": err = e.wgInterface.CreateOniOS(e.mobileDep.FileDescriptor) @@ -724,8 +724,9 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config { for _, nsGroup := range protoDNSConfig.GetNameServerGroups() { dnsNSGroup := &nbdns.NameServerGroup{ - Primary: nsGroup.GetPrimary(), - Domains: nsGroup.GetDomains(), + Primary: nsGroup.GetPrimary(), + Domains: nsGroup.GetDomains(), + SearchDomainsEnabled: nsGroup.GetSearchDomainsEnabled(), } for _, ns := range nsGroup.GetNameServers() { dnsNS := nbdns.NameServer{ @@ -1060,13 +1061,14 @@ func (e *Engine) close() { } } -func (e *Engine) readInitialSettings() ([]*route.Route, error) { +func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) { netMap, err := e.mgmClient.GetNetworkMap() if err != nil { - return nil, err + return nil, nil, err } routes := toRoutes(netMap.GetRoutes()) - return routes, nil + dnsCfg := toDNSConfig(netMap.GetDNSConfig()) + return routes, &dnsCfg, nil } func findIPFromInterfaceName(ifaceName string) (net.IP, error) { diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index ea4a23a8d..42012bd0a 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -1039,10 +1039,11 @@ func startManagement(dataDir string) (*grpc.Server, string, error) { return nil, "", err } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, err := server.NewFileStore(config.Datadir, nil) + store, err := server.NewStoreFromJson(config.Datadir, nil) if err != nil { - log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) + return nil, "", err } + peersUpdateManager := server.NewPeersUpdateManager() eventStore := &activity.InMemoryEventStore{} if err != nil { diff --git a/client/internal/listener/network_change.go b/client/internal/listener/network_change.go new file mode 100644 index 000000000..ff9cb11f5 --- /dev/null +++ b/client/internal/listener/network_change.go @@ -0,0 +1,7 @@ +package listener + +// NetworkChangeListener is a callback interface for mobile system +type NetworkChangeListener interface { + // OnNetworkChanged invoke when network settings has been changed + OnNetworkChanged() +} diff --git a/client/internal/mobile_dependency.go b/client/internal/mobile_dependency.go index b027937ad..0f762a570 100644 --- a/client/internal/mobile_dependency.go +++ b/client/internal/mobile_dependency.go @@ -2,19 +2,19 @@ package internal import ( "github.com/netbirdio/netbird/client/internal/dns" - "github.com/netbirdio/netbird/client/internal/routemanager" + "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/iface" ) // MobileDependency collect all dependencies for mobile platform type MobileDependency struct { - TunAdapter iface.TunAdapter - IFaceDiscover stdnet.ExternalIFaceDiscover - RouteListener routemanager.RouteListener - HostDNSAddresses []string - DnsReadyListener dns.ReadyListener - DnsManager dns.IosDnsManager - FileDescriptor int32 - InterfaceName string + TunAdapter iface.TunAdapter + IFaceDiscover stdnet.ExternalIFaceDiscover + NetworkChangeListener listener.NetworkChangeListener + HostDNSAddresses []string + DnsReadyListener dns.ReadyListener + DnsManager dns.IosDnsManager + FileDescriptor int32 + InterfaceName string } diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index 62fe4dfc1..fda7b012f 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -119,7 +119,7 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro log.Warnf("the network %s has not been assigned a routing peer as no peers from the list %s are currently connected", c.network, peers) } else if chosen != currID { - log.Infof("new chosen route is %s with peer %s with score %d", chosen, c.routes[chosen].Peer, chosenScore) + log.Infof("new chosen route is %s with peer %s with score %d for network %s", chosen, c.routes[chosen].Peer, chosenScore, c.network) } return chosen diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index b31fe6327..1f812983c 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -7,6 +7,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" @@ -16,7 +17,7 @@ import ( // Manager is a route manager interface type Manager interface { UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error - SetRouteChangeListener(listener RouteListener) + SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string Stop() } @@ -96,7 +97,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro } // SetRouteChangeListener set RouteListener for route change notifier -func (m *DefaultManager) SetRouteChangeListener(listener RouteListener) { +func (m *DefaultManager) SetRouteChangeListener(listener listener.NetworkChangeListener) { m.notifier.setListener(listener) } @@ -155,7 +156,7 @@ func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string] // if prefix is too small, lets assume is a possible default route which is not yet supported // we skip this route management if newRoute.Network.Bits() < 7 { - log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route", + log.Errorf("this agent version: %s, doesn't support default routes, received %s, skipping this route", version.NetbirdVersion(), newRoute.Network) continue } diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index f56dbfb17..8970841a2 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" ) @@ -32,7 +33,7 @@ func (m *MockManager) Start(ctx context.Context, iface *iface.WGIface) { } // SetRouteChangeListener mock implementation of SetRouteChangeListener from Manager interface -func (m *MockManager) SetRouteChangeListener(listener RouteListener) { +func (m *MockManager) SetRouteChangeListener(listener listener.NetworkChangeListener) { } diff --git a/client/internal/routemanager/nftables_linux.go b/client/internal/routemanager/nftables_linux.go index 25dc6e7db..e62b1a404 100644 --- a/client/internal/routemanager/nftables_linux.go +++ b/client/internal/routemanager/nftables_linux.go @@ -135,7 +135,8 @@ func (n *nftablesManager) RestoreOrCreateContainers() error { } for _, table := range tables { - if table.Name == "filter" { + if table.Name == "filter" && table.Family == nftables.TableFamilyIPv4 { + log.Debugf("nftables: found filter table for ipv4") n.filterTable = table continue } @@ -486,7 +487,7 @@ func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error { if len(n.rules) == 2 && n.defaultForwardRules[0] != nil { err := n.eraseDefaultForwardRule() if err != nil { - log.Errorf("failed to delte default fwd rule: %s", err) + log.Errorf("failed to delete default fwd rule: %s", err) } } diff --git a/client/internal/routemanager/notifier.go b/client/internal/routemanager/notifier.go index 752cdd7db..a10c76ac0 100644 --- a/client/internal/routemanager/notifier.go +++ b/client/internal/routemanager/notifier.go @@ -2,37 +2,34 @@ package routemanager import ( "sort" - "strings" "sync" - log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/route" ) // RouteListener is a callback interface for mobile system -type RouteListener interface { - // OnNewRouteSetting invoke when new route setting has been arrived - OnNewRouteSetting(string) - SetInterfaceIP(string) -} +// type RouteListener interface { +// // OnNewRouteSetting invoke when new route setting has been arrived +// OnNewRouteSetting(string) +// SetInterfaceIP(string) +// } type notifier struct { initialRouteRangers []string routeRangers []string - routeListener RouteListener - routeListenerMux sync.Mutex + listener listener.NetworkChangeListener + listenerMux sync.Mutex } func newNotifier() *notifier { return ¬ifier{} } -func (n *notifier) setListener(listener RouteListener) { - n.routeListenerMux.Lock() - defer n.routeListenerMux.Unlock() - n.routeListener = listener +func (n *notifier) setListener(listener listener.NetworkChangeListener) { + n.listenerMux.Lock() + defer n.listenerMux.Unlock() + n.listener = listener } func (n *notifier) setInitialClientRoutes(clientRoutes []*route.Route) { @@ -63,16 +60,16 @@ func (n *notifier) onNewRoutes(idMap map[string][]*route.Route) { } func (n *notifier) notify() { - n.routeListenerMux.Lock() - defer n.routeListenerMux.Unlock() - if n.routeListener == nil { + n.listenerMux.Lock() + defer n.listenerMux.Unlock() + if n.listener == nil { return } - go func(l RouteListener) { + go func(l listener.NetworkChangeListener) { log.Debugf("notifying route listener with route ranges: %s", strings.Join(n.routeRangers, ",")) - l.OnNewRouteSetting(strings.Join(n.routeRangers, ",")) - }(n.routeListener) + l.OnNetworkChanged(strings.Join(n.routeRangers, ",")) + }(n.listener) } func (n *notifier) hasDiff(a []string, b []string) bool { diff --git a/client/ui/build-ui-linux.sh b/client/ui/build-ui-linux.sh new file mode 100644 index 000000000..eab08214d --- /dev/null +++ b/client/ui/build-ui-linux.sh @@ -0,0 +1,5 @@ +#!/bin/bash +sudo apt update +sudo apt remove gir1.2-appindicator3-0.1 +sudo apt install -y libayatana-appindicator3-dev +go build \ No newline at end of file diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index e6b4394e8..e66d03d95 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -15,8 +15,10 @@ import ( "runtime" "strconv" "strings" + "sync" "syscall" "time" + "unicode" "fyne.io/fyne/v2" "fyne.io/fyne/v2/app" @@ -74,18 +76,30 @@ func main() { } } -//go:embed connected.ico +//go:embed netbird-systemtray-connected.ico var iconConnectedICO []byte -//go:embed connected.png +//go:embed netbird-systemtray-connected.png var iconConnectedPNG []byte -//go:embed disconnected.ico +//go:embed netbird-systemtray-default.ico var iconDisconnectedICO []byte -//go:embed disconnected.png +//go:embed netbird-systemtray-default.png var iconDisconnectedPNG []byte +//go:embed netbird-systemtray-update.ico +var iconUpdateICO []byte + +//go:embed netbird-systemtray-update.png +var iconUpdatePNG []byte + +//go:embed netbird-systemtray-update-cloud.ico +var iconUpdateCloudICO []byte + +//go:embed netbird-systemtray-update-cloud.png +var iconUpdateCloudPNG []byte + type serviceClient struct { ctx context.Context addr string @@ -93,14 +107,20 @@ type serviceClient struct { icConnected []byte icDisconnected []byte + icUpdate []byte + icUpdateCloud []byte // systray menu items - mStatus *systray.MenuItem - mUp *systray.MenuItem - mDown *systray.MenuItem - mAdminPanel *systray.MenuItem - mSettings *systray.MenuItem - mQuit *systray.MenuItem + mStatus *systray.MenuItem + mUp *systray.MenuItem + mDown *systray.MenuItem + mAdminPanel *systray.MenuItem + mSettings *systray.MenuItem + mAbout *systray.MenuItem + mVersionUI *systray.MenuItem + mVersionDaemon *systray.MenuItem + mUpdate *systray.MenuItem + mQuit *systray.MenuItem // application with main windows. app fyne.App @@ -118,6 +138,11 @@ type serviceClient struct { managementURL string preSharedKey string adminURL string + + update *version.Update + daemonVersion string + updateIndicationLock sync.Mutex + isUpdateIconActive bool } // newServiceClient instance constructor @@ -130,14 +155,20 @@ func newServiceClient(addr string, a fyne.App, showSettings bool) *serviceClient app: a, showSettings: showSettings, + update: version.NewUpdate(), } if runtime.GOOS == "windows" { s.icConnected = iconConnectedICO s.icDisconnected = iconDisconnectedICO + s.icUpdate = iconUpdateICO + s.icUpdateCloud = iconUpdateCloudICO + } else { s.icConnected = iconConnectedPNG s.icDisconnected = iconDisconnectedPNG + s.icUpdate = iconUpdatePNG + s.icUpdateCloud = iconUpdateCloudPNG } if showSettings { @@ -202,9 +233,10 @@ func (s *serviceClient) getSettingsForm() *widget.Form { } _, err = client.Login(s.ctx, &proto.LoginRequest{ - ManagementUrl: s.iMngURL.Text, - AdminURL: s.iAdminURL.Text, - PreSharedKey: s.iPreSharedKey.Text, + ManagementUrl: s.iMngURL.Text, + AdminURL: s.iAdminURL.Text, + PreSharedKey: s.iPreSharedKey.Text, + IsLinuxDesktopClient: runtime.GOOS == "linux", }) if err != nil { log.Errorf("login to management URL: %v", err) @@ -233,7 +265,9 @@ func (s *serviceClient) login() error { return err } - loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{}) + loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{ + IsLinuxDesktopClient: runtime.GOOS == "linux", + }) if err != nil { log.Errorf("login to management URL with: %v", err) return err @@ -325,19 +359,53 @@ func (s *serviceClient) updateStatus() error { return err } + s.updateIndicationLock.Lock() + defer s.updateIndicationLock.Unlock() + + var systrayIconState bool if status.Status == string(internal.StatusConnected) && !s.mUp.Disabled() { - systray.SetIcon(s.icConnected) + if !s.isUpdateIconActive { + systray.SetIcon(s.icConnected) + } systray.SetTooltip("NetBird (Connected)") s.mStatus.SetTitle("Connected") s.mUp.Disable() s.mDown.Enable() + systrayIconState = true } else if status.Status != string(internal.StatusConnected) && s.mUp.Disabled() { - systray.SetIcon(s.icDisconnected) + if !s.isUpdateIconActive { + systray.SetIcon(s.icDisconnected) + } systray.SetTooltip("NetBird (Disconnected)") s.mStatus.SetTitle("Disconnected") s.mDown.Disable() s.mUp.Enable() + systrayIconState = false } + + // the updater struct notify by the upgrades available only, but if meanwhile the daemon has successfully + // updated must reset the mUpdate visibility state + if s.daemonVersion != status.DaemonVersion { + s.mUpdate.Hide() + s.daemonVersion = status.DaemonVersion + + s.isUpdateIconActive = s.update.SetDaemonVersion(status.DaemonVersion) + if !s.isUpdateIconActive { + if systrayIconState { + systray.SetIcon(s.icConnected) + s.mAbout.SetIcon(s.icConnected) + } else { + systray.SetIcon(s.icDisconnected) + s.mAbout.SetIcon(s.icDisconnected) + } + } + + daemonVersionTitle := normalizedVersion(s.daemonVersion) + s.mVersionDaemon.SetTitle(fmt.Sprintf("Daemon: %s", daemonVersionTitle)) + s.mVersionDaemon.SetTooltip(fmt.Sprintf("Daemon version: %s", daemonVersionTitle)) + s.mVersionDaemon.Show() + } + return nil }, &backoff.ExponentialBackOff{ InitialInterval: time.Second, @@ -371,11 +439,24 @@ func (s *serviceClient) onTrayReady() { systray.AddSeparator() s.mSettings = systray.AddMenuItem("Settings", "Settings of the application") systray.AddSeparator() - v := systray.AddMenuItem("v"+version.NetbirdVersion(), "Client Version: "+version.NetbirdVersion()) - v.Disable() + + s.mAbout = systray.AddMenuItem("About", "About") + s.mAbout.SetIcon(s.icDisconnected) + versionString := normalizedVersion(version.NetbirdVersion()) + s.mVersionUI = s.mAbout.AddSubMenuItem(fmt.Sprintf("GUI: %s", versionString), fmt.Sprintf("GUI Version: %s", versionString)) + s.mVersionUI.Disable() + + s.mVersionDaemon = s.mAbout.AddSubMenuItem("", "") + s.mVersionDaemon.Disable() + s.mVersionDaemon.Hide() + + s.mUpdate = s.mAbout.AddSubMenuItem("Download latest version", "Download latest version") + s.mUpdate.Hide() + systray.AddSeparator() s.mQuit = systray.AddMenuItem("Quit", "Quit the client app") + s.update.SetOnUpdateListener(s.onUpdateAvailable) go func() { s.getSrvConfig() for { @@ -433,6 +514,11 @@ func (s *serviceClient) onTrayReady() { case <-s.mQuit.ClickedCh: systray.Quit() return + case <-s.mUpdate.ClickedCh: + err := openURL(version.DownloadUrl()) + if err != nil { + log.Errorf("%s", err) + } } if err != nil { log.Errorf("process connection: %v", err) @@ -441,6 +527,14 @@ func (s *serviceClient) onTrayReady() { }() } +func normalizedVersion(version string) string { + versionString := version + if unicode.IsDigit(rune(versionString[0])) { + versionString = fmt.Sprintf("v%s", versionString) + } + return versionString +} + func (s *serviceClient) onTrayExit() {} // getSrvClient connection to the service. @@ -501,6 +595,32 @@ func (s *serviceClient) getSrvConfig() { } } +func (s *serviceClient) onUpdateAvailable() { + s.updateIndicationLock.Lock() + defer s.updateIndicationLock.Unlock() + + s.mUpdate.Show() + s.mAbout.SetIcon(s.icUpdateCloud) + + s.isUpdateIconActive = true + systray.SetIcon(s.icUpdate) +} + +func openURL(url string) error { + var err error + switch runtime.GOOS { + case "windows": + err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() + case "darwin": + err = exec.Command("open", url).Start() + case "linux": + err = exec.Command("xdg-open", url).Start() + default: + err = fmt.Errorf("unsupported platform") + } + return err +} + // checkPIDFile exists and return error, or write new. func checkPIDFile() error { pidFile := path.Join(os.TempDir(), "wiretrustee-ui.pid") diff --git a/client/ui/connected.ico b/client/ui/connected.ico deleted file mode 100644 index 3dd598fa7..000000000 Binary files a/client/ui/connected.ico and /dev/null differ diff --git a/client/ui/connected.png b/client/ui/connected.png deleted file mode 100644 index 64b25ade1..000000000 Binary files a/client/ui/connected.png and /dev/null differ diff --git a/client/ui/disconnected.ico b/client/ui/disconnected.ico deleted file mode 100644 index 2bab8a503..000000000 Binary files a/client/ui/disconnected.ico and /dev/null differ diff --git a/client/ui/disconnected.png b/client/ui/disconnected.png deleted file mode 100644 index c74a30b99..000000000 Binary files a/client/ui/disconnected.png and /dev/null differ diff --git a/client/ui/netbird-systemtray-connected.ico b/client/ui/netbird-systemtray-connected.ico new file mode 100644 index 000000000..621afce9f Binary files /dev/null and b/client/ui/netbird-systemtray-connected.ico differ diff --git a/client/ui/netbird-systemtray-connected.png b/client/ui/netbird-systemtray-connected.png new file mode 100644 index 000000000..c5878d018 Binary files /dev/null and b/client/ui/netbird-systemtray-connected.png differ diff --git a/client/ui/netbird-systemtray-default.ico b/client/ui/netbird-systemtray-default.ico new file mode 100644 index 000000000..5a0252675 Binary files /dev/null and b/client/ui/netbird-systemtray-default.ico differ diff --git a/client/ui/netbird-systemtray-default.png b/client/ui/netbird-systemtray-default.png new file mode 100644 index 000000000..12e7a2dc1 Binary files /dev/null and b/client/ui/netbird-systemtray-default.png differ diff --git a/client/ui/netbird-systemtray-update-cloud.ico b/client/ui/netbird-systemtray-update-cloud.ico new file mode 100644 index 000000000..b87c6f4b5 Binary files /dev/null and b/client/ui/netbird-systemtray-update-cloud.ico differ diff --git a/client/ui/netbird-systemtray-update-cloud.png b/client/ui/netbird-systemtray-update-cloud.png new file mode 100644 index 000000000..e9d0b8035 Binary files /dev/null and b/client/ui/netbird-systemtray-update-cloud.png differ diff --git a/client/ui/netbird-systemtray-update.ico b/client/ui/netbird-systemtray-update.ico new file mode 100644 index 000000000..1a1c4086d Binary files /dev/null and b/client/ui/netbird-systemtray-update.ico differ diff --git a/client/ui/netbird-systemtray-update.png b/client/ui/netbird-systemtray-update.png new file mode 100644 index 000000000..1f4651df9 Binary files /dev/null and b/client/ui/netbird-systemtray-update.png differ diff --git a/dns/nameserver.go b/dns/nameserver.go index 7751f8e1c..bb904b165 100644 --- a/dns/nameserver.go +++ b/dns/nameserver.go @@ -50,21 +50,25 @@ func ToNameServerType(typeString string) NameServerType { // NameServerGroup group of nameservers and with group ids type NameServerGroup struct { // ID identifier of group - ID string + ID string `gorm:"primaryKey"` + // AccountID is a reference to Account that this object belongs + AccountID string `gorm:"index"` // Name group name Name string // Description group description Description string // NameServers list of nameservers - NameServers []NameServer + NameServers []NameServer `gorm:"serializer:json"` // Groups list of peer group IDs to distribute the nameservers information - Groups []string + Groups []string `gorm:"serializer:json"` // Primary indicates that the nameserver group is the primary resolver for any dns query Primary bool // Domains indicate the dns query domains to use with this nameserver group - Domains []string + Domains []string `gorm:"serializer:json"` // Enabled group status Enabled bool + // SearchDomainsEnabled indicates whether to add match domains to search domains list or not + SearchDomainsEnabled bool } // NameServer represents a DNS nameserver @@ -131,14 +135,15 @@ func ParseNameServerURL(nsURL string) (NameServer, error) { // Copy copies a nameserver group object func (g *NameServerGroup) Copy() *NameServerGroup { nsGroup := &NameServerGroup{ - ID: g.ID, - Name: g.Name, - Description: g.Description, - NameServers: make([]NameServer, len(g.NameServers)), - Groups: make([]string, len(g.Groups)), - Enabled: g.Enabled, - Primary: g.Primary, - Domains: make([]string, len(g.Domains)), + ID: g.ID, + Name: g.Name, + Description: g.Description, + NameServers: make([]NameServer, len(g.NameServers)), + Groups: make([]string, len(g.Groups)), + Enabled: g.Enabled, + Primary: g.Primary, + Domains: make([]string, len(g.Domains)), + SearchDomainsEnabled: g.SearchDomainsEnabled, } copy(nsGroup.NameServers, g.NameServers) @@ -154,6 +159,7 @@ func (g *NameServerGroup) IsEqual(other *NameServerGroup) bool { other.Name == g.Name && other.Description == g.Description && other.Primary == g.Primary && + other.SearchDomainsEnabled == g.SearchDomainsEnabled && compareNameServerList(g.NameServers, other.NameServers) && compareGroupsList(g.Groups, other.Groups) && compareGroupsList(g.Domains, other.Domains) diff --git a/encryption/encryption.go b/encryption/encryption.go index 1c6ec7806..abdf4cb2f 100644 --- a/encryption/encryption.go +++ b/encryption/encryption.go @@ -30,7 +30,7 @@ func Decrypt(encryptedMsg []byte, peerPublicKey wgtypes.Key, privateKey wgtypes. return nil, err } if len(encryptedMsg) < nonceSize { - return nil, fmt.Errorf("invalid encrypted message lenght") + return nil, fmt.Errorf("invalid encrypted message length") } copy(nonce[:], encryptedMsg[:nonceSize]) opened, ok := box.Open(nil, encryptedMsg[nonceSize:], nonce, toByte32(peerPublicKey), toByte32(privateKey)) diff --git a/go.mod b/go.mod index 8be159997..c6c8221e1 100644 --- a/go.mod +++ b/go.mod @@ -17,12 +17,12 @@ require ( github.com/spf13/cobra v1.6.1 github.com/spf13/pflag v1.0.5 github.com/vishvananda/netlink v1.1.0 - golang.org/x/crypto v0.9.0 - golang.org/x/sys v0.8.0 + golang.org/x/crypto v0.14.0 + golang.org/x/sys v0.13.0 golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de golang.zx2c4.com/wireguard/windows v0.5.3 - google.golang.org/grpc v1.55.0 + google.golang.org/grpc v1.56.3 google.golang.org/protobuf v1.30.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0 ) @@ -46,11 +46,12 @@ require ( github.com/hashicorp/go-version v1.6.0 github.com/libp2p/go-netroute v0.2.0 github.com/magiconair/properties v1.8.5 - github.com/mattn/go-sqlite3 v1.14.16 + github.com/mattn/go-sqlite3 v1.14.17 github.com/mdlayher/socket v0.4.0 github.com/miekg/dns v1.1.43 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 + github.com/netbirdio/management-integrations/integrations v0.0.0-20231027143200-a966bce7db88 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pion/logging v0.2.2 @@ -68,12 +69,14 @@ require ( goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028 - golang.org/x/net v0.10.0 + golang.org/x/net v0.17.0 golang.org/x/oauth2 v0.8.0 golang.org/x/sync v0.2.0 - golang.org/x/term v0.8.0 + golang.org/x/term v0.13.0 google.golang.org/api v0.126.0 gopkg.in/yaml.v3 v3.0.1 + gorm.io/driver/sqlite v1.5.3 + gorm.io/gorm v1.25.4 ) require ( @@ -110,6 +113,8 @@ require ( github.com/googleapis/gax-go/v2 v2.10.0 // indirect github.com/hashicorp/go-uuid v1.0.2 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect github.com/josharian/native v1.0.0 // indirect github.com/kelseyhightower/envconfig v1.4.0 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect @@ -135,9 +140,9 @@ require ( go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/otel/sdk v1.11.1 // indirect go.opentelemetry.io/otel/trace v1.11.1 // indirect - golang.org/x/image v0.5.0 // indirect + golang.org/x/image v0.10.0 // indirect golang.org/x/mod v0.8.0 // indirect - golang.org/x/text v0.9.0 // indirect + golang.org/x/text v0.13.0 // indirect golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac // indirect golang.org/x/tools v0.6.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect @@ -154,6 +159,6 @@ require ( replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 -replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20221012095658-dc8eda872c0c +replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20230524172305-5a498a82b33f diff --git a/go.sum b/go.sum index 25182ca85..84b8816e9 100644 --- a/go.sum +++ b/go.sum @@ -383,6 +383,10 @@ github.com/inconshreveable/mousetrap v1.0.1/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLf github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackmordaunt/icns v0.0.0-20181231085925-4f16af745526/go.mod h1:UQkeMHVoNcyXYq9otUupF7/h/2tmHlhrS2zw7ZVvUqc= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= github.com/josephspurrier/goversioninfo v0.0.0-20200309025242-14b0ab84c6ca/go.mod h1:eJTEwMjXb7kZ633hO3Ln9mBUCOjX2+FlTljvpl9SYdE= github.com/josharian/native v0.0.0-20200817173448-b6b71def0850/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= @@ -441,8 +445,8 @@ github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mattbaird/jsonpatch v0.0.0-20171005235357-81af80346b1a/go.mod h1:M1qoD/MqPgTZIk0EWKB38wE28ACRfVcn+cU08jyArI0= github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= -github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= -github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= +github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= @@ -491,10 +495,12 @@ github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRW github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc= github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ= +github.com/netbirdio/management-integrations/integrations v0.0.0-20231027143200-a966bce7db88 h1:zhe8qseauBuYOS910jpl5sv8Tb+36zxQPXrwYXqll0g= +github.com/netbirdio/management-integrations/integrations v0.0.0-20231027143200-a966bce7db88/go.mod h1:KSqjzHcqlodTWiuap5lRXxt5KT3vtYRoksL0KIrTK40= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= -github.com/netbirdio/systray v0.0.0-20221012095658-dc8eda872c0c h1:wK/s4nyZj/GF/kFJQjX6nqNfE0G3gcqd6hhnPCyp4sw= -github.com/netbirdio/systray v0.0.0-20221012095658-dc8eda872c0c/go.mod h1:AecygODWIsBquJCJFop8MEQcJbWFfw/1yWbVabNgpCM= +github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM= +github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949/go.mod h1:AecygODWIsBquJCJFop8MEQcJbWFfw/1yWbVabNgpCM= github.com/netbirdio/wireguard-go v0.0.0-20230524172305-5a498a82b33f h1:WQXGYCKPkNs1KusFTLieV73UVTNfZVyez4CFRvlOruM= github.com/netbirdio/wireguard-go v0.0.0-20230524172305-5a498a82b33f/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= @@ -724,8 +730,8 @@ golang.org/x/crypto v0.0.0-20211202192323-5770296d904e/go.mod h1:IxCIyHEi3zRg3s0 golang.org/x/crypto v0.0.0-20220314234659-1baeb1ce4c0b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU= golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= -golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= -golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= +golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= +golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -741,8 +747,8 @@ golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf/go.mod h1:yh0Ynu2b5ZUe3MQfp2 golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.0.0-20200430140353-33d19683fad8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/image v0.5.0 h1:5JMiNunQeQw++mMOz48/ISeNu3Iweh/JaZU8ZLqHRrI= -golang.org/x/image v0.5.0/go.mod h1:FVC7BI/5Ym8R25iw5OLsgshdUBbT1h5jZTpA+mvAdZ4= +golang.org/x/image v0.10.0 h1:gXjUUtwtx5yOE0VKWq1CH4IJAClq4UGgUA3i+rpON9M= +golang.org/x/image v0.10.0/go.mod h1:jtrku+n79PfroUbvDdeUWMAI+heR786BofxrbiSF+J0= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -832,8 +838,9 @@ golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= +golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -957,15 +964,17 @@ golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.8.0 h1:n5xxQn2i3PC0yLAbjTpNT85q/Kgzcr2gIoX9OrJUols= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek= +golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -979,8 +988,10 @@ golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -1135,8 +1146,8 @@ google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAG google.golang.org/grpc v1.39.0/go.mod h1:PImNr+rS9TWYb2O4/emRugxiyHZ5JyHW5F+RPnDzfrE= google.golang.org/grpc v1.45.0/go.mod h1:lN7owxKUQEqMfSyQikvvk5tf/6zMPsrK+ONuO11+0rQ= google.golang.org/grpc v1.51.0-dev/go.mod h1:ZgQEeidpAuNRZ8iRrlBKXZQP1ghovWIVhdJRyCDK+GI= -google.golang.org/grpc v1.55.0 h1:3Oj82/tFSCeUrRTg/5E/7d/W5A1tj6Ky1ABAuZuv5ag= -google.golang.org/grpc v1.55.0/go.mod h1:iYEXKGkEBhg1PjZQvoYEVPTDkHo1/bjTnfwTeGONTY8= +google.golang.org/grpc v1.56.3 h1:8I4C0Yq1EjstUzUJzpcRVbuYA2mODtEmpWiQoN/b2nc= +google.golang.org/grpc v1.56.3/go.mod h1:I9bI3vqKfayGqPUAwGdOSu7kt6oIJLixfffKrpXqQ9s= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= @@ -1189,6 +1200,10 @@ gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/sqlite v1.5.3 h1:7/0dUgX28KAcopdfbRWWl68Rflh6osa4rDh+m51KL2g= +gorm.io/driver/sqlite v1.5.3/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4= +gorm.io/gorm v1.25.4 h1:iyNd8fNAe8W9dvtlgeRI5zSVZPsq3OpcTu37cYcpCmw= +gorm.io/gorm v1.25.4/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= gotest.tools/v3 v3.4.0/go.mod h1:CtbdzLSsqVhDgMtKsx03ird5YTGB3ar27v0u/yKBW5g= gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 h1:Wobr37noukisGxpKo5jAsLREcpj61RxrWYzD8uwveOY= diff --git a/iface/bind/udp_mux_universal.go b/iface/bind/udp_mux_universal.go index 9792bba55..772f02b35 100644 --- a/iface/bind/udp_mux_universal.go +++ b/iface/bind/udp_mux_universal.go @@ -282,7 +282,7 @@ func (a *xorMapped) closeWaiters() { // just exit break default: - // notify tha twe have a new addr + // notify that twe have a new addr close(a.waitAddrReceived) } } diff --git a/iface/device_wrapper_test.go b/iface/device_wrapper_test.go index 9f5386587..2d3725ea4 100644 --- a/iface/device_wrapper_test.go +++ b/iface/device_wrapper_test.go @@ -59,7 +59,7 @@ func TestDeviceWrapperRead(t *testing.T) { n, err := wrapped.Read(bufs, sizes, offset) if err != nil { - t.Errorf("unexpeted error: %v", err) + t.Errorf("unexpected error: %v", err) return } if n != 1 { @@ -105,7 +105,7 @@ func TestDeviceWrapperRead(t *testing.T) { n, err := wrapped.Write(bufs, 0) if err != nil { - t.Errorf("unexpeted error: %v", err) + t.Errorf("unexpected error: %v", err) return } if n != 1 { @@ -154,7 +154,7 @@ func TestDeviceWrapperRead(t *testing.T) { n, err := wrapped.Write(bufs, 0) if err != nil { - t.Errorf("unexpeted error: %v", err) + t.Errorf("unexpected error: %v", err) return } if n != 0 { @@ -211,7 +211,7 @@ func TestDeviceWrapperRead(t *testing.T) { n, err := wrapped.Read(bufs, sizes, offset) if err != nil { - t.Errorf("unexpeted error: %v", err) + t.Errorf("unexpected error: %v", err) return } if n != 0 { diff --git a/iface/iface_test.go b/iface/iface_test.go index 3e0759d87..5ce276b75 100644 --- a/iface/iface_test.go +++ b/iface/iface_test.go @@ -13,7 +13,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -// keep darwin compability +// keep darwin compatibility const ( WgIntNumber = 2000 ) diff --git a/iface/module_linux.go b/iface/module_linux.go index e943c0ba7..11c0482d5 100644 --- a/iface/module_linux.go +++ b/iface/module_linux.go @@ -110,7 +110,7 @@ func canCreateFakeWireGuardInterface() bool { // We willingly try to create a device with an invalid // MTU here as the validation of the MTU will be performed after // the validation of the link kind and hence allows us to check - // for the existance of the wireguard module without actually + // for the existence of the wireguard module without actually // creating a link. // // As a side-effect, this will also let the kernel lazy-load @@ -271,12 +271,12 @@ func moduleStatus(name string) (status, error) { func loadModuleWithDependencies(name, path string) error { deps, err := getModuleDependencies(name) if err != nil { - return fmt.Errorf("couldn't load list of module %s dependecies", name) + return fmt.Errorf("couldn't load list of module %s dependencies", name) } for _, dep := range deps { err = loadModule(dep.name, dep.path) if err != nil { - return fmt.Errorf("couldn't load dependecy module %s for %s", dep.name, name) + return fmt.Errorf("couldn't load dependency module %s for %s", dep.name, name) } } return loadModule(name, path) diff --git a/iface/tun.go b/iface/tun.go index 51a7783a1..ec8af4c32 100644 --- a/iface/tun.go +++ b/iface/tun.go @@ -1,8 +1,9 @@ package iface type MobileIFaceArguments struct { - Routes []string - Dns string + Routes []string + Dns string + SearchDomains []string } // NetInterface represents a generic network tunnel interface diff --git a/iface/tun_adapter.go b/iface/tun_adapter.go index 0ba0bde22..da0b1695b 100644 --- a/iface/tun_adapter.go +++ b/iface/tun_adapter.go @@ -2,6 +2,6 @@ package iface // TunAdapter is an interface for create tun device from externel service type TunAdapter interface { - ConfigureInterface(address string, mtu int, dns string, routes string) (int, error) + ConfigureInterface(address string, mtu int, dns string, searchDomains string, routes string) (int, error) UpdateAddr(address string) error } diff --git a/iface/tun_android.go b/iface/tun_android.go index 6f13296bd..e938dc57b 100644 --- a/iface/tun_android.go +++ b/iface/tun_android.go @@ -40,7 +40,8 @@ func (t *tunDevice) Create(mIFaceArgs MobileIFaceArguments) error { log.Info("create tun interface") var err error routesString := t.routesToString(mIFaceArgs.Routes) - t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, mIFaceArgs.Dns, routesString) + searchDomainsToString := t.searchDomainsToString(mIFaceArgs.SearchDomains) + t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, mIFaceArgs.Dns, searchDomainsToString, routesString) if err != nil { log.Errorf("failed to create Android interface: %s", err) return err @@ -97,3 +98,7 @@ func (t *tunDevice) Close() (err error) { func (t *tunDevice) routesToString(routes []string) string { return strings.Join(routes, ";") } + +func (t *tunDevice) searchDomainsToString(searchDomains []string) string { + return strings.Join(searchDomains, ";") +} diff --git a/iface/tun_darwin.go b/iface/tun_darwin.go index 7735479b6..6e917e374 100644 --- a/iface/tun_darwin.go +++ b/iface/tun_darwin.go @@ -23,7 +23,7 @@ func (c *tunDevice) Create() error { func (c *tunDevice) assignAddr() error { cmd := exec.Command("ifconfig", c.name, "inet", c.address.IP.String(), c.address.IP.String()) if out, err := cmd.CombinedOutput(); err != nil { - log.Infof(`adding addreess command "%v" failed with output %s and error: `, cmd.String(), out) + log.Infof(`adding address command "%v" failed with output %s and error: `, cmd.String(), out) return err } diff --git a/infrastructure_files/base.setup.env b/infrastructure_files/base.setup.env index f610a9691..e254aa6f3 100644 --- a/infrastructure_files/base.setup.env +++ b/infrastructure_files/base.setup.env @@ -6,14 +6,15 @@ NETBIRD_MGMT_API_PORT=${NETBIRD_MGMT_API_PORT:-33073} # Management API endpoint address, used by the Dashboard NETBIRD_MGMT_API_ENDPOINT=https://$NETBIRD_DOMAIN:$NETBIRD_MGMT_API_PORT -# Management Certficate file path. These are generated by the Dashboard container +# Management Certificate file path. These are generated by the Dashboard container NETBIRD_LETSENCRYPT_DOMAIN=$NETBIRD_DOMAIN NETBIRD_MGMT_API_CERT_FILE="/etc/letsencrypt/live/$NETBIRD_LETSENCRYPT_DOMAIN/fullchain.pem" -# Management Certficate key file path. +# Management Certificate key file path. NETBIRD_MGMT_API_CERT_KEY_FILE="/etc/letsencrypt/live/$NETBIRD_LETSENCRYPT_DOMAIN/privkey.pem" # By default Management single account mode is enabled and domain set to $NETBIRD_DOMAIN, you may want to set this to your user's email domain NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN=$NETBIRD_DOMAIN NETBIRD_MGMT_DNS_DOMAIN=${NETBIRD_MGMT_DNS_DOMAIN:-netbird.selfhosted} +NETBIRD_MGMT_IDP_SIGNKEY_REFRESH=${NETBIRD_MGMT_IDP_SIGNKEY_REFRESH:-false} # Signal NETBIRD_SIGNAL_PROTOCOL="http" @@ -55,6 +56,9 @@ NETBIRD_AUTH_PKCE_AUDIENCE=$NETBIRD_AUTH_AUDIENCE NETBIRD_DASH_AUTH_USE_AUDIENCE=${NETBIRD_DASH_AUTH_USE_AUDIENCE:-true} NETBIRD_DASH_AUTH_AUDIENCE=$NETBIRD_AUTH_AUDIENCE +# Store config +NETBIRD_STORE_CONFIG_ENGINE=${NETBIRD_STORE_CONFIG_ENGINE:-"jsonfile"} + # exports export NETBIRD_DOMAIN export NETBIRD_AUTH_CLIENT_ID @@ -86,6 +90,7 @@ export LETSENCRYPT_VOLUMESUFFIX export NETBIRD_DISABLE_ANONYMOUS_METRICS export NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN export NETBIRD_MGMT_DNS_DOMAIN +export NETBIRD_MGMT_IDP_SIGNKEY_REFRESH export NETBIRD_SIGNAL_PROTOCOL export NETBIRD_SIGNAL_PORT export NETBIRD_AUTH_USER_ID_CLAIM @@ -97,4 +102,5 @@ export NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT export NETBIRD_AUTH_PKCE_USE_ID_TOKEN export NETBIRD_AUTH_PKCE_AUDIENCE export NETBIRD_DASH_AUTH_USE_AUDIENCE -export NETBIRD_DASH_AUTH_AUDIENCE \ No newline at end of file +export NETBIRD_DASH_AUTH_AUDIENCE +export NETBIRD_STORE_CONFIG_ENGINE \ No newline at end of file diff --git a/infrastructure_files/configure.sh b/infrastructure_files/configure.sh index 3db799068..6d2902816 100755 --- a/infrastructure_files/configure.sh +++ b/infrastructure_files/configure.sh @@ -1,4 +1,5 @@ #!/bin/bash +set -e if ! which curl >/dev/null 2>&1; then echo "This script uses curl fetch OpenID configuration from IDP." @@ -124,7 +125,7 @@ if [[ "$NETBIRD_DISABLE_LETSENCRYPT" == "true" ]]; then echo "- $NETBIRD_SIGNAL_ENDPOINT/signalexchange.SignalExchange/ -grpc-> signal:80" echo "You most likely also have to change NETBIRD_MGMT_API_ENDPOINT in base.setup.env and port-mappings in docker-compose.yml.tmpl and rerun this script." echo " The target of the forwards depends on your setup. Beware of the gRPC protocol instead of http for management and signal!" - echo "You are also free to remove any occurences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME" + echo "You are also free to remove any occurrences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME" echo "" export NETBIRD_SIGNAL_PROTOCOL="https" @@ -154,6 +155,8 @@ if [ -n "$NETBIRD_MGMT_IDP" ]; then export NETBIRD_IDP_MGMT_CLIENT_ID export NETBIRD_IDP_MGMT_CLIENT_SECRET export NETBIRD_IDP_MGMT_EXTRA_CONFIG=$EXTRA_CONFIG +else + export NETBIRD_IDP_MGMT_EXTRA_CONFIG={} fi IFS=',' read -r -a REDIRECT_URL_PORTS <<< "$NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS" @@ -170,8 +173,29 @@ if [ "$NETBIRD_DASH_AUTH_USE_AUDIENCE" = "false" ]; then export NETBIRD_AUTH_PKCE_AUDIENCE= fi +# Read the encryption key +if test -f 'management.json'; then + encKey=$(jq -r ".DataStoreEncryptionKey" management.json) + if [[ "$encKey" != "null" ]]; then + export NETBIRD_DATASTORE_ENC_KEY=$encKey + + fi +fi + env | grep NETBIRD +bkp_postfix="$(date +%s)" +if test -f 'docker-compose.yml'; then + cp docker-compose.yml "docker-compose.yml.bkp.${bkp_postfix}" +fi + +if test -f 'management.json'; then + cp management.json "management.json.bkp.${bkp_postfix}" +fi + +if test -f 'turnserver.conf'; then + cp turnserver.conf "turnserver.conf.bpk.${bkp_postfix}" +fi envsubst docker-compose.yml -envsubst management.json -envsubst turnserver.conf \ No newline at end of file +envsubst management.json +envsubst turnserver.conf diff --git a/infrastructure_files/management.json.tmpl b/infrastructure_files/management.json.tmpl index e185faa6e..64c2d0816 100644 --- a/infrastructure_files/management.json.tmpl +++ b/infrastructure_files/management.json.tmpl @@ -27,6 +27,10 @@ "Password": null }, "Datadir": "", + "DataStoreEncryptionKey": "$NETBIRD_DATASTORE_ENC_KEY", + "StoreConfig": { + "Engine": "$NETBIRD_STORE_CONFIG_ENGINE" + }, "HttpConfig": { "Address": "0.0.0.0:$NETBIRD_MGMT_API_PORT", "AuthIssuer": "$NETBIRD_AUTH_AUTHORITY", @@ -35,6 +39,7 @@ "AuthUserIDClaim": "$NETBIRD_AUTH_USER_ID_CLAIM", "CertFile":"$NETBIRD_MGMT_API_CERT_FILE", "CertKey":"$NETBIRD_MGMT_API_CERT_KEY_FILE", + "IdpSignKeyRefreshEnabled": $NETBIRD_MGMT_IDP_SIGNKEY_REFRESH, "OIDCConfigEndpoint":"$NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT" }, "IdpManagerConfig": { @@ -46,18 +51,25 @@ "ClientSecret": "$NETBIRD_IDP_MGMT_CLIENT_SECRET", "GrantType": "client_credentials" }, - "ExtraConfig": $NETBIRD_IDP_MGMT_EXTRA_CONFIG + "ExtraConfig": $NETBIRD_IDP_MGMT_EXTRA_CONFIG, + "Auth0ClientCredentials": null, + "AzureClientCredentials": null, + "KeycloakClientCredentials": null, + "ZitadelClientCredentials": null }, "DeviceAuthorizationFlow": { "Provider": "$NETBIRD_AUTH_DEVICE_AUTH_PROVIDER", "ProviderConfig": { "Audience": "$NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE", + "AuthorizationEndpoint": "", "Domain": "$NETBIRD_AUTH0_DOMAIN", "ClientID": "$NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID", + "ClientSecret": "", "TokenEndpoint": "$NETBIRD_AUTH_TOKEN_ENDPOINT", "DeviceAuthEndpoint": "$NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT", "Scope": "$NETBIRD_AUTH_DEVICE_AUTH_SCOPE", - "UseIDToken": $NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN + "UseIDToken": $NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN, + "RedirectURLs": null } }, "PKCEAuthorizationFlow": { @@ -65,6 +77,7 @@ "Audience": "$NETBIRD_AUTH_PKCE_AUDIENCE", "ClientID": "$NETBIRD_AUTH_CLIENT_ID", "ClientSecret": "$NETBIRD_AUTH_CLIENT_SECRET", + "Domain": "", "AuthorizationEndpoint": "$NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT", "TokenEndpoint": "$NETBIRD_AUTH_TOKEN_ENDPOINT", "Scope": "$NETBIRD_AUTH_SUPPORTED_SCOPES", diff --git a/infrastructure_files/setup.env.example b/infrastructure_files/setup.env.example index f9ad63846..00c0c07f9 100644 --- a/infrastructure_files/setup.env.example +++ b/infrastructure_files/setup.env.example @@ -53,6 +53,8 @@ NETBIRD_MGMT_IDP="none" # Some IDPs requires different client id and client secret for management api NETBIRD_IDP_MGMT_CLIENT_ID=$NETBIRD_AUTH_CLIENT_ID NETBIRD_IDP_MGMT_CLIENT_SECRET="" +# With some IDPs may be needed enabling automatic refresh of signing keys on expire +# NETBIRD_MGMT_IDP_SIGNKEY_REFRESH=false # NETBIRD_IDP_MGMT_EXTRA_ variables. See https://docs.netbird.io/selfhosted/identity-providers for more information about your IDP of choice. # ------------------------------------------- # Letsencrypt diff --git a/infrastructure_files/tests/setup.env b/infrastructure_files/tests/setup.env index b0999eb51..f02ef3d14 100644 --- a/infrastructure_files/tests/setup.env +++ b/infrastructure_files/tests/setup.env @@ -22,4 +22,6 @@ NETBIRD_AUTH_DEVICE_AUTH_SCOPE="openid email" NETBIRD_MGMT_IDP=$CI_NETBIRD_MGMT_IDP NETBIRD_IDP_MGMT_CLIENT_ID=$CI_NETBIRD_IDP_MGMT_CLIENT_ID NETBIRD_IDP_MGMT_CLIENT_SECRET=$CI_NETBIRD_IDP_MGMT_CLIENT_SECRET -NETBIRD_SIGNAL_PORT=12345 \ No newline at end of file +NETBIRD_SIGNAL_PORT=12345 +NETBIRD_STORE_CONFIG_ENGINE=$CI_NETBIRD_STORE_CONFIG_ENGINE +NETBIRD_MGMT_IDP_SIGNKEY_REFRESH=$CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH \ No newline at end of file diff --git a/infrastructure_files/turnserver.conf.tmpl b/infrastructure_files/turnserver.conf.tmpl index e5d3b231d..9b31cb511 100644 --- a/infrastructure_files/turnserver.conf.tmpl +++ b/infrastructure_files/turnserver.conf.tmpl @@ -696,7 +696,7 @@ no-cli #web-admin-port=8080 # Web-admin server listen on STUN/TURN worker threads -# By default it is disabled for security resons! (Not recommended in any production environment!) +# By default it is disabled for security reasons! (Not recommended in any production environment!) # #web-admin-listen-on-workers diff --git a/management/client/client_test.go b/management/client/client_test.go index 86c598adb..889b7a131 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -16,7 +16,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/encryption" - "github.com/netbirdio/netbird/management/proto" mgmtProto "github.com/netbirdio/netbird/management/proto" mgmt "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/mock_server" @@ -53,7 +52,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { t.Fatal(err) } s := grpc.NewServer() - store, err := mgmt.NewFileStore(config.Datadir, nil) + store, err := mgmt.NewStoreFromJson(config.Datadir, nil) if err != nil { t.Fatal(err) } @@ -95,8 +94,8 @@ func startMockManagement(t *testing.T) (*grpc.Server, net.Listener, *mock_server } mgmtMockServer := &mock_server.ManagementServiceServerMock{ - GetServerKeyFunc: func(context.Context, *proto.Empty) (*proto.ServerKeyResponse, error) { - response := &proto.ServerKeyResponse{ + GetServerKeyFunc: func(context.Context, *mgmtProto.Empty) (*mgmtProto.ServerKeyResponse, error) { + response := &mgmtProto.ServerKeyResponse{ Key: serverKey.PublicKey().String(), } return response, nil @@ -300,19 +299,19 @@ func Test_SystemMetaDataFromClient(t *testing.T) { log.Fatalf("error while getting server public key from testclient, %v", err) } - var actualMeta *proto.PeerSystemMeta + var actualMeta *mgmtProto.PeerSystemMeta var actualValidKey string var wg sync.WaitGroup wg.Add(1) - mgmtMockServer.LoginFunc = func(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { + mgmtMockServer.LoginFunc = func(ctx context.Context, msg *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) { peerKey, err := wgtypes.ParseKey(msg.GetWgPubKey()) if err != nil { log.Warnf("error while parsing peer's Wireguard public key %s on Sync request.", msg.WgPubKey) return nil, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", msg.WgPubKey) } - loginReq := &proto.LoginRequest{} + loginReq := &mgmtProto.LoginRequest{} err = encryption.DecryptMessage(peerKey, serverKey, msg.Body, loginReq) if err != nil { log.Fatal(err) @@ -322,7 +321,7 @@ func Test_SystemMetaDataFromClient(t *testing.T) { actualValidKey = loginReq.GetSetupKey() wg.Done() - loginResp := &proto.LoginResponse{} + loginResp := &mgmtProto.LoginResponse{} encryptedResp, err := encryption.EncryptMessage(peerKey, serverKey, loginResp) if err != nil { return nil, err @@ -343,7 +342,7 @@ func Test_SystemMetaDataFromClient(t *testing.T) { wg.Wait() - expectedMeta := &proto.PeerSystemMeta{ + expectedMeta := &mgmtProto.PeerSystemMeta{ Hostname: info.Hostname, GoOS: info.GoOS, Kernel: info.Kernel, @@ -374,12 +373,12 @@ func Test_GetDeviceAuthorizationFlow(t *testing.T) { log.Fatalf("error while creating testClient: %v", err) } - expectedFlowInfo := &proto.DeviceAuthorizationFlow{ + expectedFlowInfo := &mgmtProto.DeviceAuthorizationFlow{ Provider: 0, - ProviderConfig: &proto.ProviderConfig{ClientID: "client"}, + ProviderConfig: &mgmtProto.ProviderConfig{ClientID: "client"}, } - mgmtMockServer.GetDeviceAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*proto.EncryptedMessage, error) { + mgmtMockServer.GetDeviceAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) { encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo) if err != nil { return nil, err @@ -418,14 +417,14 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) { log.Fatalf("error while creating testClient: %v", err) } - expectedFlowInfo := &proto.PKCEAuthorizationFlow{ - ProviderConfig: &proto.ProviderConfig{ + expectedFlowInfo := &mgmtProto.PKCEAuthorizationFlow{ + ProviderConfig: &mgmtProto.ProviderConfig{ ClientID: "client", ClientSecret: "secret", }, } - mgmtMockServer.GetPKCEAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*proto.EncryptedMessage, error) { + mgmtMockServer.GetPKCEAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) { encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo) if err != nil { return nil, err diff --git a/management/client/grpc.go b/management/client/grpc.go index e4caed4b0..ddb420ee2 100644 --- a/management/client/grpc.go +++ b/management/client/grpc.go @@ -57,7 +57,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE transportOption, grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ - Time: 15 * time.Second, + Time: 30 * time.Second, Timeout: 10 * time.Second, })) if err != nil { diff --git a/management/cmd/management.go b/management/cmd/management.go index f85cf225e..1a00a0f57 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -101,7 +101,7 @@ var ( _, valid := dns.IsDomainName(dnsDomain) if !valid || len(dnsDomain) > 192 { - return fmt.Errorf("failed parsing the provided dns-domain. Valid status: %t, Lenght: %d", valid, len(dnsDomain)) + return fmt.Errorf("failed parsing the provided dns-domain. Valid status: %t, Length: %d", valid, len(dnsDomain)) } return nil @@ -126,7 +126,7 @@ var ( if err != nil { return err } - store, err := server.NewFileStore(config.Datadir, appMetrics) + store, err := server.NewStore(config.StoreConfig.Engine, config.Datadir, appMetrics) if err != nil { return fmt.Errorf("failed creating Store: %s: %v", config.Datadir, err) } diff --git a/management/cmd/migration_down.go b/management/cmd/migration_down.go new file mode 100644 index 000000000..6d136ec1a --- /dev/null +++ b/management/cmd/migration_down.go @@ -0,0 +1,66 @@ +package cmd + +import ( + "errors" + "flag" + "fmt" + "os" + "path" + + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/util" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" +) + +var shortDown = "Rollback SQLite store to JSON file store. Please make a backup of the SQLite file before running this command." + +var downCmd = &cobra.Command{ + Use: "downgrade [--datadir directory] [--log-file console]", + Aliases: []string{"down"}, + Short: shortDown, + Long: shortDown + + "\n\n" + + "This command reads the content of {datadir}/store.db and migrates it to {datadir}/store.json that can be used by File store driver.", + RunE: func(cmd *cobra.Command, args []string) error { + flag.Parse() + err := util.InitLog(logLevel, logFile) + if err != nil { + return fmt.Errorf("failed initializing log %v", err) + } + + sqliteStorePath := path.Join(mgmtDataDir, "store.db") + if _, err := os.Stat(sqliteStorePath); errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("%s doesn't exist, couldn't continue the operation", sqliteStorePath) + } + + fileStorePath := path.Join(mgmtDataDir, "store.json") + if _, err := os.Stat(fileStorePath); err == nil { + return fmt.Errorf("%s already exists, couldn't continue the operation", fileStorePath) + } + + sqlstore, err := server.NewSqliteStore(mgmtDataDir, nil) + if err != nil { + return fmt.Errorf("failed creating file store: %s: %v", mgmtDataDir, err) + } + + sqliteStoreAccounts := len(sqlstore.GetAllAccounts()) + log.Infof("%d account will be migrated from sqlite store %s to file store %s", + sqliteStoreAccounts, sqliteStorePath, fileStorePath) + + store, err := server.NewFilestoreFromSqliteStore(sqlstore, mgmtDataDir, nil) + if err != nil { + return fmt.Errorf("failed creating file store: %s: %v", mgmtDataDir, err) + } + + fsStoreAccounts := len(store.GetAllAccounts()) + if fsStoreAccounts != sqliteStoreAccounts { + return fmt.Errorf("failed to migrate accounts from sqlite to file[]. Expected accounts: %d, got: %d", + sqliteStoreAccounts, fsStoreAccounts) + } + + log.Info("Migration finished successfully") + + return nil + }, +} diff --git a/management/cmd/migration_up.go b/management/cmd/migration_up.go new file mode 100644 index 000000000..5c7505cfc --- /dev/null +++ b/management/cmd/migration_up.go @@ -0,0 +1,66 @@ +package cmd + +import ( + "errors" + "flag" + "fmt" + "os" + "path" + + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/util" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" +) + +var shortUp = "Migrate JSON file store to SQLite store. Please make a backup of the JSON file before running this command." + +var upCmd = &cobra.Command{ + Use: "upgrade [--datadir directory] [--log-file console]", + Aliases: []string{"up"}, + Short: shortUp, + Long: shortUp + + "\n\n" + + "This command reads the content of {datadir}/store.json and migrates it to {datadir}/store.db that can be used by SQLite store driver.", + RunE: func(cmd *cobra.Command, args []string) error { + flag.Parse() + err := util.InitLog(logLevel, logFile) + if err != nil { + return fmt.Errorf("failed initializing log %v", err) + } + + fileStorePath := path.Join(mgmtDataDir, "store.json") + if _, err := os.Stat(fileStorePath); errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("%s doesn't exist, couldn't continue the operation", fileStorePath) + } + + sqlStorePath := path.Join(mgmtDataDir, "store.db") + if _, err := os.Stat(sqlStorePath); err == nil { + return fmt.Errorf("%s already exists, couldn't continue the operation", sqlStorePath) + } + + fstore, err := server.NewFileStore(mgmtDataDir, nil) + if err != nil { + return fmt.Errorf("failed creating file store: %s: %v", mgmtDataDir, err) + } + + fsStoreAccounts := len(fstore.GetAllAccounts()) + log.Infof("%d account will be migrated from file store %s to sqlite store %s", + fsStoreAccounts, fileStorePath, sqlStorePath) + + store, err := server.NewSqliteStoreFromFileStore(fstore, mgmtDataDir, nil) + if err != nil { + return fmt.Errorf("failed creating file store: %s: %v", mgmtDataDir, err) + } + + sqliteStoreAccounts := len(store.GetAllAccounts()) + if fsStoreAccounts != sqliteStoreAccounts { + return fmt.Errorf("failed to migrate accounts from file to sqlite. Expected accounts: %d, got: %d", + fsStoreAccounts, sqliteStoreAccounts) + } + + log.Info("Migration finished successfully") + + return nil + }, +} diff --git a/management/cmd/root.go b/management/cmd/root.go index 2080a6b29..de8b5b8b3 100644 --- a/management/cmd/root.go +++ b/management/cmd/root.go @@ -34,6 +34,12 @@ var ( SilenceUsage: true, } + migrationCmd = &cobra.Command{ + Use: "sqlite-migration", + Short: "Contains sub-commands to perform JSON file store to SQLite store migration and rollback", + Long: "", + SilenceUsage: true, + } // Execution control channel for stopCh signal stopCh chan int ) @@ -55,7 +61,7 @@ func init() { mgmtCmd.Flags().StringVar(&certFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect") mgmtCmd.Flags().StringVar(&certKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect") mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird") - mgmtCmd.Flags().StringVar(&dnsDomain, "dns-domain", defaultSingleAccModeDomain, fmt.Sprintf("Domain used for peer resolution. This is appended to the peer's name, e.g. pi-server. %s. Max lenght is 192 characters to allow appending to a peer name with up to 63 characters.", defaultSingleAccModeDomain)) + mgmtCmd.Flags().StringVar(&dnsDomain, "dns-domain", defaultSingleAccModeDomain, fmt.Sprintf("Domain used for peer resolution. This is appended to the peer's name, e.g. pi-server. %s. Max length is 192 characters to allow appending to a peer name with up to 63 characters.", defaultSingleAccModeDomain)) mgmtCmd.Flags().BoolVar(&idpSignKeyRefreshEnabled, "idp-sign-key-refresh-enabled", false, "Enable cache headers evaluation to determine signing key rotation period. This will refresh the signing key upon expiry.") mgmtCmd.Flags().BoolVar(&userDeleteFromIDPEnabled, "user-delete-from-idp", false, "Allows to delete user from IDP when user is deleted from account") rootCmd.MarkFlagRequired("config") //nolint @@ -63,6 +69,14 @@ func init() { rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "") rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the the log will be output to stdout") rootCmd.AddCommand(mgmtCmd) + + migrationCmd.PersistentFlags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location") + migrationCmd.MarkFlagRequired("datadir") //nolint + + migrationCmd.AddCommand(upCmd) + migrationCmd.AddCommand(downCmd) + + rootCmd.AddCommand(migrationCmd) } // SetupCloseHandler handles SIGTERM signal and exits with success diff --git a/management/proto/management.pb.go b/management/proto/management.pb.go index eb80f9299..45ef49e1f 100644 --- a/management/proto/management.pb.go +++ b/management/proto/management.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v3.21.12 +// protoc v3.21.9 // source: management.proto package proto @@ -1999,9 +1999,10 @@ type NameServerGroup struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - NameServers []*NameServer `protobuf:"bytes,1,rep,name=NameServers,proto3" json:"NameServers,omitempty"` - Primary bool `protobuf:"varint,2,opt,name=Primary,proto3" json:"Primary,omitempty"` - Domains []string `protobuf:"bytes,3,rep,name=Domains,proto3" json:"Domains,omitempty"` + NameServers []*NameServer `protobuf:"bytes,1,rep,name=NameServers,proto3" json:"NameServers,omitempty"` + Primary bool `protobuf:"varint,2,opt,name=Primary,proto3" json:"Primary,omitempty"` + Domains []string `protobuf:"bytes,3,rep,name=Domains,proto3" json:"Domains,omitempty"` + SearchDomainsEnabled bool `protobuf:"varint,4,opt,name=SearchDomainsEnabled,proto3" json:"SearchDomainsEnabled,omitempty"` } func (x *NameServerGroup) Reset() { @@ -2057,6 +2058,13 @@ func (x *NameServerGroup) GetDomains() []string { return nil } +func (x *NameServerGroup) GetSearchDomainsEnabled() bool { + if x != nil { + return x.SearchDomainsEnabled + } + return false +} + // NameServer represents a dns.NameServer type NameServer struct { state protoimpl.MessageState @@ -2444,73 +2452,76 @@ var file_management_proto_rawDesc = []byte{ 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, - 0x7f, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, - 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, - 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, - 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, - 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, - 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, - 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, - 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, - 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, - 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, - 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xf0, 0x02, 0x0a, 0x0c, 0x46, - 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, - 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, - 0x72, 0x49, 0x50, 0x12, 0x40, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x22, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, - 0x2e, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, 0x65, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x37, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, + 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, + 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, + 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, + 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, + 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, + 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, + 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, + 0xf0, 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, + 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x12, 0x40, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, + 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x22, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, + 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x2e, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, + 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x37, 0x0a, 0x06, 0x41, 0x63, + 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, + 0x52, 0x75, 0x6c, 0x65, 0x2e, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, + 0x69, 0x6f, 0x6e, 0x12, 0x3d, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x21, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x2e, - 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x3d, - 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x21, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, - 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, - 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, - 0x74, 0x22, 0x1c, 0x0a, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, - 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x22, - 0x1e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, - 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x22, - 0x3c, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, - 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, 0x10, - 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, - 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x32, 0xd1, 0x03, - 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, - 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, - 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, - 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, - 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, - 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, - 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, - 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, - 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, - 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, - 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, - 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, - 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0x1c, 0x0a, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, + 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, + 0x55, 0x54, 0x10, 0x01, 0x22, 0x1e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, + 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, + 0x4f, 0x50, 0x10, 0x01, 0x22, 0x3c, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, + 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, + 0x03, 0x41, 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, + 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, + 0x10, 0x04, 0x32, 0xd1, 0x03, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, + 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, + 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, + 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, + 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, - 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, - 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, - 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, - 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, - 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x33, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, + 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, + 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, + 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, + 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, + 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, + 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, 0x0a, 0x18, + 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, + 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/management/proto/management.proto b/management/proto/management.proto index d5b925d73..ae90beaf3 100644 --- a/management/proto/management.proto +++ b/management/proto/management.proto @@ -317,6 +317,7 @@ message NameServerGroup { repeated NameServer NameServers = 1; bool Primary = 2; repeated string Domains = 3; + bool SearchDomainsEnabled = 4; } // NameServer represents a dns.NameServer diff --git a/management/server/account.go b/management/server/account.go index ab79a6789..30a9bd200 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -36,6 +36,7 @@ const ( UnknownCategory = "unknown" GroupIssuedAPI = "api" GroupIssuedJWT = "jwt" + GroupIssuedIntegration = "integration" CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days DefaultPeerLoginExpiration = 24 * time.Hour @@ -91,7 +92,7 @@ type AccountManager interface { DeleteRoute(accountID, routeID, userID string) error ListRoutes(accountID, userID string) ([]*route.Route, error) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) - CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error) + CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error DeleteNameServerGroup(accountID, nsGroupID, userID string) error ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error) @@ -103,6 +104,7 @@ type AccountManager interface { UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error) LoginPeer(login PeerLogin) (*Peer, *NetworkMap, error) // used by peer gRPC API SyncPeer(sync PeerSync) (*Peer, *NetworkMap, error) // used by peer gRPC API + GetAllConnectedPeers() (map[string]struct{}, error) } type DefaultAccountManager struct { @@ -164,43 +166,54 @@ func (s *Settings) Copy() *Settings { // Account represents a unique account of the system type Account struct { - Id string + // we have to name column to aid as it collides with Network.Id when work with associations + Id string `gorm:"primaryKey"` + // User.Id it was created by CreatedBy string - Domain string + Domain string `gorm:"index"` DomainCategory string IsDomainPrimaryAccount bool - SetupKeys map[string]*SetupKey - Network *Network - Peers map[string]*Peer - Users map[string]*User - Groups map[string]*Group - Rules map[string]*Rule - Policies []*Policy - Routes map[string]*route.Route - NameServerGroups map[string]*nbdns.NameServerGroup - DNSSettings *DNSSettings + SetupKeys map[string]*SetupKey `gorm:"-"` + SetupKeysG []SetupKey `json:"-" gorm:"foreignKey:AccountID;references:id"` + Network *Network `gorm:"embedded;embeddedPrefix:network_"` + Peers map[string]*Peer `gorm:"-"` + PeersG []Peer `json:"-" gorm:"foreignKey:AccountID;references:id"` + Users map[string]*User `gorm:"-"` + UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"` + Groups map[string]*Group `gorm:"-"` + GroupsG []Group `json:"-" gorm:"foreignKey:AccountID;references:id"` + Rules map[string]*Rule `gorm:"-"` + RulesG []Rule `json:"-" gorm:"foreignKey:AccountID;references:id"` + Policies []*Policy `gorm:"foreignKey:AccountID;references:id"` + Routes map[string]*route.Route `gorm:"-"` + RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"` + NameServerGroups map[string]*nbdns.NameServerGroup `gorm:"-"` + NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"` + DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` // Settings is a dictionary of Account settings - Settings *Settings + Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` } type UserInfo struct { - ID string `json:"id"` - Email string `json:"email"` - Name string `json:"name"` - Role string `json:"role"` - AutoGroups []string `json:"auto_groups"` - Status string `json:"-"` - IsServiceUser bool `json:"is_service_user"` - IsBlocked bool `json:"is_blocked"` - LastLogin time.Time `json:"last_login"` + ID string `json:"id"` + Email string `json:"email"` + Name string `json:"name"` + Role string `json:"role"` + AutoGroups []string `json:"auto_groups"` + Status string `json:"-"` + IsServiceUser bool `json:"is_service_user"` + IsBlocked bool `json:"is_blocked"` + LastLogin time.Time `json:"last_login"` + Issued string `json:"issued"` + IntegrationReference IntegrationReference `json:"-"` } // getRoutesToSync returns the enabled routes for the peer ID and the routes // from the ACL peers that have distribution groups associated with the peer ID. // Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. func (a *Account) getRoutesToSync(peerID string, aclPeers []*Peer) []*route.Route { - routes, peerDisabledRoutes := a.getEnabledAndDisabledRoutesByPeer(peerID) + routes, peerDisabledRoutes := a.getRoutingPeerRoutes(peerID) peerRoutesMembership := make(lookupMap) for _, r := range append(routes, peerDisabledRoutes...) { peerRoutesMembership[route.GetHAUniqueID(r)] = struct{}{} @@ -208,7 +221,7 @@ func (a *Account) getRoutesToSync(peerID string, aclPeers []*Peer) []*route.Rout groupListMap := a.getPeerGroups(peerID) for _, peer := range aclPeers { - activeRoutes, _ := a.getEnabledAndDisabledRoutesByPeer(peer.ID) + activeRoutes, _ := a.getRoutingPeerRoutes(peer.ID) groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, groupListMap) filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership) routes = append(routes, filteredRoutes...) @@ -244,20 +257,32 @@ func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap looku return filteredRoutes } -// getEnabledAndDisabledRoutesByPeer returns the enabled and disabled lists of routes that belong to a peer. +// getRoutingPeerRoutes returns the enabled and disabled lists of routes that the given routing peer serves // Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. -func (a *Account) getEnabledAndDisabledRoutesByPeer(peerID string) ([]*route.Route, []*route.Route) { - var enabledRoutes []*route.Route - var disabledRoutes []*route.Route +// If the given is not a routing peer, then the lists are empty. +func (a *Account) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) { + + peer := a.GetPeer(peerID) + if peer == nil { + log.Errorf("peer %s that doesn't exist under account %s", peerID, a.Id) + return enabledRoutes, disabledRoutes + } + + // currently we support only linux routing peers + if peer.Meta.GoOS != "linux" { + return enabledRoutes, disabledRoutes + } + + seenRoute := make(map[string]struct{}) takeRoute := func(r *route.Route, id string) { - peer := a.GetPeer(peerID) - if peer == nil { - log.Errorf("route %s has peer %s that doesn't exist under account %s", r.ID, peerID, a.Id) + if _, ok := seenRoute[r.ID]; ok { return } + seenRoute[r.ID] = struct{}{} if r.Enabled { + r.Peer = peer.Key enabledRoutes = append(enabledRoutes, r) return } @@ -265,25 +290,30 @@ func (a *Account) getEnabledAndDisabledRoutesByPeer(peerID string) ([]*route.Rou } for _, r := range a.Routes { - if len(r.PeerGroups) != 0 { - for _, groupID := range r.PeerGroups { - group := a.GetGroup(groupID) - if group == nil { - log.Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id) + for _, groupID := range r.PeerGroups { + group := a.GetGroup(groupID) + if group == nil { + log.Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id) + continue + } + for _, id := range group.Peers { + if id != peerID { continue } - for _, id := range group.Peers { - if id == peerID { - takeRoute(r, id) - break - } - } + + newPeerRoute := r.Copy() + newPeerRoute.Peer = id + newPeerRoute.PeerGroups = nil + newPeerRoute.ID = r.ID + ":" + id // we have to provide unique route id when distribute network map + takeRoute(newPeerRoute, id) + break } } if r.Peer == peerID { - takeRoute(r, peerID) + takeRoute(r.Copy(), peerID) } } + return enabledRoutes, disabledRoutes } @@ -319,50 +349,7 @@ func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap { peersToConnect = append(peersToConnect, p) } - routes := a.getRoutesToSync(peerID, peersToConnect) - - takePeer := func(id string) (*Peer, bool) { - peer := a.GetPeer(id) - if peer == nil || peer.Meta.GoOS != "linux" { - return nil, false - } - return peer, true - } - - // We need to set Peer.Key instead of Peer.ID because this object will be sent to agents as part of a network map. - // Ideally we should have a separate field for that, but fine for now. - var routesUpdate []*route.Route - seenPeers := make(map[string]bool) - for _, r := range routes { - if r.Peer != "" { - peer, valid := takePeer(r.Peer) - if !valid { - continue - } - rCopy := r.Copy() - rCopy.Peer = peer.Key // client expects the key - routesUpdate = append(routesUpdate, rCopy) - continue - } - for _, groupID := range r.PeerGroups { - if group := a.GetGroup(groupID); group != nil { - for _, peerId := range group.Peers { - peer, valid := takePeer(peerId) - if !valid { - continue - } - - if _, ok := seenPeers[peer.ID]; !ok { - rCopy := r.Copy() - rCopy.ID = r.ID + ":" + peer.ID // we have to provide unit route id when distribute network map - rCopy.Peer = peer.Key // client expects the key - routesUpdate = append(routesUpdate, rCopy) - } - seenPeers[peer.ID] = true - } - } - } - } + routesUpdate := a.getRoutesToSync(peerID, peersToConnect) dnsManagementStatus := a.getPeerDNSManagementStatus(peerID) dnsUpdate := nbdns.Config{ @@ -538,13 +525,11 @@ func (a *Account) getUserGroups(userID string) ([]string, error) { func (a *Account) getPeerDNSManagementStatus(peerID string) bool { peerGroups := a.getPeerGroups(peerID) enabled := true - if a.DNSSettings != nil { - for _, groupID := range a.DNSSettings.DisabledManagementGroups { - _, found := peerGroups[groupID] - if found { - enabled = false - break - } + for _, groupID := range a.DNSSettings.DisabledManagementGroups { + _, found := peerGroups[groupID] + if found { + enabled = false + break } } return enabled @@ -631,10 +616,7 @@ func (a *Account) Copy() *Account { nsGroups[id] = nsGroup.Copy() } - var dnsSettings *DNSSettings - if a.DNSSettings != nil { - dnsSettings = a.DNSSettings.Copy() - } + dnsSettings := a.DNSSettings.Copy() var settings *Settings if a.Settings != nil { @@ -972,6 +954,7 @@ func (am *DefaultAccountManager) warmupIDPCache() error { if err != nil { return err } + log.Infof("%d entries received from IdP management", len(userData)) // If the Identity Provider does not support writing AppMetadata, // in cases like this, we expect it to return all users in an "unset" field. @@ -1071,6 +1054,7 @@ func (am *DefaultAccountManager) loadAccount(_ context.Context, accountID interf if err != nil { return nil, err } + log.Debugf("%d entries received from IdP management", len(userData)) dataMap := make(map[string]*idp.UserData, len(userData)) for _, datum := range userData { @@ -1582,6 +1566,11 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla } } +// GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers() +func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) { + return am.peersUpdateManager.GetAllConnectedPeers(), nil +} + func isDomainValid(domain string) bool { re := regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`) return re.Match([]byte(domain)) @@ -1636,7 +1625,7 @@ func newAccountWithId(accountID, userID, domain string) *Account { setupKeys := map[string]*SetupKey{} nameServersGroups := make(map[string]*nbdns.NameServerGroup) users[userID] = NewAdminUser(userID) - dnsSettings := &DNSSettings{ + dnsSettings := DNSSettings{ DisabledManagementGroups: make([]string, 0), } log.Debugf("created new account %s", accountID) diff --git a/management/server/account_test.go b/management/server/account_test.go index d55734685..c8a8a5dc9 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -198,11 +198,11 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { netIP := net.IP{100, 64, 0, 0} netMask := net.IPMask{255, 255, 0, 0} network := &Network{ - Id: "network", - Net: net.IPNet{IP: netIP, Mask: netMask}, - Dns: "netbird.selfhosted", - Serial: 0, - mu: sync.Mutex{}, + Identifier: "network", + Net: net.IPNet{IP: netIP, Mask: netMask}, + Dns: "netbird.selfhosted", + Serial: 0, + mu: sync.Mutex{}, } for _, testCase := range tt { @@ -476,7 +476,7 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { // as initAccount was created without account id we have to take the id after account initialization // that happens inside the GetAccountByUserOrAccountID where the id is getting generated // it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it - initAccount.Id = acc.Id + initAccount = acc claims := jwtclaims.AuthorizationClaims{ AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount @@ -1025,7 +1025,6 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { wg.Wait() }) - t.Run("delete peer update", func(t *testing.T) { wg.Add(1) go func() { @@ -1117,7 +1116,7 @@ func TestAccountManager_DeletePeer(t *testing.T) { } if account.Network.CurrentSerial() != 2 { - t.Errorf("expecting Network Serial=%d to be incremented and be equal to 2 after adding and deleteing a peer", account.Network.CurrentSerial()) + t.Errorf("expecting Network Serial=%d to be incremented and be equal to 2 after adding and deleting a peer", account.Network.CurrentSerial()) } ev := getEvent(t, account.Id, manager, activity.PeerRemovedByUser) @@ -1237,7 +1236,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) { } account := &Account{ Peers: map[string]*Peer{ - "peer-1": {Key: "peer-1"}, "peer-2": {Key: "peer-2"}, "peer-3": {Key: "peer-1"}, + "peer-1": {Key: "peer-1", Meta: PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: PeerSystemMeta{GoOS: "linux"}}, }, Groups: map[string]*Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}}, Routes: map[string]*route.Route{ @@ -1309,7 +1308,7 @@ func TestAccount_Copy(t *testing.T) { }, }, Network: &Network{ - Id: "net1", + Identifier: "net1", }, Peers: map[string]*Peer{ "peer1": { @@ -1374,7 +1373,7 @@ func TestAccount_Copy(t *testing.T) { NameServers: []nbdns.NameServer{}, }, }, - DNSSettings: &DNSSettings{DisabledManagementGroups: []string{}}, + DNSSettings: DNSSettings{DisabledManagementGroups: []string{}}, Settings: &Settings{}, } err := hasNilField(account) @@ -1400,6 +1399,10 @@ func hasNilField(x interface{}) error { rv := reflect.ValueOf(x) rv = rv.Elem() for i := 0; i < rv.NumField(); i++ { + // skip gorm internal fields + if json, ok := rv.Type().Field(i).Tag.Lookup("json"); ok && json == "-" { + continue + } if f := rv.Field(i); f.IsValid() { k := f.Kind() switch k { @@ -2045,7 +2048,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) { func createStore(t *testing.T) (Store, error) { dataDir := t.TempDir() - store, err := NewFileStore(dataDir, nil) + store, err := NewStoreFromJson(dataDir, nil) if err != nil { return nil, err } diff --git a/management/server/activity/sqlite/sqlite.go b/management/server/activity/sqlite/sqlite.go index 6af4d4d8d..a5130b0c5 100644 --- a/management/server/activity/sqlite/sqlite.go +++ b/management/server/activity/sqlite/sqlite.go @@ -45,6 +45,9 @@ const ( "VALUES(?, ?, ?, ?, ?, ?)" insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?)` + + fallbackName = "unknown" + fallbackEmail = "unknown@unknown.com" ) // Store is the implementation of the activity.Store interface backed by SQLite @@ -128,6 +131,7 @@ func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) { func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) { events := make([]*activity.Event, 0) + var cryptErr error for result.Next() { var id int64 var operation activity.Activity @@ -156,8 +160,8 @@ func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) { if targetUserName != nil { name, err := store.fieldEncrypt.Decrypt(*targetUserName) if err != nil { - log.Errorf("failed to decrypt username for target id: %s", target) - meta["username"] = "" + cryptErr = fmt.Errorf("failed to decrypt username for target id: %s", target) + meta["username"] = fallbackName } else { meta["username"] = name } @@ -166,8 +170,8 @@ func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) { if targetEmail != nil { email, err := store.fieldEncrypt.Decrypt(*targetEmail) if err != nil { - log.Errorf("failed to decrypt email address for target id: %s", target) - meta["email"] = "" + cryptErr = fmt.Errorf("failed to decrypt email address for target id: %s", target) + meta["email"] = fallbackEmail } else { meta["email"] = email } @@ -186,7 +190,8 @@ func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) { if initiatorName != nil { name, err := store.fieldEncrypt.Decrypt(*initiatorName) if err != nil { - log.Errorf("failed to decrypt username of initiator: %s", initiator) + cryptErr = fmt.Errorf("failed to decrypt username of initiator: %s", initiator) + event.InitiatorName = fallbackName } else { event.InitiatorName = name } @@ -195,7 +200,8 @@ func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) { if initiatorEmail != nil { email, err := store.fieldEncrypt.Decrypt(*initiatorEmail) if err != nil { - log.Errorf("failed to decrypt email address of initiator: %s", initiator) + cryptErr = fmt.Errorf("failed to decrypt email address of initiator: %s", initiator) + event.InitiatorEmail = fallbackEmail } else { event.InitiatorEmail = email } @@ -204,6 +210,10 @@ func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) { events = append(events, event) } + if cryptErr != nil { + log.Warnf("%s", cryptErr) + } + return events, nil } diff --git a/management/server/config.go b/management/server/config.go index 31c1cf45c..4fed93bba 100644 --- a/management/server/config.go +++ b/management/server/config.go @@ -45,6 +45,8 @@ type Config struct { DeviceAuthorizationFlow *DeviceAuthorizationFlow PKCEAuthorizationFlow *PKCEAuthorizationFlow + + StoreConfig StoreConfig } // GetAuthAudiences returns the audience from the http config and device authorization flow config @@ -136,6 +138,11 @@ type ProviderConfig struct { RedirectURLs []string } +// StoreConfig contains Store configuration +type StoreConfig struct { + Engine StoreEngine +} + // validateURL validates input http url func validateURL(httpURL string) bool { _, err := url.ParseRequestURI(httpURL) diff --git a/management/server/dns.go b/management/server/dns.go index 252782aea..f90a5e9f2 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -20,23 +20,15 @@ type lookupMap map[string]struct{} // DNSSettings defines dns settings at the account level type DNSSettings struct { // DisabledManagementGroups groups whose DNS management is disabled - DisabledManagementGroups []string + DisabledManagementGroups []string `gorm:"serializer:json"` } // Copy returns a copy of the DNS settings -func (d *DNSSettings) Copy() *DNSSettings { - settings := &DNSSettings{ - DisabledManagementGroups: make([]string, 0), +func (d DNSSettings) Copy() DNSSettings { + settings := DNSSettings{ + DisabledManagementGroups: make([]string, len(d.DisabledManagementGroups)), } - - if d == nil { - return settings - } - - if d.DisabledManagementGroups != nil && len(d.DisabledManagementGroups) > 0 { - settings.DisabledManagementGroups = d.DisabledManagementGroups[:] - } - + copy(settings.DisabledManagementGroups, d.DisabledManagementGroups) return settings } @@ -58,12 +50,8 @@ func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string) if !user.IsAdmin() { return nil, status.Errorf(status.PermissionDenied, "only admins are allowed to view DNS settings") } - - if account.DNSSettings == nil { - return &DNSSettings{}, nil - } - - return account.DNSSettings.Copy(), nil + dnsSettings := account.DNSSettings.Copy() + return &dnsSettings, nil } // SaveDNSSettings validates a user role and updates the account's DNS settings @@ -96,11 +84,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string } } - oldSettings := &DNSSettings{} - if account.DNSSettings != nil { - oldSettings = account.DNSSettings.Copy() - } - + oldSettings := account.DNSSettings.Copy() account.DNSSettings = dnsSettingsToSave.Copy() account.Network.IncSerial() @@ -146,8 +130,9 @@ func toProtocolDNSConfig(update nbdns.Config) *proto.DNSConfig { for _, nsGroup := range update.NameServerGroups { protoGroup := &proto.NameServerGroup{ - Primary: nsGroup.Primary, - Domains: nsGroup.Domains, + Primary: nsGroup.Primary, + Domains: nsGroup.Domains, + SearchDomainsEnabled: nsGroup.SearchDomainsEnabled, } for _, ns := range nsGroup.NameServers { protoNS := &proto.NameServer{ @@ -231,7 +216,7 @@ func addPeerLabelsToAccount(account *Account, peerLabels lookupMap) { log.Errorf("got an error while generating a peer host label. Peer name %s, error: %v. Trying with the peer's meta hostname", peer.Name, err) label, err = getPeerHostLabel(peer.Meta.Hostname, peerLabels) if err != nil { - log.Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skiping", peer.Meta.Hostname, err) + log.Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skipping", peer.Meta.Hostname, err) continue } } diff --git a/management/server/dns_test.go b/management/server/dns_test.go index b089949b2..a2c9d3aa2 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -42,7 +42,7 @@ func TestGetDNSSettings(t *testing.T) { t.Fatal("DNS settings for new accounts shouldn't return nil") } - account.DNSSettings = &DNSSettings{ + account.DNSSettings = DNSSettings{ DisabledManagementGroups: []string{group1ID}, } @@ -196,7 +196,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { func createDNSStore(t *testing.T) (Store, error) { dataDir := t.TempDir() - store, err := NewFileStore(dataDir, nil) + store, err := NewStoreFromJson(dataDir, nil) if err != nil { return nil, err } diff --git a/management/server/file_store.go b/management/server/file_store.go index ecd02ba99..73c52927e 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -54,6 +54,25 @@ func NewFileStore(dataDir string, metrics telemetry.AppMetrics) (*FileStore, err return fs, nil } +// NewFilestoreFromSqliteStore restores a store from Sqlite and stores to Filestore json in the file located in datadir +func NewFilestoreFromSqliteStore(sqlitestore *SqliteStore, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) { + store, err := NewFileStore(dataDir, metrics) + if err != nil { + return nil, err + } + + err = store.SaveInstallationID(sqlitestore.GetInstallationID()) + if err != nil { + return nil, err + } + + for _, account := range sqlitestore.GetAllAccounts() { + store.Accounts[account.Id] = account + } + + return store, store.persist(store.storeFile) +} + // restore the state of the store from the file. // Creates a new empty store file if doesn't exist func restore(file string) (*FileStore, error) { @@ -111,13 +130,14 @@ func restore(file string) (*FileStore, error) { for _, peer := range account.Peers { store.PeerKeyID2AccountID[peer.Key] = accountID store.PeerID2AccountID[peer.ID] = accountID - // reset all peers to status = Disconnected - if peer.Status != nil && peer.Status.Connected { - peer.Status.Connected = false - } } for _, user := range account.Users { store.UserID2AccountID[user.Id] = accountID + if user.Issued == "" { + user.Issued = UserIssuedAPI + account.Users[user.Id] = user + } + for _, pat := range user.PATs { store.TokenID2UserID[pat.ID] = user.Id store.HashedPAT2TokenID[pat.HashedToken] = pat.ID @@ -599,3 +619,8 @@ func (s *FileStore) Close() error { return s.persist(s.storeFile) } + +// GetStoreEngine returns FileStoreEngine +func (s *FileStore) GetStoreEngine() StoreEngine { + return FileStoreEngine +} diff --git a/management/server/file_store_test.go b/management/server/file_store_test.go index e2f07acda..705e9f149 100644 --- a/management/server/file_store_test.go +++ b/management/server/file_store_test.go @@ -387,7 +387,7 @@ func TestFileStore_GetAccount(t *testing.T) { assert.Equal(t, expected.DomainCategory, account.DomainCategory) assert.Equal(t, expected.Domain, account.Domain) assert.Equal(t, expected.CreatedBy, account.CreatedBy) - assert.Equal(t, expected.Network.Id, account.Network.Id) + assert.Equal(t, expected.Network.Identifier, account.Network.Identifier) assert.Len(t, account.Peers, len(expected.Peers)) assert.Len(t, account.Users, len(expected.Users)) assert.Len(t, account.SetupKeys, len(expected.SetupKeys)) diff --git a/management/server/group.go b/management/server/group.go index a7502134a..d626c3538 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -23,6 +23,9 @@ type Group struct { // ID of the group ID string + // AccountID is a reference to Account that this object belongs + AccountID string `json:"-" gorm:"index"` + // Name visible in the UI Name string @@ -30,7 +33,9 @@ type Group struct { Issued string // Peers list of the group - Peers []string + Peers []string `gorm:"serializer:json"` + + IntegrationReference IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` } // EventMeta returns activity event meta related to the group @@ -40,10 +45,11 @@ func (g *Group) EventMeta() map[string]any { func (g *Group) Copy() *Group { group := &Group{ - ID: g.ID, - Name: g.Name, - Issued: g.Issued, - Peers: make([]string, len(g.Peers)), + ID: g.ID, + Name: g.Name, + Issued: g.Issued, + Peers: make([]string, len(g.Peers)), + IntegrationReference: g.IntegrationReference, } copy(group.Peers, g.Peers) return group @@ -157,6 +163,11 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) return nil } + // check integration link + if g.Issued == GroupIssuedIntegration { + return &GroupLinkError{GroupIssuedIntegration, g.IntegrationReference.String()} + } + // check route links for _, r := range account.Routes { for _, g := range r.Groups { diff --git a/management/server/group_test.go b/management/server/group_test.go index 3e2d6d3cc..5db0ca900 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -52,6 +52,11 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) { "grp-for-users", "user", }, + { + "integration", + "grp-for-integration", + "integration", + }, } for _, testCase := range testCases { @@ -79,38 +84,51 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) { domain := "example.com" groupForRoute := &Group{ - "grp-for-route", - "Group for route", - GroupIssuedAPI, - make([]string, 0), + ID: "grp-for-route", + AccountID: "account-id", + Name: "Group for route", + Issued: GroupIssuedAPI, + Peers: make([]string, 0), } groupForNameServerGroups := &Group{ - "grp-for-name-server-grp", - "Group for name server groups", - GroupIssuedAPI, - make([]string, 0), + ID: "grp-for-name-server-grp", + AccountID: "account-id", + Name: "Group for name server groups", + Issued: GroupIssuedAPI, + Peers: make([]string, 0), } groupForPolicies := &Group{ - "grp-for-policies", - "Group for policies", - GroupIssuedAPI, - make([]string, 0), + ID: "grp-for-policies", + AccountID: "account-id", + Name: "Group for policies", + Issued: GroupIssuedAPI, + Peers: make([]string, 0), } groupForSetupKeys := &Group{ - "grp-for-keys", - "Group for setup keys", - GroupIssuedAPI, - make([]string, 0), + ID: "grp-for-keys", + AccountID: "account-id", + Name: "Group for setup keys", + Issued: GroupIssuedAPI, + Peers: make([]string, 0), } groupForUsers := &Group{ - "grp-for-users", - "Group for users", - GroupIssuedAPI, - make([]string, 0), + ID: "grp-for-users", + AccountID: "account-id", + Name: "Group for users", + Issued: GroupIssuedAPI, + Peers: make([]string, 0), + } + + groupForIntegration := &Group{ + ID: "grp-for-integration", + AccountID: "account-id", + Name: "Group for users", + Issued: GroupIssuedIntegration, + Peers: make([]string, 0), } routeResource := &route.Route{ @@ -159,6 +177,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) { _ = am.SaveGroup(accountID, groupAdminUserID, groupForPolicies) _ = am.SaveGroup(accountID, groupAdminUserID, groupForSetupKeys) _ = am.SaveGroup(accountID, groupAdminUserID, groupForUsers) + _ = am.SaveGroup(accountID, groupAdminUserID, groupForIntegration) return am.Store.GetAccount(account.Id) } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 383cb0d1f..f32f6347a 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -169,7 +169,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi s.cancelPeerRoutines(peer) return nil } - log.Debugf("recevied an update for peer %s", peerKey.String()) + log.Debugf("received an update for peer %s", peerKey.String()) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update) if err != nil { diff --git a/management/server/http/accounts_handler_test.go b/management/server/http/accounts_handler_test.go index 3b257a703..08c98c830 100644 --- a/management/server/http/accounts_handler_test.go +++ b/management/server/http/accounts_handler_test.go @@ -117,7 +117,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { expectedID: accountID, }, { - name: "PutAccount OK wiht JWT", + name: "PutAccount OK with JWT", expectedBody: true, requestType: http.MethodPut, requestPath: "/api/accounts/" + accountID, @@ -134,7 +134,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { expectedID: accountID, }, { - name: "PutAccount OK wiht JWT Propagation", + name: "PutAccount OK with JWT Propagation", expectedBody: true, requestType: http.MethodPut, requestPath: "/api/accounts/" + accountID, diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 658d389f6..a0a64fd98 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -125,6 +125,10 @@ components: description: Is true if this user is blocked. Blocked users can't use the system type: boolean example: false + issued: + description: How user was issued by API or Integration + type: string + example: api required: - id - email @@ -857,13 +861,17 @@ components: type: boolean example: true domains: - description: Nameserver group domain list + description: Nameserver group match domain list type: array items: type: string minLength: 1 maxLength: 255 example: "example.com" + search_domains_enabled: + description: Nameserver group search domain status for match domains. It should be true only if domains list is not empty. + type: boolean + example: true required: - name - description @@ -872,6 +880,7 @@ components: - groups - primary - domains + - search_domains_enabled NameserverGroup: allOf: - type: object diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index fd3eedde3..ddf8ce65f 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -1,6 +1,6 @@ // Package api provides primitives to interact with the openapi HTTP API. // -// Code generated by github.com/deepmap/oapi-codegen version v1.15.0 DO NOT EDIT. +// Code generated by github.com/deepmap/oapi-codegen version v1.11.1-0.20220912230023-4a1477f6a8ba DO NOT EDIT. package api import ( @@ -248,7 +248,7 @@ type NameserverGroup struct { // Description Nameserver group description Description string `json:"description"` - // Domains Nameserver group domain list + // Domains Nameserver group match domain list Domains []string `json:"domains"` // Enabled Nameserver group status @@ -268,6 +268,9 @@ type NameserverGroup struct { // Primary Nameserver group primary status Primary bool `json:"primary"` + + // SearchDomainsEnabled Nameserver group search domain status for match domains. It should be true only if domains list is not empty. + SearchDomainsEnabled bool `json:"search_domains_enabled"` } // NameserverGroupRequest defines model for NameserverGroupRequest. @@ -275,7 +278,7 @@ type NameserverGroupRequest struct { // Description Nameserver group description Description string `json:"description"` - // Domains Nameserver group domain list + // Domains Nameserver group match domain list Domains []string `json:"domains"` // Enabled Nameserver group status @@ -292,6 +295,9 @@ type NameserverGroupRequest struct { // Primary Nameserver group primary status Primary bool `json:"primary"` + + // SearchDomainsEnabled Nameserver group search domain status for match domains. It should be true only if domains list is not empty. + SearchDomainsEnabled bool `json:"search_domains_enabled"` } // Peer defines model for Peer. @@ -785,6 +791,9 @@ type User struct { // IsServiceUser Is true if this user is a service user IsServiceUser *bool `json:"is_service_user,omitempty"` + // Issued How user was issued by API or Integration + Issued *string `json:"issued,omitempty"` + // LastLogin Last time this user performed a login to the dashboard LastLogin *time.Time `json:"last_login,omitempty"` diff --git a/management/server/http/dns_settings_handler_test.go b/management/server/http/dns_settings_handler_test.go index c7d135fd1..a2f65a521 100644 --- a/management/server/http/dns_settings_handler_test.go +++ b/management/server/http/dns_settings_handler_test.go @@ -26,7 +26,7 @@ const ( testDNSSettingsUserID = "test_user" ) -var baseExistingDNSSettings = &server.DNSSettings{ +var baseExistingDNSSettings = server.DNSSettings{ DisabledManagementGroups: []string{testDNSSettingsExistingGroup}, } @@ -43,7 +43,7 @@ func initDNSSettingsTestData() *DNSSettingsHandler { return &DNSSettingsHandler{ accountManager: &mock_server.MockAccountManager{ GetDNSSettingsFunc: func(accountID string, userID string) (*server.DNSSettings, error) { - return testingDNSSettingsAccount.DNSSettings, nil + return &testingDNSSettingsAccount.DNSSettings, nil }, SaveDNSSettingsFunc: func(accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error { if dnsSettingsToSave != nil { diff --git a/management/server/http/groups_handler.go b/management/server/http/groups_handler.go index d409623df..c58916250 100644 --- a/management/server/http/groups_handler.go +++ b/management/server/http/groups_handler.go @@ -107,10 +107,11 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { peers = *req.Peers } group := server.Group{ - ID: groupID, - Name: req.Name, - Peers: peers, - Issued: eg.Issued, + ID: groupID, + Name: req.Name, + Peers: peers, + Issued: eg.Issued, + IntegrationReference: eg.IntegrationReference, } if err := h.accountManager.SaveGroup(account.Id, user.Id, &group); err != nil { diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 6e9b029c7..c589512e5 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -6,6 +6,7 @@ import ( "github.com/gorilla/mux" "github.com/rs/cors" + "github.com/netbirdio/management-integrations/integrations" s "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -58,6 +59,12 @@ func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValid AuthCfg: authCfg, } + claimsExtractor := jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(authCfg.Audience), + jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), + ) + + integrations.RegisterHandlers(api.Router, accountManager, claimsExtractor) api.addAccountsEndpoint() api.addPeersEndpoint() api.addUsersEndpoint() @@ -73,8 +80,8 @@ func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValid err := api.Router.Walk(func(route *mux.Route, _ *mux.Router, _ []*mux.Route) error { methods, err := route.GetMethods() - if err != nil { - return err + if err != nil { // we may have wildcard routes from integrations without methods, skip them for now + methods = []string{} } for _, method := range methods { template, err := route.GetPathTemplate() diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 710723124..99482bfb7 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -57,10 +57,17 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := strings.Split(r.Header.Get("Authorization"), " ") - authType := auth[0] - switch strings.ToLower(authType) { + authType := strings.ToLower(auth[0]) + + // fallback to token when receive pat as bearer + if len(auth) >= 2 && authType == "bearer" && strings.HasPrefix(auth[1], "nbp_") { + authType = "token" + auth[0] = authType + } + + switch authType { case "bearer": - err := m.CheckJWTFromRequest(w, r) + err := m.checkJWTFromRequest(w, r, auth) if err != nil { log.Errorf("Error when validating JWT claims: %s", err.Error()) util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w) @@ -68,7 +75,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { } h.ServeHTTP(w, r) case "token": - err := m.CheckPATFromRequest(w, r) + err := m.checkPATFromRequest(w, r, auth) if err != nil { log.Debugf("Error when validating PAT claims: %s", err.Error()) util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w) @@ -83,9 +90,8 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { } // CheckJWTFromRequest checks if the JWT is valid -func (m *AuthMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Request) error { - - token, err := getTokenFromJWTRequest(r) +func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error { + token, err := getTokenFromJWTRequest(auth) // If an error occurs, call the error handler and return an error if err != nil { @@ -110,8 +116,8 @@ func (m *AuthMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Requ } // CheckPATFromRequest checks if the PAT is valid -func (m *AuthMiddleware) CheckPATFromRequest(w http.ResponseWriter, r *http.Request) error { - token, err := getTokenFromPATRequest(r) +func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error { + token, err := getTokenFromPATRequest(auth) // If an error occurs, call the error handler and return an error if err != nil { @@ -143,16 +149,9 @@ func (m *AuthMiddleware) CheckPATFromRequest(w http.ResponseWriter, r *http.Requ return nil } -// getTokenFromJWTRequest is a "TokenExtractor" that takes a give request and extracts +// getTokenFromJWTRequest is a "TokenExtractor" that takes auth header parts and extracts // the JWT token from the Authorization header. -func getTokenFromJWTRequest(r *http.Request) (string, error) { - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - return "", nil // No error, just no token - } - - // TODO: Make this a bit more robust, parsing-wise - authHeaderParts := strings.Fields(authHeader) +func getTokenFromJWTRequest(authHeaderParts []string) (string, error) { if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { return "", errors.New("Authorization header format must be Bearer {token}") } @@ -160,16 +159,9 @@ func getTokenFromJWTRequest(r *http.Request) (string, error) { return authHeaderParts[1], nil } -// getTokenFromPATRequest is a "TokenExtractor" that takes a give request and extracts +// getTokenFromPATRequest is a "TokenExtractor" that takes auth header parts and extracts // the PAT token from the Authorization header. -func getTokenFromPATRequest(r *http.Request) (string, error) { - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - return "", nil // No error, just no token - } - - // TODO: Make this a bit more robust, parsing-wise - authHeaderParts := strings.Fields(authHeader) +func getTokenFromPATRequest(authHeaderParts []string) (string, error) { if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "token" { return "", errors.New("Authorization header format must be Token {token}") } diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index 608bf42fa..55e5de260 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -19,7 +19,7 @@ const ( domain = "domain" userID = "userID" tokenID = "tokenID" - PAT = "PAT" + PAT = "nbp_PAT" JWT = "JWT" wrongToken = "wrongToken" ) @@ -82,6 +82,11 @@ func TestAuthMiddleware_Handler(t *testing.T) { authHeader: "Token " + wrongToken, expectedStatusCode: 401, }, + { + name: "Fallback to PAT Token", + authHeader: "Bearer " + PAT, + expectedStatusCode: 200, + }, { name: "Valid JWT Token", authHeader: "Bearer " + JWT, diff --git a/management/server/http/nameservers_handler.go b/management/server/http/nameservers_handler.go index 918988d69..871bf639a 100644 --- a/management/server/http/nameservers_handler.go +++ b/management/server/http/nameservers_handler.go @@ -79,7 +79,7 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt return } - nsGroup, err := h.accountManager.CreateNameServerGroup(account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id) + nsGroup, err := h.accountManager.CreateNameServerGroup(account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled) if err != nil { util.WriteError(err, w) return @@ -119,14 +119,15 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt } updatedNSGroup := &nbdns.NameServerGroup{ - ID: nsGroupID, - Name: req.Name, - Description: req.Description, - Primary: req.Primary, - Domains: req.Domains, - NameServers: nsList, - Groups: req.Groups, - Enabled: req.Enabled, + ID: nsGroupID, + Name: req.Name, + Description: req.Description, + Primary: req.Primary, + Domains: req.Domains, + NameServers: nsList, + Groups: req.Groups, + Enabled: req.Enabled, + SearchDomainsEnabled: req.SearchDomainsEnabled, } err = h.accountManager.SaveNameServerGroup(account.Id, user.Id, updatedNSGroup) @@ -216,13 +217,14 @@ func toNameserverGroupResponse(serverNSGroup *nbdns.NameServerGroup) *api.Namese } return &api.NameserverGroup{ - Id: serverNSGroup.ID, - Name: serverNSGroup.Name, - Description: serverNSGroup.Description, - Primary: serverNSGroup.Primary, - Domains: serverNSGroup.Domains, - Groups: serverNSGroup.Groups, - Nameservers: nsList, - Enabled: serverNSGroup.Enabled, + Id: serverNSGroup.ID, + Name: serverNSGroup.Name, + Description: serverNSGroup.Description, + Primary: serverNSGroup.Primary, + Domains: serverNSGroup.Domains, + Groups: serverNSGroup.Groups, + Nameservers: nsList, + Enabled: serverNSGroup.Enabled, + SearchDomainsEnabled: serverNSGroup.SearchDomainsEnabled, } } diff --git a/management/server/http/nameservers_handler_test.go b/management/server/http/nameservers_handler_test.go index 100f4b87a..b00ff606f 100644 --- a/management/server/http/nameservers_handler_test.go +++ b/management/server/http/nameservers_handler_test.go @@ -67,16 +67,17 @@ func initNameserversTestData() *NameserversHandler { } return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID) }, - CreateNameServerGroupFunc: func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, _ string) (*nbdns.NameServerGroup, error) { + CreateNameServerGroupFunc: func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, _ string, searchDomains bool) (*nbdns.NameServerGroup, error) { return &nbdns.NameServerGroup{ - ID: existingNSGroupID, - Name: name, - Description: description, - NameServers: nameServerList, - Groups: groups, - Enabled: enabled, - Primary: primary, - Domains: domains, + ID: existingNSGroupID, + Name: name, + Description: description, + NameServers: nameServerList, + Groups: groups, + Enabled: enabled, + Primary: primary, + Domains: domains, + SearchDomainsEnabled: searchDomains, }, nil }, DeleteNameServerGroupFunc: func(accountID, nsGroupID, _ string) error { diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index adf4a9721..a485d6ccf 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -31,6 +31,24 @@ func NewPeersHandler(accountManager server.AccountManager, authCfg AuthCfg) *Pee } } +func (h *PeersHandler) checkPeerStatus(peer *server.Peer) (*server.Peer, error) { + peerToReturn := peer.Copy() + if peer.Status.Connected { + statuses, err := h.accountManager.GetAllConnectedPeers() + if err != nil { + return peerToReturn, err + } + + // Although we have online status in store we do not yet have an updated channel so have to show it as disconnected + // This may happen after server restart when not all peers are yet connected + if _, connected := statuses[peerToReturn.ID]; !connected { + peerToReturn.Status.Connected = false + } + } + + return peerToReturn, nil +} + func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w http.ResponseWriter) { peer, err := h.accountManager.GetPeer(account.Id, peerID, userID) if err != nil { @@ -38,7 +56,13 @@ func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w return } - util.WriteJSONObject(w, toPeerResponse(peer, account, h.accountManager.GetDNSDomain())) + peerToReturn, err := h.checkPeerStatus(peer) + if err != nil { + util.WriteError(err, w) + return + } + + util.WriteJSONObject(w, toPeerResponse(peerToReturn, account, h.accountManager.GetDNSDomain())) } func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) { @@ -120,7 +144,12 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { respBody := []*api.Peer{} for _, peer := range peers { - respBody = append(respBody, toPeerResponse(peer, account, dnsDomain)) + peerToReturn, err := h.checkPeerStatus(peer) + if err != nil { + util.WriteError(err, w) + return + } + respBody = append(respBody, toPeerResponse(peerToReturn, account, dnsDomain)) } util.WriteJSONObject(w, respBody) return diff --git a/management/server/http/peers_handler_test.go b/management/server/http/peers_handler_test.go index 7fe732f2f..1856861d5 100644 --- a/management/server/http/peers_handler_test.go +++ b/management/server/http/peers_handler_test.go @@ -3,6 +3,7 @@ package http import ( "bytes" "encoding/json" + "fmt" "io" "net" "net/http" @@ -23,19 +24,33 @@ import ( ) const testPeerID = "test_peer" +const noUpdateChannelTestPeerID = "no-update-channel" func initTestMetaData(peers ...*server.Peer) *PeersHandler { return &PeersHandler{ accountManager: &mock_server.MockAccountManager{ UpdatePeerFunc: func(accountID, userID string, update *server.Peer) (*server.Peer, error) { - p := peers[0].Copy() + var p *server.Peer + for _, peer := range peers { + if update.ID == peer.ID { + p = peer.Copy() + break + } + } p.SSHEnabled = update.SSHEnabled p.LoginExpirationEnabled = update.LoginExpirationEnabled p.Name = update.Name return p, nil }, GetPeerFunc: func(accountID, peerID, userID string) (*server.Peer, error) { - return peers[0], nil + var p *server.Peer + for _, peer := range peers { + if peerID == peer.ID { + p = peer.Copy() + break + } + } + return p, nil }, GetPeersFunc: func(accountID, userID string) ([]*server.Peer, error) { return peers, nil @@ -57,6 +72,16 @@ func initTestMetaData(peers ...*server.Peer) *PeersHandler { }, }, user, nil }, + GetAllConnectedPeersFunc: func() (map[string]struct{}, error) { + statuses := make(map[string]struct{}) + for _, peer := range peers { + if peer.ID == noUpdateChannelTestPeerID { + break + } + statuses[peer.ID] = struct{}{} + } + return statuses, nil + }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { @@ -79,7 +104,7 @@ func TestGetPeers(t *testing.T) { Key: "key", SetupKey: "setupkey", IP: net.ParseIP("100.64.0.1"), - Status: &server.PeerStatus{}, + Status: &server.PeerStatus{Connected: true}, Name: "PeerName", LoginExpirationEnabled: false, Meta: server.PeerSystemMeta{ @@ -93,11 +118,17 @@ func TestGetPeers(t *testing.T) { }, } + peer1 := peer.Copy() + peer1.ID = noUpdateChannelTestPeerID + expectedUpdatedPeer := peer.Copy() expectedUpdatedPeer.LoginExpirationEnabled = true expectedUpdatedPeer.SSHEnabled = true expectedUpdatedPeer.Name = "New Name" + expectedPeer1 := peer1.Copy() + expectedPeer1.Status.Connected = false + tt := []struct { name string expectedStatus int @@ -116,13 +147,21 @@ func TestGetPeers(t *testing.T) { expectedPeer: peer, }, { - name: "GetPeer", + name: "GetPeer with update channel", requestType: http.MethodGet, requestPath: "/api/peers/" + testPeerID, expectedStatus: http.StatusOK, expectedArray: false, expectedPeer: peer, }, + { + name: "GetPeer with no update channel", + requestType: http.MethodGet, + requestPath: "/api/peers/" + peer1.ID, + expectedStatus: http.StatusOK, + expectedArray: false, + expectedPeer: expectedPeer1, + }, { name: "PutPeer", requestType: http.MethodPut, @@ -136,7 +175,7 @@ func TestGetPeers(t *testing.T) { rr := httptest.NewRecorder() - p := initTestMetaData(peer) + p := initTestMetaData(peer, peer1) for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { @@ -171,6 +210,10 @@ func TestGetPeers(t *testing.T) { t.Fatalf("Sent content is not in correct json format; %v", err) } + // hardcode this check for now as we only have two peers in this suite + assert.Equal(t, len(respBody), 2) + assert.Equal(t, respBody[1].Connected, false) + got = respBody[0] } else { got = &api.Peer{} @@ -180,12 +223,15 @@ func TestGetPeers(t *testing.T) { } } + fmt.Println(got) + assert.Equal(t, got.Name, tc.expectedPeer.Name) assert.Equal(t, got.Version, tc.expectedPeer.Meta.WtVersion) assert.Equal(t, got.Ip, tc.expectedPeer.IP.String()) assert.Equal(t, got.Os, "OS core") assert.Equal(t, got.LoginExpirationEnabled, tc.expectedPeer.LoginExpirationEnabled) assert.Equal(t, got.SshEnabled, tc.expectedPeer.SSHEnabled) + assert.Equal(t, got.Connected, tc.expectedPeer.Status.Connected) }) } } diff --git a/management/server/http/users_handler.go b/management/server/http/users_handler.go index d215e1510..0441e8cc0 100644 --- a/management/server/http/users_handler.go +++ b/management/server/http/users_handler.go @@ -54,6 +54,12 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { return } + existingUser, ok := account.Users[userID] + if !ok { + util.WriteError(status.Errorf(status.NotFound, "couldn't find user with ID %s", userID), w) + return + } + req := &api.PutApiUsersUserIdJSONRequestBody{} err = json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -73,10 +79,12 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { } newUser, err := h.accountManager.SaveUser(account.Id, user.Id, &server.User{ - Id: userID, - Role: userRole, - AutoGroups: req.AutoGroups, - Blocked: req.IsBlocked, + Id: userID, + Role: userRole, + AutoGroups: req.AutoGroups, + Blocked: req.IsBlocked, + Issued: existingUser.Issued, + IntegrationReference: existingUser.IntegrationReference, }) if err != nil { @@ -153,6 +161,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { Role: req.Role, AutoGroups: req.AutoGroups, IsServiceUser: req.IsServiceUser, + Issued: server.UserIssuedAPI, }) if err != nil { util.WriteError(err, w) @@ -198,9 +207,7 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) { util.WriteError(status.Errorf(status.InvalidArgument, "invalid service_user query parameter"), w) return } - log.Debugf("User %v is service user: %v", r.Name, r.IsServiceUser) if includeServiceUser == r.IsServiceUser { - log.Debugf("Found service user: %v", r.Name) users = append(users, toUserResponse(r, claims.UserId)) } } @@ -271,5 +278,6 @@ func toUserResponse(user *server.UserInfo, currenUserID string) *api.User { IsServiceUser: &user.IsServiceUser, IsBlocked: user.IsBlocked, LastLogin: &user.LastLogin, + Issued: &user.Issued, } } diff --git a/management/server/http/users_handler_test.go b/management/server/http/users_handler_test.go index a56507145..b4d449be3 100644 --- a/management/server/http/users_handler_test.go +++ b/management/server/http/users_handler_test.go @@ -33,18 +33,21 @@ var usersTestAccount = &server.Account{ Role: "admin", IsServiceUser: false, AutoGroups: []string{"group_1"}, + Issued: server.UserIssuedAPI, }, regularUserID: { Id: regularUserID, Role: "user", IsServiceUser: false, AutoGroups: []string{"group_1"}, + Issued: server.UserIssuedAPI, }, serviceUserID: { Id: serviceUserID, Role: "user", IsServiceUser: true, AutoGroups: []string{"group_1"}, + Issued: server.UserIssuedAPI, }, }, } @@ -64,6 +67,7 @@ func initUsersTestData() *UsersHandler { Name: "", Email: "", IsServiceUser: v.IsServiceUser, + Issued: v.Issued, }) } return users, nil @@ -170,6 +174,7 @@ func TestGetUsers(t *testing.T) { assert.Equal(t, v.ID, usersTestAccount.Users[v.ID].Id) assert.Equal(t, v.Role, string(usersTestAccount.Users[v.ID].Role)) assert.Equal(t, v.IsServiceUser, usersTestAccount.Users[v.ID].IsServiceUser) + assert.Equal(t, v.Issued, usersTestAccount.Users[v.ID].Issued) } }) } diff --git a/management/server/http/util/util.go b/management/server/http/util/util.go index 277627310..4e2c3d0b3 100644 --- a/management/server/http/util/util.go +++ b/management/server/http/util/util.go @@ -18,7 +18,7 @@ type ErrorResponse struct { Code int `json:"code"` } -// WriteJSONObject simply writes object to the HTTP reponse in JSON format +// WriteJSONObject simply writes object to the HTTP response in JSON format func WriteJSONObject(w http.ResponseWriter, obj interface{}) { w.WriteHeader(http.StatusOK) w.Header().Set("Content-Type", "application/json; charset=UTF-8") diff --git a/management/server/idp/auth0.go b/management/server/idp/auth0.go index d3802d8ad..745136f62 100644 --- a/management/server/idp/auth0.go +++ b/management/server/idp/auth0.go @@ -337,7 +337,7 @@ func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) { return nil, err } - log.Debugf("returned user batch for accountID %s on page %d, %v", accountID, page, batch) + log.Debugf("returned user batch for accountID %s on page %d, batch length %d", accountID, page, len(batch)) err = res.Body.Close() if err != nil { diff --git a/management/server/idp/authentik.go b/management/server/idp/authentik.go index ca995b299..4bbf09404 100644 --- a/management/server/idp/authentik.go +++ b/management/server/idp/authentik.go @@ -251,34 +251,18 @@ func (am *AuthentikManager) GetUserDataByID(userID string, appMetadata AppMetada // GetAccount returns all the users for a given profile. func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) { - ctx, err := am.authenticationContext() + users, err := am.getAllUsers() if err != nil { return nil, err } - userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Execute() - if err != nil { - return nil, err - } - defer resp.Body.Close() - if am.appMetrics != nil { am.appMetrics.IDPMetrics().CountGetAccount() } - if resp.StatusCode != http.StatusOK { - if am.appMetrics != nil { - am.appMetrics.IDPMetrics().CountRequestStatusError() - } - return nil, fmt.Errorf("unable to get account %s users, statusCode %d", accountID, resp.StatusCode) - } - - users := make([]*UserData, 0) - for _, user := range userList.Results { - userData := parseAuthentikUser(user) - userData.AppMetadata.WTAccountID = accountID - - users = append(users, userData) + for index, user := range users { + user.AppMetadata.WTAccountID = accountID + users[index] = user } return users, nil @@ -287,37 +271,59 @@ func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) { // GetAllAccounts gets all registered accounts with corresponding user data. // It returns a list of users indexed by accountID. func (am *AuthentikManager) GetAllAccounts() (map[string][]*UserData, error) { - ctx, err := am.authenticationContext() + users, err := am.getAllUsers() if err != nil { return nil, err } - userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Execute() - if err != nil { - return nil, err - } - defer resp.Body.Close() + indexedUsers := make(map[string][]*UserData) + indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...) if am.appMetrics != nil { am.appMetrics.IDPMetrics().CountGetAllAccounts() } - if resp.StatusCode != http.StatusOK { - if am.appMetrics != nil { - am.appMetrics.IDPMetrics().CountRequestStatusError() - } - return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode) - } - - indexedUsers := make(map[string][]*UserData) - for _, user := range userList.Results { - userData := parseAuthentikUser(user) - indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData) - } - return indexedUsers, nil } +// getAllUsers returns all users in a Authentik account. +func (am *AuthentikManager) getAllUsers() ([]*UserData, error) { + users := make([]*UserData, 0) + + page := int32(1) + for { + ctx, err := am.authenticationContext() + if err != nil { + return nil, err + } + + userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Page(page).Execute() + if err != nil { + return nil, err + } + _ = resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + if am.appMetrics != nil { + am.appMetrics.IDPMetrics().CountRequestStatusError() + } + return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode) + } + + for _, user := range userList.Results { + users = append(users, parseAuthentikUser(user)) + } + + page = int32(userList.GetPagination().Next) + if userList.GetPagination().Next == 0 { + break + } + + } + + return users, nil +} + // CreateUser creates a new user in authentik Idp and sends an invitation. func (am *AuthentikManager) CreateUser(_, _, _, _ string) (*UserData, error) { return nil, fmt.Errorf("method CreateUser not implemented") diff --git a/management/server/idp/azure.go b/management/server/idp/azure.go index e4224c26d..706e4d330 100644 --- a/management/server/idp/azure.go +++ b/management/server/idp/azure.go @@ -266,10 +266,7 @@ func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) { // GetAccount returns all the users for a given profile. func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) { - q := url.Values{} - q.Add("$select", profileFields) - - body, err := am.get("users", q) + users, err := am.getAllUsers() if err != nil { return nil, err } @@ -278,18 +275,9 @@ func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) { am.appMetrics.IDPMetrics().CountGetAccount() } - var profiles struct{ Value []azureProfile } - err = am.helper.Unmarshal(body, &profiles) - if err != nil { - return nil, err - } - - users := make([]*UserData, 0) - for _, profile := range profiles.Value { - userData := profile.userData() - userData.AppMetadata.WTAccountID = accountID - - users = append(users, userData) + for index, user := range users { + user.AppMetadata.WTAccountID = accountID + users[index] = user } return users, nil @@ -298,28 +286,16 @@ func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) { // GetAllAccounts gets all registered accounts with corresponding user data. // It returns a list of users indexed by accountID. func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) { - q := url.Values{} - q.Add("$select", profileFields) - - body, err := am.get("users", q) - if err != nil { - return nil, err - } - - if am.appMetrics != nil { - am.appMetrics.IDPMetrics().CountGetAllAccounts() - } - - var profiles struct{ Value []azureProfile } - err = am.helper.Unmarshal(body, &profiles) + users, err := am.getAllUsers() if err != nil { return nil, err } indexedUsers := make(map[string][]*UserData) - for _, profile := range profiles.Value { - userData := profile.userData() - indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData) + indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...) + + if am.appMetrics != nil { + am.appMetrics.IDPMetrics().CountGetAllAccounts() } return indexedUsers, nil @@ -373,6 +349,39 @@ func (am *AzureManager) DeleteUser(userID string) error { return nil } +// getAllUsers returns all users in an Azure AD account. +func (am *AzureManager) getAllUsers() ([]*UserData, error) { + users := make([]*UserData, 0) + + q := url.Values{} + q.Add("$select", profileFields) + q.Add("$top", "500") + + for nextLink := "users"; nextLink != ""; { + body, err := am.get(nextLink, q) + if err != nil { + return nil, err + } + + var profiles struct { + Value []azureProfile + NextLink string `json:"@odata.nextLink"` + } + err = am.helper.Unmarshal(body, &profiles) + if err != nil { + return nil, err + } + + for _, profile := range profiles.Value { + users = append(users, profile.userData()) + } + + nextLink = profiles.NextLink + } + + return users, nil +} + // get perform Get requests. func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) { jwtToken, err := am.credentials.Authenticate() @@ -380,7 +389,14 @@ func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) { return nil, err } - reqURL := fmt.Sprintf("%s/%s?%s", am.GraphAPIEndpoint, resource, q.Encode()) + var reqURL string + if strings.HasPrefix(resource, "https") { + // Already an absolute URL for paging + reqURL = resource + } else { + reqURL = fmt.Sprintf("%s/%s?%s", am.GraphAPIEndpoint, resource, q.Encode()) + } + req, err := http.NewRequest(http.MethodGet, reqURL, nil) if err != nil { return nil, err diff --git a/management/server/idp/google_workspace.go b/management/server/idp/google_workspace.go index ed2de9a42..896fb707b 100644 --- a/management/server/idp/google_workspace.go +++ b/management/server/idp/google_workspace.go @@ -96,7 +96,7 @@ func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(_ string, _ AppMetadata) // GetUserDataByID requests user data from Google Workspace via ID. func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { - user, err := gm.usersService.Get(userID).Projection("full").Do() + user, err := gm.usersService.Get(userID).Do() if err != nil { return nil, err } @@ -113,43 +113,69 @@ func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata App // GetAccount returns all the users for a given profile. func (gm *GoogleWorkspaceManager) GetAccount(accountID string) ([]*UserData, error) { - usersList, err := gm.usersService.List().Customer(gm.CustomerID).Projection("full").Do() - if err != nil { - return nil, err - } - - usersData := make([]*UserData, 0) - for _, user := range usersList.Users { - userData := parseGoogleWorkspaceUser(user) - userData.AppMetadata.WTAccountID = accountID - - usersData = append(usersData, userData) - } - - return usersData, nil -} - -// GetAllAccounts gets all registered accounts with corresponding user data. -// It returns a list of users indexed by accountID. -func (gm *GoogleWorkspaceManager) GetAllAccounts() (map[string][]*UserData, error) { - usersList, err := gm.usersService.List().Customer(gm.CustomerID).Projection("full").Do() + users, err := gm.getAllUsers() if err != nil { return nil, err } if gm.appMetrics != nil { - gm.appMetrics.IDPMetrics().CountGetAllAccounts() + gm.appMetrics.IDPMetrics().CountGetAccount() + } + + for index, user := range users { + user.AppMetadata.WTAccountID = accountID + users[index] = user + } + + return users, nil +} + +// GetAllAccounts gets all registered accounts with corresponding user data. +// It returns a list of users indexed by accountID. +func (gm *GoogleWorkspaceManager) GetAllAccounts() (map[string][]*UserData, error) { + users, err := gm.getAllUsers() + if err != nil { + return nil, err } indexedUsers := make(map[string][]*UserData) - for _, user := range usersList.Users { - userData := parseGoogleWorkspaceUser(user) - indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData) + indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...) + + if gm.appMetrics != nil { + gm.appMetrics.IDPMetrics().CountGetAllAccounts() } return indexedUsers, nil } +// getAllUsers returns all users in a Google Workspace account filtered by customer ID. +func (gm *GoogleWorkspaceManager) getAllUsers() ([]*UserData, error) { + users := make([]*UserData, 0) + pageToken := "" + for { + call := gm.usersService.List().Customer(gm.CustomerID).MaxResults(500) + if pageToken != "" { + call.PageToken(pageToken) + } + + resp, err := call.Do() + if err != nil { + return nil, err + } + + for _, user := range resp.Users { + users = append(users, parseGoogleWorkspaceUser(user)) + } + + pageToken = resp.NextPageToken + if pageToken == "" { + break + } + } + + return users, nil +} + // CreateUser creates a new user in Google Workspace and sends an invitation. func (gm *GoogleWorkspaceManager) CreateUser(_, _, _, _ string) (*UserData, error) { return nil, fmt.Errorf("method CreateUser not implemented") @@ -158,7 +184,7 @@ func (gm *GoogleWorkspaceManager) CreateUser(_, _, _, _ string) (*UserData, erro // GetUserByEmail searches users with a given email. // If no users have been found, this function returns an empty list. func (gm *GoogleWorkspaceManager) GetUserByEmail(email string) ([]*UserData, error) { - user, err := gm.usersService.Get(email).Projection("full").Do() + user, err := gm.usersService.Get(email).Do() if err != nil { return nil, err } diff --git a/management/server/idp/okta.go b/management/server/idp/okta.go index 3e7b9357e..67341a26f 100644 --- a/management/server/idp/okta.go +++ b/management/server/idp/okta.go @@ -9,6 +9,7 @@ import ( "time" "github.com/okta/okta-sdk-golang/v2/okta" + "github.com/okta/okta-sdk-golang/v2/okta/query" "github.com/netbirdio/netbird/management/server/telemetry" ) @@ -160,7 +161,7 @@ func (om *OktaManager) GetUserByEmail(email string) ([]*UserData, error) { // GetAccount returns all the users for a given profile. func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) { - users, resp, err := om.client.User.ListUsers(context.Background(), nil) + users, err := om.getAllUsers() if err != nil { return nil, err } @@ -169,39 +170,40 @@ func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) { om.appMetrics.IDPMetrics().CountGetAccount() } - if resp.StatusCode != http.StatusOK { - if om.appMetrics != nil { - om.appMetrics.IDPMetrics().CountRequestStatusError() - } - return nil, fmt.Errorf("unable to get account, statusCode %d", resp.StatusCode) + for index, user := range users { + user.AppMetadata.WTAccountID = accountID + users[index] = user } - list := make([]*UserData, 0) - for _, user := range users { - userData, err := parseOktaUser(user) - if err != nil { - return nil, err - } - userData.AppMetadata.WTAccountID = accountID - - list = append(list, userData) - } - - return list, nil + return users, nil } // GetAllAccounts gets all registered accounts with corresponding user data. // It returns a list of users indexed by accountID. func (om *OktaManager) GetAllAccounts() (map[string][]*UserData, error) { - users, resp, err := om.client.User.ListUsers(context.Background(), nil) + users, err := om.getAllUsers() if err != nil { return nil, err } + indexedUsers := make(map[string][]*UserData) + indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...) + if om.appMetrics != nil { om.appMetrics.IDPMetrics().CountGetAllAccounts() } + return indexedUsers, nil +} + +// getAllUsers returns all users in an Okta account. +func (om *OktaManager) getAllUsers() ([]*UserData, error) { + qp := query.NewQueryParams(query.WithLimit(200)) + userList, resp, err := om.client.User.ListUsers(context.Background(), qp) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { if om.appMetrics != nil { om.appMetrics.IDPMetrics().CountRequestStatusError() @@ -209,17 +211,34 @@ func (om *OktaManager) GetAllAccounts() (map[string][]*UserData, error) { return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode) } - indexedUsers := make(map[string][]*UserData) - for _, user := range users { + for resp.HasNextPage() { + paginatedUsers := make([]*okta.User, 0) + resp, err = resp.Next(context.Background(), &paginatedUsers) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + if om.appMetrics != nil { + om.appMetrics.IDPMetrics().CountRequestStatusError() + } + return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode) + } + + userList = append(userList, paginatedUsers...) + } + + users := make([]*UserData, 0, len(userList)) + for _, user := range userList { userData, err := parseOktaUser(user) if err != nil { return nil, err } - indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData) + users = append(users, userData) } - return indexedUsers, nil + return users, nil } // UpdateUserAppMetadata updates user app metadata based on userID and metadata map. diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index b4a527e46..06fc6669d 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -405,7 +405,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error) return nil, "", err } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, err := NewFileStore(config.Datadir, nil) + store, err := NewStoreFromJson(config.Datadir, nil) if err != nil { return nil, "", err } diff --git a/management/server/management_test.go b/management/server/management_test.go index fa35cfdef..375e7e634 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -393,6 +393,7 @@ var _ = Describe("Management service", func() { ipChannel := make(chan string, 20) for i := 0; i < initialPeers; i++ { go func() { + defer GinkgoRecover() key, _ := wgtypes.GenerateKey() loginPeerWithValidSetupKey(serverPubKey, key, client) encryptedBytes, err := encryption.EncryptMessage(serverPubKey, key, &mgmtProto.SyncRequest{}) @@ -496,7 +497,7 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) { Expect(err).NotTo(HaveOccurred()) s := grpc.NewServer() - store, err := server.NewFileStore(config.Datadir, nil) + store, err := server.NewStoreFromJson(config.Datadir, nil) if err != nil { log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) } diff --git a/management/server/metrics/selfhosted.go b/management/server/metrics/selfhosted.go index 3b3db0baa..cf6b2e440 100644 --- a/management/server/metrics/selfhosted.go +++ b/management/server/metrics/selfhosted.go @@ -48,6 +48,7 @@ type properties map[string]interface{} // DataSource metric data source type DataSource interface { GetAllAccounts() []*server.Account + GetStoreEngine() server.StoreEngine } // ConnManager peer connection manager that holds state for current active connections @@ -295,6 +296,7 @@ func (w *Worker) generateProperties() properties { metricsProperties["max_active_peer_version"] = maxActivePeerVersion metricsProperties["ui_clients"] = uiClient metricsProperties["idp_manager"] = w.idpManager + metricsProperties["store_engine"] = w.dataSource.GetStoreEngine() for protocol, count := range rulesProtocol { metricsProperties["rules_protocol_"+protocol] = count diff --git a/management/server/metrics/selfhosted_test.go b/management/server/metrics/selfhosted_test.go index c61613fd2..7717ff409 100644 --- a/management/server/metrics/selfhosted_test.go +++ b/management/server/metrics/selfhosted_test.go @@ -151,6 +151,11 @@ func (mockDatasource) GetAllAccounts() []*server.Account { } } +// GetStoreEngine returns FileStoreEngine +func (mockDatasource) GetStoreEngine() server.StoreEngine { + return server.FileStoreEngine +} + // TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties func TestGenerateProperties(t *testing.T) { ds := mockDatasource{} @@ -236,4 +241,8 @@ func TestGenerateProperties(t *testing.T) { if properties["user_peers"] != 2 { t.Errorf("expected 2 user_peers, got %d", properties["user_peers"]) } + + if properties["store_engine"] != server.FileStoreEngine { + t.Errorf("expected JsonFile, got %s", properties["store_engine"]) + } } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 5432b201b..ea4a18f56 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -60,7 +60,7 @@ type MockAccountManager struct { GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error) GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error) GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) - CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error) + CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error ListNameServerGroupsFunc func(accountID string) ([]*nbdns.NameServerGroup, error) @@ -75,6 +75,7 @@ type MockAccountManager struct { LoginPeerFunc func(login server.PeerLogin) (*server.Peer, *server.NetworkMap, error) SyncPeerFunc func(sync server.PeerSync) (*server.Peer, *server.NetworkMap, error) InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error + GetAllConnectedPeersFunc func() (map[string]struct{}, error) } // GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface @@ -463,9 +464,9 @@ func (am *MockAccountManager) GetNameServerGroup(accountID, nsGroupID string) (* } // CreateNameServerGroup mocks CreateNameServerGroup of the AccountManager interface -func (am *MockAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error) { +func (am *MockAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) { if am.CreateNameServerGroupFunc != nil { - return am.CreateNameServerGroupFunc(accountID, name, description, nameServerList, groups, primary, domains, enabled, userID) + return am.CreateNameServerGroupFunc(accountID, name, description, nameServerList, groups, primary, domains, enabled, userID, searchDomainsEnabled) } return nil, nil } @@ -583,3 +584,11 @@ func (am *MockAccountManager) SyncPeer(sync server.PeerSync) (*server.Peer, *ser } return nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented") } + +// GetAllConnectedPeers mocks GetAllConnectedPeers of the AccountManager interface +func (am *MockAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) { + if am.GetAllConnectedPeersFunc != nil { + return am.GetAllConnectedPeersFunc() + } + return nil, status.Errorf(codes.Unimplemented, "method GetAllConnectedPeers is not implemented") +} diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 9af5b49ad..8ae71dbae 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -35,7 +35,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string) } // CreateNameServerGroup creates and saves a new nameserver group -func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error) { +func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -46,14 +46,15 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d } newNSGroup := &nbdns.NameServerGroup{ - ID: xid.New().String(), - Name: name, - Description: description, - NameServers: nameServerList, - Groups: groups, - Enabled: enabled, - Primary: primary, - Domains: domains, + ID: xid.New().String(), + Name: name, + Description: description, + NameServers: nameServerList, + Groups: groups, + Enabled: enabled, + Primary: primary, + Domains: domains, + SearchDomainsEnabled: searchDomainEnabled, } err = validateNameServerGroup(false, newNSGroup, account) @@ -174,7 +175,7 @@ func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServ } } - err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains) + err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled) if err != nil { return err } @@ -197,7 +198,7 @@ func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServ return nil } -func validateDomainInput(primary bool, domains []string) error { +func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error { if !primary && len(domains) == 0 { return status.Errorf(status.InvalidArgument, "nameserver group primary status is false and domains are empty,"+ " it should be primary or have at least one domain") @@ -206,6 +207,12 @@ func validateDomainInput(primary bool, domains []string) error { return status.Errorf(status.InvalidArgument, "nameserver group primary status is true and domains are not empty,"+ " you should set either primary or domain") } + + if primary && searchDomainsEnabled { + return status.Errorf(status.InvalidArgument, "nameserver group primary status is true and search domains is enabled,"+ + " you should not set search domains for primary nameservers") + } + for _, domain := range domains { if err := validateDomain(domain); err != nil { return status.Errorf(status.InvalidArgument, "nameserver group got an invalid domain: %s %q", domain, err) diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 26977116b..6210ae538 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -23,13 +23,14 @@ const ( func TestCreateNameServerGroup(t *testing.T) { type input struct { - name string - description string - enabled bool - groups []string - nameServers []nbdns.NameServer - primary bool - domains []string + name string + description string + enabled bool + groups []string + nameServers []nbdns.NameServer + primary bool + domains []string + searchDomains bool } testCases := []struct { @@ -383,6 +384,7 @@ func TestCreateNameServerGroup(t *testing.T) { testCase.inputArgs.domains, testCase.inputArgs.enabled, userID, + testCase.inputArgs.searchDomains, ) testCase.errFunc(t, err) @@ -749,7 +751,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { func createNSStore(t *testing.T) (Store, error) { dataDir := t.TempDir() - store, err := NewFileStore(dataDir, nil) + store, err := NewStoreFromJson(dataDir, nil) if err != nil { return nil, err } diff --git a/management/server/network.go b/management/server/network.go index 70f218f66..c5b165cae 100644 --- a/management/server/network.go +++ b/management/server/network.go @@ -34,14 +34,14 @@ type NetworkMap struct { } type Network struct { - Id string - Net net.IPNet - Dns string + Identifier string `json:"id"` + Net net.IPNet `gorm:"serializer:gob"` + Dns string // Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added). // Used to synchronize state to the client apps. Serial uint64 - mu sync.Mutex `json:"-"` + mu sync.Mutex `json:"-" gorm:"-"` } // NewNetwork creates a new Network initializing it with a Serial=0 @@ -56,10 +56,10 @@ func NewNetwork() *Network { intn := r.Intn(len(sub)) return &Network{ - Id: xid.New().String(), - Net: sub[intn].IPNet, - Dns: "", - Serial: 0} + Identifier: xid.New().String(), + Net: sub[intn].IPNet, + Dns: "", + Serial: 0} } // IncSerial increments Serial by 1 reflecting that the network state has been changed @@ -78,10 +78,10 @@ func (n *Network) CurrentSerial() uint64 { func (n *Network) Copy() *Network { return &Network{ - Id: n.Id, - Net: n.Net, - Dns: n.Dns, - Serial: n.Serial, + Identifier: n.Identifier, + Net: n.Net, + Dns: n.Dns, + Serial: n.Serial, } } diff --git a/management/server/peer.go b/management/server/peer.go index e5c6e39d6..33c9430fc 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -72,22 +72,24 @@ type PeerLogin struct { // The Peer is a WireGuard peer identified by a public key type Peer struct { // ID is an internal ID of the peer - ID string + ID string `gorm:"primaryKey"` + // AccountID is a reference to Account that this object belongs + AccountID string `json:"-" gorm:"index;uniqueIndex:idx_peers_account_id_ip"` // WireGuard public key - Key string + Key string `gorm:"index"` // A setup key this peer was registered with SetupKey string // IP address of the Peer - IP net.IP + IP net.IP `gorm:"uniqueIndex:idx_peers_account_id_ip"` // Meta is a Peer system meta data - Meta PeerSystemMeta + Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"` // Name is peer's name (machine name) Name string // DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's // domain to the peer label. e.g. peer-dns-label.netbird.cloud DNSLabel string // Status peer's management connection status - Status *PeerStatus + Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"` // The user ID that registered the peer UserID string // SSHKey is a public SSH key of the peer @@ -116,6 +118,7 @@ func (p *Peer) Copy() *Peer { } return &Peer{ ID: p.ID, + AccountID: p.AccountID, Key: p.Key, SetupKey: p.SetupKey, IP: p.IP, @@ -728,7 +731,7 @@ func checkAuth(loginUserID string, peer *Peer) error { return status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more") } if peer.UserID != loginUserID { - log.Warnf("user mismatch when loggin in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, loginUserID) + log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, loginUserID) return status.Errorf(status.Unauthenticated, "can't login") } return nil diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 36e96df43..9d5a8bfb9 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -369,8 +369,8 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) { return } - if account.Network.Id != network.Id { - t.Errorf("expecting Account Networks ID to be equal, got %s expected %s", network.Id, account.Network.Id) + if account.Network.Identifier != network.Identifier { + t.Errorf("expecting Account Networks ID to be equal, got %s expected %s", network.Identifier, account.Network.Identifier) } } diff --git a/management/server/personal_access_token.go b/management/server/personal_access_token.go index c7deca9de..f46666112 100644 --- a/management/server/personal_access_token.go +++ b/management/server/personal_access_token.go @@ -26,7 +26,9 @@ const ( // PersonalAccessToken holds all information about a PAT including a hashed version of it for verification type PersonalAccessToken struct { - ID string + ID string `gorm:"primaryKey"` + // User is a reference to Account that this object belongs + UserID string `gorm:"index"` Name string HashedToken string ExpirationDate time.Time diff --git a/management/server/policy.go b/management/server/policy.go index 308a5c3c0..b7b5b331c 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -43,7 +43,7 @@ const ( ) const ( - // PolicyRuleFlowDirect allows trafic from source to destination + // PolicyRuleFlowDirect allows traffic from source to destination PolicyRuleFlowDirect = PolicyRuleDirection("direct") // PolicyRuleFlowBidirect allows traffic to both directions PolicyRuleFlowBidirect = PolicyRuleDirection("bidirect") @@ -63,7 +63,10 @@ type PolicyUpdateOperation struct { // PolicyRule is the metadata of the policy type PolicyRule struct { // ID of the policy rule - ID string + ID string `gorm:"primaryKey"` + + // PolicyID is a reference to Policy that this object belongs + PolicyID string `json:"-" gorm:"index"` // Name of the rule visible in the UI Name string @@ -78,10 +81,10 @@ type PolicyRule struct { Action PolicyTrafficActionType // Destinations policy destination groups - Destinations []string + Destinations []string `gorm:"serializer:json"` // Sources policy source groups - Sources []string + Sources []string `gorm:"serializer:json"` // Bidirectional define if the rule is applicable in both directions, sources, and destinations Bidirectional bool @@ -90,7 +93,7 @@ type PolicyRule struct { Protocol PolicyRuleProtocolType // Ports or it ranges list - Ports []string + Ports []string `gorm:"serializer:json"` } // Copy returns a copy of a policy rule @@ -128,8 +131,11 @@ func (pm *PolicyRule) ToRule() *Rule { // Policy of the Rego query type Policy struct { - // ID of the policy - ID string + // ID of the policy' + ID string `gorm:"primaryKey"` + + // AccountID is a reference to Account that this object belongs + AccountID string `json:"-" gorm:"index"` // Name of the Policy Name string @@ -141,7 +147,7 @@ type Policy struct { Enabled bool // Rules of the policy - Rules []*PolicyRule + Rules []*PolicyRule `gorm:"foreignKey:PolicyID;references:id"` } // Copy returns a copy of the policy. @@ -201,7 +207,6 @@ type FirewallRule struct { // This function returns the list of peers and firewall rules that are applicable to a given peer. func (a *Account) getPeerConnectionResources(peerID string) ([]*Peer, []*FirewallRule) { generateResources, getAccumulatedResources := a.connResourcesGenerator() - for _, policy := range a.Policies { if !policy.Enabled { continue diff --git a/management/server/policy_test.go b/management/server/policy_test.go index bf003ffac..971bd27d9 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -111,8 +111,8 @@ func TestAccount_getPeersByPolicy(t *testing.T) { t.Run("check that all peers get map", func(t *testing.T) { for _, p := range account.Peers { peers, firewallRules := account.getPeerConnectionResources(p.ID) - assert.GreaterOrEqual(t, len(peers), 2, "mininum number peers should present") - assert.GreaterOrEqual(t, len(firewallRules), 2, "mininum number of firewall rules should present") + assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present") + assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present") } }) diff --git a/management/server/route.go b/management/server/route.go index 79c207c9b..6b5aa982d 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -98,7 +98,7 @@ func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account // check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix for _, id := range group.Peers { if _, ok := seenPeers[id]; ok { - peer := account.GetPeer(peerID) + peer := account.GetPeer(id) if peer == nil { return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) } diff --git a/management/server/route_test.go b/management/server/route_test.go index 32f15843b..efd73d6c2 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/rs/xid" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/activity" @@ -48,11 +49,12 @@ func TestCreateRoute(t *testing.T) { } testCases := []struct { - name string - inputArgs input - shouldCreate bool - errFunc require.ErrorAssertionFunc - expectedRoute *route.Route + name string + inputArgs input + createInitRoute bool + shouldCreate bool + errFunc require.ErrorAssertionFunc + expectedRoute *route.Route }{ { name: "Happy Path", @@ -164,8 +166,9 @@ func TestCreateRoute(t *testing.T) { enabled: true, groups: []string{routeGroup1}, }, - errFunc: require.Error, - shouldCreate: false, + createInitRoute: true, + errFunc: require.Error, + shouldCreate: false, }, { name: "Bad Peers Group already has this route", @@ -179,8 +182,9 @@ func TestCreateRoute(t *testing.T) { enabled: true, groups: []string{routeGroup1}, }, - errFunc: require.Error, - shouldCreate: false, + createInitRoute: true, + errFunc: require.Error, + shouldCreate: false, }, { name: "Empty Peer Should Create", @@ -326,6 +330,18 @@ func TestCreateRoute(t *testing.T) { t.Errorf("failed to init testing account: %s", err) } + if testCase.createInitRoute { + groupAll, errInit := account.GetGroupAll() + if errInit != nil { + t.Errorf("failed to get group all: %s", errInit) + } + _, errInit = am.CreateRoute(account.Id, existingNetwork, "", []string{routeGroup3, routeGroup4}, + "", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID) + if errInit != nil { + t.Errorf("failed to create init route: %s", errInit) + } + } + outRoute, err := am.CreateRoute( account.Id, testCase.inputArgs.network, @@ -370,17 +386,18 @@ func TestSaveRoute(t *testing.T) { validGroupHA2 := routeGroupHA2 testCases := []struct { - name string - existingRoute *route.Route - newPeer *string - newPeerGroups []string - newMetric *int - newPrefix *netip.Prefix - newGroups []string - skipCopying bool - shouldCreate bool - errFunc require.ErrorAssertionFunc - expectedRoute *route.Route + name string + existingRoute *route.Route + createInitRoute bool + newPeer *string + newPeerGroups []string + newMetric *int + newPrefix *netip.Prefix + newGroups []string + skipCopying bool + shouldCreate bool + errFunc require.ErrorAssertionFunc + expectedRoute *route.Route }{ { name: "Happy Path", @@ -645,8 +662,9 @@ func TestSaveRoute(t *testing.T) { Enabled: true, Groups: []string{routeGroup1}, }, - newPeer: &validUsedPeer, - errFunc: require.Error, + createInitRoute: true, + newPeer: &validUsedPeer, + errFunc: require.Error, }, { name: "Do not allow to modify existing route with a peers group from another route", @@ -662,8 +680,9 @@ func TestSaveRoute(t *testing.T) { Enabled: true, Groups: []string{routeGroup1}, }, - newPeerGroups: []string{routeGroup4}, - errFunc: require.Error, + createInitRoute: true, + newPeerGroups: []string{routeGroup4}, + errFunc: require.Error, }, } for _, testCase := range testCases { @@ -678,6 +697,21 @@ func TestSaveRoute(t *testing.T) { t.Error("failed to init testing account") } + if testCase.createInitRoute { + account.Routes["initRoute"] = &route.Route{ + ID: "initRoute", + Network: netip.MustParsePrefix(existingNetwork), + NetID: existingRouteID, + NetworkType: route.IPv4Network, + PeerGroups: []string{routeGroup4}, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + } + } + account.Routes[testCase.existingRoute.ID] = testCase.existingRoute err = am.Store.SaveAccount(account) @@ -811,15 +845,15 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { peer1Routes, err := am.GetNetworkMap(peer1ID) require.NoError(t, err) - require.Len(t, peer1Routes.Routes, 3, "HA route should have more than 1 routes") + assert.Len(t, peer1Routes.Routes, 1, "HA route should have 1 server route") peer2Routes, err := am.GetNetworkMap(peer2ID) require.NoError(t, err) - require.Len(t, peer2Routes.Routes, 3, "HA route should have more than 1 routes") + assert.Len(t, peer2Routes.Routes, 1, "HA route should have 1 server route") peer4Routes, err := am.GetNetworkMap(peer4ID) require.NoError(t, err) - require.Len(t, peer4Routes.Routes, 3, "HA route should have more than 1 routes") + assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route") groups, err := am.ListGroups(account.Id) require.NoError(t, err) @@ -838,32 +872,32 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { peer2RoutesAfterDelete, err := am.GetNetworkMap(peer2ID) require.NoError(t, err) - require.Len(t, peer2RoutesAfterDelete.Routes, 2, "after peer deletion group should have only 2 route") + assert.Len(t, peer2RoutesAfterDelete.Routes, 2, "after peer deletion group should have 2 client routes") err = am.GroupDeletePeer(account.Id, groupHA2.ID, peer4ID) require.NoError(t, err) peer2RoutesAfterDelete, err = am.GetNetworkMap(peer2ID) require.NoError(t, err) - require.Len(t, peer2RoutesAfterDelete.Routes, 1, "after peer deletion group should have only 1 route") + assert.Len(t, peer2RoutesAfterDelete.Routes, 1, "after peer deletion group should have only 1 route") err = am.GroupAddPeer(account.Id, groupHA2.ID, peer4ID) require.NoError(t, err) peer1RoutesAfterAdd, err := am.GetNetworkMap(peer1ID) require.NoError(t, err) - require.Len(t, peer1RoutesAfterAdd.Routes, 2, "HA route should have more than 1 route") + assert.Len(t, peer1RoutesAfterAdd.Routes, 1, "HA route should have more than 1 route") peer2RoutesAfterAdd, err := am.GetNetworkMap(peer2ID) require.NoError(t, err) - require.Len(t, peer2RoutesAfterAdd.Routes, 2, "HA route should have more than 1 route") + assert.Len(t, peer2RoutesAfterAdd.Routes, 2, "HA route should have 2 client routes") err = am.DeleteRoute(account.Id, newRoute.ID, userID) require.NoError(t, err) peer1DeletedRoute, err := am.GetNetworkMap(peer1ID) require.NoError(t, err) - require.Len(t, peer1DeletedRoute.Routes, 0, "we should receive one route for peer1") + assert.Len(t, peer1DeletedRoute.Routes, 0, "we should receive one route for peer1") } func TestGetNetworkMap_RouteSync(t *testing.T) { @@ -983,7 +1017,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { func createRouterStore(t *testing.T) (Store, error) { dataDir := t.TempDir() - store, err := NewFileStore(dataDir, nil) + store, err := NewStoreFromJson(dataDir, nil) if err != nil { return nil, err } @@ -1193,11 +1227,5 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er } } - _, err = am.CreateRoute(account.Id, existingNetwork, "", []string{routeGroup3, routeGroup4}, - "", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID) - if err != nil { - return nil, err - } - return am.Store.GetAccount(account.Id) } diff --git a/management/server/rule.go b/management/server/rule.go index cb85d633d..19085840c 100644 --- a/management/server/rule.go +++ b/management/server/rule.go @@ -25,6 +25,9 @@ type Rule struct { // ID of the rule ID string + // AccountID is a reference to Account that this object belongs + AccountID string `json:"-" gorm:"index"` + // Name of the rule visible in the UI Name string @@ -35,10 +38,10 @@ type Rule struct { Disabled bool // Source list of groups IDs of peers - Source []string + Source []string `gorm:"serializer:json"` // Destination list of groups IDs of peers - Destination []string + Destination []string `gorm:"serializer:json"` // Flow of the traffic allowed by the rule Flow TrafficFlowType diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 6e626d084..a33f537a7 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -68,13 +68,15 @@ type SetupKeyType string // SetupKey represents a pre-authorized key used to register machines (peers) type SetupKey struct { - Id string + Id string + // AccountID is a reference to Account that this object belongs + AccountID string `json:"-" gorm:"index"` Key string Name string Type SetupKeyType CreatedAt time.Time ExpiresAt time.Time - UpdatedAt time.Time + UpdatedAt time.Time `gorm:"autoUpdateTime:false"` // Revoked indicates whether the key was revoked or not (we don't remove them for tracking purposes) Revoked bool // UsedTimes indicates how many times the key was used @@ -82,7 +84,7 @@ type SetupKey struct { // LastUsed last time the key was used for peer registration LastUsed time.Time // AutoGroups is a list of Group IDs that are auto assigned to a Peer when it uses this key to register - AutoGroups []string + AutoGroups []string `gorm:"serializer:json"` // UsageLimit indicates the number of times this key can be used to enroll a machine. // The value of 0 indicates the unlimited usage. UsageLimit int @@ -99,6 +101,7 @@ func (key *SetupKey) Copy() *SetupKey { } return &SetupKey{ Id: key.Id, + AccountID: key.AccountID, Key: key.Key, Name: key.Name, Type: key.Type, diff --git a/management/server/sqlite_store.go b/management/server/sqlite_store.go new file mode 100644 index 000000000..ed473e143 --- /dev/null +++ b/management/server/sqlite_store.go @@ -0,0 +1,458 @@ +package server + +import ( + "path/filepath" + "runtime" + "strings" + "sync" + "time" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/route" + log "github.com/sirupsen/logrus" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" +) + +// SqliteStore represents an account storage backed by a Sqlite DB persisted to disk +type SqliteStore struct { + db *gorm.DB + storeFile string + accountLocks sync.Map + globalAccountLock sync.Mutex + metrics telemetry.AppMetrics + installationPK int +} + +type installation struct { + ID uint `gorm:"primaryKey"` + InstallationIDValue string +} + +// NewSqliteStore restores a store from the file located in the datadir +func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore, error) { + storeStr := "store.db?cache=shared" + if runtime.GOOS == "windows" { + // Vo avoid `The process cannot access the file because it is being used by another process` on Windows + storeStr = "store.db" + } + + file := filepath.Join(dataDir, storeStr) + db, err := gorm.Open(sqlite.Open(file), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + PrepareStmt: true, + }) + if err != nil { + return nil, err + } + + sql, err := db.DB() + if err != nil { + return nil, err + } + conns := runtime.NumCPU() + sql.SetMaxOpenConns(conns) // TODO: make it configurable + + err = db.AutoMigrate( + &SetupKey{}, &Peer{}, &User{}, &PersonalAccessToken{}, &Group{}, &Rule{}, + &Account{}, &Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, + &installation{}, + ) + if err != nil { + return nil, err + } + + return &SqliteStore{db: db, storeFile: file, metrics: metrics, installationPK: 1}, nil +} + +// NewSqliteStoreFromFileStore restores a store from FileStore and stores SQLite DB in the file located in datadir +func NewSqliteStoreFromFileStore(filestore *FileStore, dataDir string, metrics telemetry.AppMetrics) (*SqliteStore, error) { + store, err := NewSqliteStore(dataDir, metrics) + if err != nil { + return nil, err + } + + err = store.SaveInstallationID(filestore.InstallationID) + if err != nil { + return nil, err + } + + for _, account := range filestore.GetAllAccounts() { + err := store.SaveAccount(account) + if err != nil { + return nil, err + } + } + + return store, nil +} + +// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock +func (s *SqliteStore) AcquireGlobalLock() (unlock func()) { + log.Debugf("acquiring global lock") + start := time.Now() + s.globalAccountLock.Lock() + + unlock = func() { + s.globalAccountLock.Unlock() + log.Debugf("released global lock in %v", time.Since(start)) + } + + took := time.Since(start) + log.Debugf("took %v to acquire global lock", took) + if s.metrics != nil { + s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took) + } + + return unlock +} + +func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) { + log.Debugf("acquiring lock for account %s", accountID) + + start := time.Now() + value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{}) + mtx := value.(*sync.Mutex) + mtx.Lock() + + unlock = func() { + mtx.Unlock() + log.Debugf("released lock for account %s in %v", accountID, time.Since(start)) + } + + return unlock +} + +func (s *SqliteStore) SaveAccount(account *Account) error { + start := time.Now() + + for _, key := range account.SetupKeys { + account.SetupKeysG = append(account.SetupKeysG, *key) + } + + for id, peer := range account.Peers { + peer.ID = id + account.PeersG = append(account.PeersG, *peer) + } + + for id, user := range account.Users { + user.Id = id + for id, pat := range user.PATs { + pat.ID = id + user.PATsG = append(user.PATsG, *pat) + } + account.UsersG = append(account.UsersG, *user) + } + + for id, group := range account.Groups { + group.ID = id + account.GroupsG = append(account.GroupsG, *group) + } + + for id, rule := range account.Rules { + rule.ID = id + account.RulesG = append(account.RulesG, *rule) + } + + for id, route := range account.Routes { + route.ID = id + account.RoutesG = append(account.RoutesG, *route) + } + + for id, ns := range account.NameServerGroups { + ns.ID = id + account.NameServerGroupsG = append(account.NameServerGroupsG, *ns) + } + + err := s.db.Transaction(func(tx *gorm.DB) error { + result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id) + if result.Error != nil { + return result.Error + } + + result = tx.Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id) + if result.Error != nil { + return result.Error + } + + result = tx.Select(clause.Associations).Delete(account) + if result.Error != nil { + return result.Error + } + + result = tx. + Session(&gorm.Session{FullSaveAssociations: true}). + Clauses(clause.OnConflict{UpdateAll: true}).Create(account) + if result.Error != nil { + return result.Error + } + return nil + }) + + took := time.Since(start) + if s.metrics != nil { + s.metrics.StoreMetrics().CountPersistenceDuration(took) + } + log.Debugf("took %d ms to persist an account to the SQLite", took.Milliseconds()) + + return err +} + +func (s *SqliteStore) SaveInstallationID(ID string) error { + installation := installation{InstallationIDValue: ID} + installation.ID = uint(s.installationPK) + + return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&installation).Error +} + +func (s *SqliteStore) GetInstallationID() string { + var installation installation + + if result := s.db.First(&installation, "id = ?", s.installationPK); result.Error != nil { + return "" + } + + return installation.InstallationIDValue +} + +func (s *SqliteStore) SavePeerStatus(accountID, peerID string, peerStatus PeerStatus) error { + var peer Peer + + result := s.db.First(&peer, "account_id = ? and id = ?", accountID, peerID) + if result.Error != nil { + return status.Errorf(status.NotFound, "peer %s not found", peerID) + } + + peer.Status = &peerStatus + + return s.db.Save(peer).Error +} + +// DeleteHashedPAT2TokenIDIndex is noop in Sqlite +func (s *SqliteStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error { + return nil +} + +// DeleteTokenID2UserIDIndex is noop in Sqlite +func (s *SqliteStore) DeleteTokenID2UserIDIndex(tokenID string) error { + return nil +} + +func (s *SqliteStore) GetAccountByPrivateDomain(domain string) (*Account, error) { + var account Account + + result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?", + strings.ToLower(domain), true, PrivateCategory) + if result.Error != nil { + return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") + } + + // TODO: rework to not call GetAccount + return s.GetAccount(account.Id) +} + +func (s *SqliteStore) GetAccountBySetupKey(setupKey string) (*Account, error) { + var key SetupKey + result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey)) + if result.Error != nil { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + if key.AccountID == "" { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + return s.GetAccount(key.AccountID) +} + +func (s *SqliteStore) GetTokenIDByHashedToken(hashedToken string) (string, error) { + var token PersonalAccessToken + result := s.db.First(&token, "hashed_token = ?", hashedToken) + if result.Error != nil { + return "", status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + return token.ID, nil +} + +func (s *SqliteStore) GetUserByTokenID(tokenID string) (*User, error) { + var token PersonalAccessToken + result := s.db.First(&token, "id = ?", tokenID) + if result.Error != nil { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + if token.UserID == "" { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + var user User + result = s.db.Preload("PATsG").First(&user, "id = ?", token.UserID) + if result.Error != nil { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + user.PATs = make(map[string]*PersonalAccessToken, len(user.PATsG)) + for _, pat := range user.PATsG { + user.PATs[pat.ID] = &pat + } + + return &user, nil +} + +func (s *SqliteStore) GetAllAccounts() (all []*Account) { + var accounts []Account + result := s.db.Find(&accounts) + if result.Error != nil { + return all + } + + for _, account := range accounts { + if acc, err := s.GetAccount(account.Id); err == nil { + all = append(all, acc) + } + } + + return all +} + +func (s *SqliteStore) GetAccount(accountID string) (*Account, error) { + var account Account + + result := s.db.Model(&account). + Preload("UsersG.PATsG"). // have to be specifies as this is nester reference + Preload(clause.Associations). + First(&account, "id = ?", accountID) + if result.Error != nil { + return nil, status.Errorf(status.NotFound, "account not found") + } + + // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us + for i, policy := range account.Policies { + var rules []*PolicyRule + err := s.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + if err != nil { + return nil, status.Errorf(status.NotFound, "account not found") + } + account.Policies[i].Rules = rules + } + + account.SetupKeys = make(map[string]*SetupKey, len(account.SetupKeysG)) + for _, key := range account.SetupKeysG { + account.SetupKeys[key.Key] = key.Copy() + } + account.SetupKeysG = nil + + account.Peers = make(map[string]*Peer, len(account.PeersG)) + for _, peer := range account.PeersG { + account.Peers[peer.ID] = peer.Copy() + } + account.PeersG = nil + + account.Users = make(map[string]*User, len(account.UsersG)) + for _, user := range account.UsersG { + user.PATs = make(map[string]*PersonalAccessToken, len(user.PATs)) + for _, pat := range user.PATsG { + user.PATs[pat.ID] = pat.Copy() + } + account.Users[user.Id] = user.Copy() + } + account.UsersG = nil + + account.Groups = make(map[string]*Group, len(account.GroupsG)) + for _, group := range account.GroupsG { + account.Groups[group.ID] = group.Copy() + } + account.GroupsG = nil + + account.Rules = make(map[string]*Rule, len(account.RulesG)) + for _, rule := range account.RulesG { + account.Rules[rule.ID] = rule.Copy() + } + account.RulesG = nil + + account.Routes = make(map[string]*route.Route, len(account.RoutesG)) + for _, route := range account.RoutesG { + account.Routes[route.ID] = route.Copy() + } + account.RoutesG = nil + + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for _, ns := range account.NameServerGroupsG { + account.NameServerGroups[ns.ID] = ns.Copy() + } + account.NameServerGroupsG = nil + + return &account, nil +} + +func (s *SqliteStore) GetAccountByUser(userID string) (*Account, error) { + var user User + result := s.db.Select("account_id").First(&user, "id = ?", userID) + if result.Error != nil { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + if user.AccountID == "" { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + return s.GetAccount(user.AccountID) +} + +func (s *SqliteStore) GetAccountByPeerID(peerID string) (*Account, error) { + var peer Peer + result := s.db.Select("account_id").First(&peer, "id = ?", peerID) + if result.Error != nil { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + if peer.AccountID == "" { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + return s.GetAccount(peer.AccountID) +} + +func (s *SqliteStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) { + var peer Peer + + result := s.db.Select("account_id").First(&peer, "key = ?", peerKey) + if result.Error != nil { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + if peer.AccountID == "" { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + return s.GetAccount(peer.AccountID) +} + +// SaveUserLastLogin stores the last login time for a user in DB. +func (s *SqliteStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error { + var user User + + result := s.db.First(&user, "account_id = ? and id = ?", accountID, userID) + if result.Error != nil { + return status.Errorf(status.NotFound, "user %s not found", userID) + } + + user.LastLogin = lastLogin + + return s.db.Save(user).Error +} + +// Close is noop in Sqlite +func (s *SqliteStore) Close() error { + return nil +} + +// GetStoreEngine returns SqliteStoreEngine +func (s *SqliteStore) GetStoreEngine() StoreEngine { + return SqliteStoreEngine +} diff --git a/management/server/sqlite_store_test.go b/management/server/sqlite_store_test.go new file mode 100644 index 000000000..4a16e2525 --- /dev/null +++ b/management/server/sqlite_store_test.go @@ -0,0 +1,229 @@ +package server + +import ( + "fmt" + "net" + "path/filepath" + "runtime" + "testing" + "time" + + "github.com/google/uuid" + "github.com/netbirdio/netbird/util" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSqlite_NewStore(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + + store := newSqliteStore(t) + + if len(store.GetAllAccounts()) != 0 { + t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") + } +} + +func TestSqlite_SaveAccount(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + + store := newSqliteStore(t) + + account := newAccountWithId("account_id", "testuser", "") + setupKey := GenerateDefaultSetupKey() + account.SetupKeys[setupKey.Key] = setupKey + account.Peers["testpeer"] = &Peer{ + Key: "peerkey", + SetupKey: "peerkeysetupkey", + IP: net.IP{127, 0, 0, 1}, + Meta: PeerSystemMeta{}, + Name: "peer name", + Status: &PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + } + + err := store.SaveAccount(account) + require.NoError(t, err) + + account2 := newAccountWithId("account_id2", "testuser2", "") + setupKey = GenerateDefaultSetupKey() + account2.SetupKeys[setupKey.Key] = setupKey + account2.Peers["testpeer2"] = &Peer{ + Key: "peerkey2", + SetupKey: "peerkeysetupkey2", + IP: net.IP{127, 0, 0, 2}, + Meta: PeerSystemMeta{}, + Name: "peer name 2", + Status: &PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + } + + err = store.SaveAccount(account2) + require.NoError(t, err) + + if len(store.GetAllAccounts()) != 2 { + t.Errorf("expecting 2 Accounts to be stored after SaveAccount()") + } + + a, err := store.GetAccount(account.Id) + if a == nil { + t.Errorf("expecting Account to be stored after SaveAccount(): %v", err) + } + + if a != nil && len(a.Policies) != 1 { + t.Errorf("expecting Account to have one policy stored after SaveAccount(), got %d", len(a.Policies)) + } + + if a != nil && len(a.Policies[0].Rules) != 1 { + t.Errorf("expecting Account to have one policy rule stored after SaveAccount(), got %d", len(a.Policies[0].Rules)) + return + } + + if a, err := store.GetAccountByPeerPubKey("peerkey"); a == nil { + t.Errorf("expecting PeerKeyID2AccountID index updated after SaveAccount(): %v", err) + } + + if a, err := store.GetAccountByUser("testuser"); a == nil { + t.Errorf("expecting UserID2AccountID index updated after SaveAccount(): %v", err) + } + + if a, err := store.GetAccountByPeerID("testpeer"); a == nil { + t.Errorf("expecting PeerID2AccountID index updated after SaveAccount(): %v", err) + } + + if a, err := store.GetAccountBySetupKey(setupKey.Key); a == nil { + t.Errorf("expecting SetupKeyID2AccountID index updated after SaveAccount(): %v", err) + } +} + +func TestSqlite_SavePeerStatus(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + + store := newSqliteStoreFromFile(t, "testdata/store.json") + + account, err := store.GetAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b") + require.NoError(t, err) + + // save status of non-existing peer + newStatus := PeerStatus{Connected: true, LastSeen: time.Now().UTC()} + err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus) + assert.Error(t, err) + + // save new status of existing peer + account.Peers["testpeer"] = &Peer{ + Key: "peerkey", + ID: "testpeer", + SetupKey: "peerkeysetupkey", + IP: net.IP{127, 0, 0, 1}, + Meta: PeerSystemMeta{}, + Name: "peer name", + Status: &PeerStatus{Connected: false, LastSeen: time.Now().UTC()}, + } + + err = store.SaveAccount(account) + require.NoError(t, err) + + err = store.SavePeerStatus(account.Id, "testpeer", newStatus) + require.NoError(t, err) + + account, err = store.GetAccount(account.Id) + require.NoError(t, err) + + actual := account.Peers["testpeer"].Status + assert.Equal(t, newStatus, *actual) +} + +func TestSqlite_TestGetAccountByPrivateDomain(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + + store := newSqliteStoreFromFile(t, "testdata/store.json") + + existingDomain := "test.com" + + account, err := store.GetAccountByPrivateDomain(existingDomain) + require.NoError(t, err, "should found account") + require.Equal(t, existingDomain, account.Domain, "domains should match") + + _, err = store.GetAccountByPrivateDomain("missing-domain.com") + require.Error(t, err, "should return error on domain lookup") +} + +func TestSqlite_GetTokenIDByHashedToken(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + + store := newSqliteStoreFromFile(t, "testdata/store.json") + + hashed := "SoMeHaShEdToKeN" + id := "9dj38s35-63fb-11ec-90d6-0242ac120003" + + token, err := store.GetTokenIDByHashedToken(hashed) + require.NoError(t, err) + require.Equal(t, id, token) +} + +func TestSqlite_GetUserByTokenID(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + + store := newSqliteStoreFromFile(t, "testdata/store.json") + + id := "9dj38s35-63fb-11ec-90d6-0242ac120003" + + user, err := store.GetUserByTokenID(id) + require.NoError(t, err) + require.Equal(t, id, user.PATs[id].ID) +} + +func newSqliteStore(t *testing.T) *SqliteStore { + t.Helper() + + store, err := NewSqliteStore(t.TempDir(), nil) + require.NoError(t, err) + require.NotNil(t, store) + + return store +} + +func newSqliteStoreFromFile(t *testing.T, filename string) *SqliteStore { + t.Helper() + + storeDir := t.TempDir() + + err := util.CopyFileContents(filename, filepath.Join(storeDir, "store.json")) + require.NoError(t, err) + + fStore, err := NewFileStore(storeDir, nil) + require.NoError(t, err) + + store, err := NewSqliteStoreFromFileStore(fStore, storeDir, nil) + require.NoError(t, err) + require.NotNil(t, store) + + return store +} + +func newAccount(store Store, id int) error { + str := fmt.Sprintf("%s-%d", uuid.New().String(), id) + account := newAccountWithId(str, str+"-testuser", "example.com") + setupKey := GenerateDefaultSetupKey() + account.SetupKeys[setupKey.Key] = setupKey + account.Peers["p"+str] = &Peer{ + Key: "peerkey" + str, + SetupKey: "peerkeysetupkey", + IP: net.IP{127, 0, 0, 1}, + Meta: PeerSystemMeta{}, + Name: "peer name", + Status: &PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + } + + return store.SaveAccount(account) +} diff --git a/management/server/store.go b/management/server/store.go index 9ebe41235..66b239f96 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -1,6 +1,15 @@ package server -import "time" +import ( + "fmt" + "os" + "strings" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server/telemetry" +) type Store interface { GetAllAccounts() []*Account @@ -25,4 +34,65 @@ type Store interface { SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error // Close should close the store persisting all unsaved data. Close() error + // GetStoreEngine should return StoreEngine of the current store implementation. + // This is also a method of metrics.DataSource interface. + GetStoreEngine() StoreEngine +} + +type StoreEngine string + +const ( + FileStoreEngine StoreEngine = "jsonfile" + SqliteStoreEngine StoreEngine = "sqlite" +) + +func getStoreEngineFromEnv() StoreEngine { + // NETBIRD_STORE_ENGINE supposed to be used in tests. Otherwise rely on the config file. + kind, ok := os.LookupEnv("NETBIRD_STORE_ENGINE") + if !ok { + return FileStoreEngine + } + + value := StoreEngine(strings.ToLower(kind)) + + if value == FileStoreEngine || value == SqliteStoreEngine { + return value + } + + return FileStoreEngine +} + +func NewStore(kind StoreEngine, dataDir string, metrics telemetry.AppMetrics) (Store, error) { + if kind == "" { + // fallback to env. Normally this only should be used from tests + kind = getStoreEngineFromEnv() + } + switch kind { + case FileStoreEngine: + log.Info("using JSON file store engine") + return NewFileStore(dataDir, metrics) + case SqliteStoreEngine: + log.Info("using SQLite store engine") + return NewSqliteStore(dataDir, metrics) + default: + return nil, fmt.Errorf("unsupported kind of store %s", kind) + } +} + +func NewStoreFromJson(dataDir string, metrics telemetry.AppMetrics) (Store, error) { + fstore, err := NewFileStore(dataDir, nil) + if err != nil { + return nil, err + } + + kind := getStoreEngineFromEnv() + + switch kind { + case FileStoreEngine: + return fstore, nil + case SqliteStoreEngine: + return NewSqliteStoreFromFileStore(fstore, dataDir, metrics) + default: + return nil, fmt.Errorf("unsupported store engine %s", kind) + } } diff --git a/management/server/store_test.go b/management/server/store_test.go new file mode 100644 index 000000000..72bbaf949 --- /dev/null +++ b/management/server/store_test.go @@ -0,0 +1,88 @@ +package server + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +type benchCase struct { + name string + storeFn func(b *testing.B) Store + size int +} + +var newFs = func(b *testing.B) Store { + store, _ := NewFileStore(b.TempDir(), nil) + return store +} + +var newSqlite = func(b *testing.B) Store { + store, _ := NewSqliteStore(b.TempDir(), nil) + return store +} + +func BenchmarkTest_StoreWrite(b *testing.B) { + cases := []benchCase{ + {name: "FileStore_Write", storeFn: newFs, size: 100}, + {name: "SqliteStore_Write", storeFn: newSqlite, size: 100}, + {name: "FileStore_Write", storeFn: newFs, size: 500}, + {name: "SqliteStore_Write", storeFn: newSqlite, size: 500}, + {name: "FileStore_Write", storeFn: newFs, size: 1000}, + {name: "SqliteStore_Write", storeFn: newSqlite, size: 1000}, + {name: "FileStore_Write", storeFn: newFs, size: 2000}, + {name: "SqliteStore_Write", storeFn: newSqlite, size: 2000}, + } + + for _, c := range cases { + name := fmt.Sprintf("%s_%d", c.name, c.size) + store := c.storeFn(b) + + for i := 0; i < c.size; i++ { + _ = newAccount(store, i) + } + + b.Run(name, func(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + i := c.size + for pb.Next() { + i++ + err := newAccount(store, i) + require.NoError(b, err) + } + }) + }) + } +} + +func BenchmarkTest_StoreRead(b *testing.B) { + cases := []benchCase{ + {name: "FileStore_Read", storeFn: newFs, size: 100}, + {name: "SqliteStore_Read", storeFn: newSqlite, size: 100}, + {name: "FileStore_Read", storeFn: newFs, size: 500}, + {name: "SqliteStore_Read", storeFn: newSqlite, size: 500}, + {name: "FileStore_Read", storeFn: newFs, size: 1000}, + {name: "SqliteStore_Read", storeFn: newSqlite, size: 1000}, + } + + for _, c := range cases { + name := fmt.Sprintf("%s_%d", c.name, c.size) + store := c.storeFn(b) + + for i := 0; i < c.size; i++ { + _ = newAccount(store, i) + } + + accounts := store.GetAllAccounts() + id := accounts[c.size-1].Id + + b.Run(name, func(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, _ = store.GetAccount(id) + } + }) + }) + } +} diff --git a/management/server/testdata/store.json b/management/server/testdata/store.json index ecde766c3..1fa4e3a9a 100644 --- a/management/server/testdata/store.json +++ b/management/server/testdata/store.json @@ -2,52 +2,87 @@ "Accounts": { "bf1c8084-ba50-4ce7-9439-34653001fc3b": { "Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b", + "CreatedBy": "", "Domain": "test.com", "DomainCategory": "private", "IsDomainPrimaryAccount": true, "SetupKeys": { "A2C8E62B-38F5-4553-B31E-DD66C696CEBB": { + "Id": "", + "AccountID": "", "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", "Name": "Default key", "Type": "reusable", "CreatedAt": "2021-08-19T20:46:20.005936822+02:00", "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00", + "UpdatedAt": "0001-01-01T00:00:00Z", "Revoked": false, - "UsedTimes": 0 - + "UsedTimes": 0, + "LastUsed": "0001-01-01T00:00:00Z", + "AutoGroups": null, + "UsageLimit": 0, + "Ephemeral": false } }, "Network": { - "Id": "af1c8024-ha40-4ce2-9418-34653101fc3c", + "id": "af1c8024-ha40-4ce2-9418-34653101fc3c", "Net": { "IP": "100.64.0.0", "Mask": "//8AAA==" }, - "Dns": null + "Dns": "", + "Serial": 0 }, "Peers": {}, "Users": { "edafee4e-63fb-11ec-90d6-0242ac120003": { "Id": "edafee4e-63fb-11ec-90d6-0242ac120003", + "AccountID": "", "Role": "admin", - "PATs": {} + "IsServiceUser": false, + "ServiceUserName": "", + "AutoGroups": null, + "PATs": {}, + "Blocked": false, + "LastLogin": "0001-01-01T00:00:00Z" }, "f4f6d672-63fb-11ec-90d6-0242ac120003": { "Id": "f4f6d672-63fb-11ec-90d6-0242ac120003", + "AccountID": "", "Role": "user", + "IsServiceUser": false, + "ServiceUserName": "", + "AutoGroups": null, "PATs": { "9dj38s35-63fb-11ec-90d6-0242ac120003": { - "ID":"9dj38s35-63fb-11ec-90d6-0242ac120003", - "Description":"some Description", - "HashedToken":"SoMeHaShEdToKeN", - "ExpirationDate":"2023-02-27T00:00:00Z", - "CreatedBy":"user", - "CreatedAt":"2023-01-01T00:00:00Z", - "LastUsed":"2023-02-01T00:00:00Z" + "ID": "9dj38s35-63fb-11ec-90d6-0242ac120003", + "UserID": "", + "Name": "", + "HashedToken": "SoMeHaShEdToKeN", + "ExpirationDate": "2023-02-27T00:00:00Z", + "CreatedBy": "user", + "CreatedAt": "2023-01-01T00:00:00Z", + "LastUsed": "2023-02-01T00:00:00Z" } - } + }, + "Blocked": false, + "LastLogin": "0001-01-01T00:00:00Z" } + }, + "Groups": null, + "Rules": null, + "Policies": [], + "Routes": null, + "NameServerGroups": null, + "DNSSettings": null, + "Settings": { + "PeerLoginExpirationEnabled": false, + "PeerLoginExpiration": 86400000000000, + "GroupsPropagationEnabled": false, + "JWTGroupsEnabled": false, + "JWTGroupsClaimName": "" } } - } + }, + "InstallationID": "" } \ No newline at end of file diff --git a/management/server/user.go b/management/server/user.go index 3169c784f..22edd2c2c 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -22,6 +22,9 @@ const ( UserStatusActive UserStatus = "active" UserStatusDisabled UserStatus = "disabled" UserStatusInvited UserStatus = "invited" + + UserIssuedAPI = "api" + UserIssuedIntegration = "integration" ) // StrRoleToUserRole returns UserRole for a given strRole or UserRoleUnknown if the specified role is unknown @@ -42,20 +45,38 @@ type UserStatus string // UserRole is the role of a User type UserRole string +// IntegrationReference holds the reference to a particular integration +type IntegrationReference struct { + ID int + IntegrationType string +} + +func (ir IntegrationReference) String() string { + return fmt.Sprintf("%d:%s", ir.ID, ir.IntegrationType) +} + // User represents a user of the system type User struct { - Id string + Id string `gorm:"primaryKey"` + // AccountID is a reference to Account that this object belongs + AccountID string `json:"-" gorm:"index"` Role UserRole IsServiceUser bool // ServiceUserName is only set if IsServiceUser is true ServiceUserName string // AutoGroups is a list of Group IDs to auto-assign to peers registered by this user - AutoGroups []string - PATs map[string]*PersonalAccessToken + AutoGroups []string `gorm:"serializer:json"` + PATs map[string]*PersonalAccessToken `gorm:"-"` + PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id"` // Blocked indicates whether the user is blocked. Blocked users can't use the system. Blocked bool // LastLogin is the last time the user logged in to IdP LastLogin time.Time + + // Issued of the user + Issued string `gorm:"default:api"` + + IntegrationReference IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` } // IsBlocked returns true if the user is blocked, false otherwise @@ -90,6 +111,7 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) { IsServiceUser: u.IsServiceUser, IsBlocked: u.Blocked, LastLogin: u.LastLogin, + Issued: u.Issued, }, nil } if userData.ID != u.Id { @@ -111,6 +133,7 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) { IsServiceUser: u.IsServiceUser, IsBlocked: u.Blocked, LastLogin: u.LastLogin, + Issued: u.Issued, }, nil } @@ -123,36 +146,40 @@ func (u *User) Copy() *User { pats[k] = v.Copy() } return &User{ - Id: u.Id, - Role: u.Role, - AutoGroups: autoGroups, - IsServiceUser: u.IsServiceUser, - ServiceUserName: u.ServiceUserName, - PATs: pats, - Blocked: u.Blocked, - LastLogin: u.LastLogin, + Id: u.Id, + AccountID: u.AccountID, + Role: u.Role, + AutoGroups: autoGroups, + IsServiceUser: u.IsServiceUser, + ServiceUserName: u.ServiceUserName, + PATs: pats, + Blocked: u.Blocked, + LastLogin: u.LastLogin, + Issued: u.Issued, + IntegrationReference: u.IntegrationReference, } } // NewUser creates a new user -func NewUser(id string, role UserRole, isServiceUser bool, serviceUserName string, autoGroups []string) *User { +func NewUser(id string, role UserRole, isServiceUser bool, serviceUserName string, autoGroups []string, issued string) *User { return &User{ Id: id, Role: role, IsServiceUser: isServiceUser, ServiceUserName: serviceUserName, AutoGroups: autoGroups, + Issued: issued, } } // NewRegularUser creates a new user with role UserRoleUser func NewRegularUser(id string) *User { - return NewUser(id, UserRoleUser, false, "", []string{}) + return NewUser(id, UserRoleUser, false, "", []string{}, UserIssuedAPI) } // NewAdminUser creates a new user with role UserRoleAdmin func NewAdminUser(id string) *User { - return NewUser(id, UserRoleAdmin, false, "", []string{}) + return NewUser(id, UserRoleAdmin, false, "", []string{}, UserIssuedAPI) } // createServiceUser creates a new service user under the given account. @@ -174,7 +201,7 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUs } newUserID := uuid.New().String() - newUser := NewUser(newUserID, role, true, serviceUserName, autoGroups) + newUser := NewUser(newUserID, role, true, serviceUserName, autoGroups, UserIssuedAPI) log.Debugf("New User: %v", newUser) account.Users[newUserID] = newUser @@ -195,6 +222,7 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUs Status: string(UserStatusActive), IsServiceUser: true, LastLogin: time.Time{}, + Issued: UserIssuedAPI, }, nil } @@ -224,10 +252,20 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID) } - // initiator is the one who is inviting the new user - initiatorUser, err := am.lookupUserInCache(userID, account) + initiatorUser, err := account.FindUser(userID) if err != nil { - return nil, status.Errorf(status.NotFound, "user %s doesn't exist in IdP", userID) + return nil, status.Errorf(status.NotFound, "initiator user with ID %s doesn't exist", userID) + } + + inviterID := userID + if initiatorUser.IsServiceUser { + inviterID = account.CreatedBy + } + + // inviterUser is the one who is inviting the new user + inviterUser, err := am.lookupUserInCache(inviterID, account) + if err != nil || inviterUser == nil { + return nil, status.Errorf(status.NotFound, "inviter user with ID %s doesn't exist in IdP", inviterID) } // check if the user is already registered with this email => reject @@ -249,16 +287,18 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account") } - idpUser, err := am.idpManager.CreateUser(invite.Email, invite.Name, accountID, initiatorUser.Email) + idpUser, err := am.idpManager.CreateUser(invite.Email, invite.Name, accountID, inviterUser.Email) if err != nil { return nil, err } role := StrRoleToUserRole(invite.Role) newUser := &User{ - Id: idpUser.ID, - Role: role, - AutoGroups: invite.AutoGroups, + Id: idpUser.ID, + Role: role, + AutoGroups: invite.AutoGroups, + Issued: invite.Issued, + IntegrationReference: invite.IntegrationReference, } account.Users[idpUser.ID] = newUser @@ -285,6 +325,14 @@ func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) ( return nil, fmt.Errorf("failed to get account with token claims %v", err) } + unlock := am.Store.AcquireAccountLock(account.Id) + defer unlock() + + account, err = am.Store.GetAccount(account.Id) + if err != nil { + return nil, fmt.Errorf("failed to get an account from store %v", err) + } + user, ok := account.Users[claims.UserId] if !ok { return nil, status.Errorf(status.NotFound, "user not found") @@ -292,16 +340,16 @@ func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) ( // this code should be outside of the am.GetAccountFromToken(claims) because this method is called also by the gRPC // server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event. - unlock := am.Store.AcquireAccountLock(account.Id) newLogin := user.LastDashboardLoginChanged(claims.LastLogin) + err = am.Store.SaveUserLastLogin(account.Id, claims.UserId, claims.LastLogin) - unlock() + if err != nil { + log.Errorf("failed saving user last login: %v", err) + } + if newLogin { meta := map[string]any{"timestamp": claims.LastLogin} am.storeEvent(claims.UserId, claims.UserId, account.Id, activity.DashboardLogin, meta) - if err != nil { - log.Errorf("failed saving user last login: %v", err) - } } return user, nil @@ -339,6 +387,10 @@ func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, t return status.Errorf(status.NotFound, "target user not found") } + if targetUser.Issued == UserIssuedIntegration { + return status.Errorf(status.PermissionDenied, "only integration can delete this user") + } + // handle service user first and exit, no need to fetch extra data from IDP, etc if targetUser.IsServiceUser { am.deleteServiceUser(account, initiatorUserID, targetUser) @@ -667,7 +719,7 @@ func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, upd if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled { removedGroups := difference(oldUser.AutoGroups, update.AutoGroups) - // need force update all auto groups in any case they will not be dublicated + // need force update all auto groups in any case they will not be duplicated account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...) account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...) diff --git a/management/server/user_test.go b/management/server/user_test.go index 1565814b8..f1b997186 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -251,6 +251,7 @@ func TestUser_Copy(t *testing.T) { // this is an imaginary case which will never be in DB this way user := User{ Id: "userId", + AccountID: "accountId", Role: "role", IsServiceUser: true, ServiceUserName: "servicename", @@ -268,6 +269,11 @@ func TestUser_Copy(t *testing.T) { }, Blocked: false, LastLogin: time.Now(), + Issued: "test", + IntegrationReference: IntegrationReference{ + ID: 0, + IntegrationType: "test", + }, } err := validateStruct(user) @@ -291,6 +297,11 @@ func validateStruct(s interface{}) (err error) { field := structVal.Field(i) fieldName := structType.Field(i).Name + // skip gorm internal fields + if json, ok := structType.Field(i).Tag.Lookup("json"); ok && json == "-" { + continue + } + isSet := field.IsValid() && (!field.IsZero() || field.Type().String() == "bool") if !isSet { @@ -447,12 +458,25 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) { func TestUser_DeleteUser_regularUser(t *testing.T) { store := newStore(t) account := newAccountWithId(mockAccountID, mockUserID, "") + targetId := "user2" account.Users[targetId] = &User{ Id: targetId, IsServiceUser: true, ServiceUserName: "user2username", } + targetId = "user3" + account.Users[targetId] = &User{ + Id: targetId, + IsServiceUser: false, + Issued: UserIssuedAPI, + } + targetId = "user4" + account.Users[targetId] = &User{ + Id: targetId, + IsServiceUser: false, + Issued: UserIssuedIntegration, + } err := store.SaveAccount(account) if err != nil { @@ -464,10 +488,37 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - err = am.DeleteUser(mockAccountID, mockUserID, targetId) - if err != nil { - t.Errorf("unexpected error: %s", err) + testCases := []struct { + name string + userID string + assertErrFunc assert.ErrorAssertionFunc + assertErrMessage string + }{ + { + name: "Delete service user successfully ", + userID: "user2", + assertErrFunc: assert.NoError, + }, + { + name: "Delete regular user successfully ", + userID: "user3", + assertErrFunc: assert.NoError, + }, + { + name: "Delete integration regular user permission denied ", + userID: "user4", + assertErrFunc: assert.Error, + assertErrMessage: "only integration can delete this user", + }, } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + err = am.DeleteUser(mockAccountID, mockUserID, testCase.userID) + testCase.assertErrFunc(t, err, testCase.assertErrMessage) + }) + } + } func TestDefaultAccountManager_GetUser(t *testing.T) { diff --git a/release_files/darwin-ui-installer.sh b/release_files/darwin-ui-installer.sh index 7e8115b64..5179f02d6 100644 --- a/release_files/darwin-ui-installer.sh +++ b/release_files/darwin-ui-installer.sh @@ -10,6 +10,7 @@ then wiretrustee service stop || true wiretrustee service uninstall || true fi + # check if netbird is installed NB_BIN=$(which netbird) if [ -z "$NB_BIN" ] @@ -41,4 +42,4 @@ netbird service install 2> /dev/null || true netbird service start || true # start app -open /Applications/Netbird\ UI.app \ No newline at end of file +open /Applications/Netbird\ UI.app diff --git a/release_files/darwin_pkg/preinstall b/release_files/darwin_pkg/preinstall index cdea1465c..5965e82eb 100755 --- a/release_files/darwin_pkg/preinstall +++ b/release_files/darwin_pkg/preinstall @@ -8,6 +8,13 @@ AGENT=/usr/local/bin/netbird mkdir -p /var/log/netbird/ { + # check if it was installed with brew + brew list --formula | grep netbird + if [ $? -eq 0 ] + then + echo "NetBird has been installed with Brew. Please use Brew to update the package." + exit 1 + fi osascript -e 'quit app "Netbird"' || true $AGENT service stop || true diff --git a/release_files/install.sh b/release_files/install.sh index c553cc28a..a0a9abf98 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -23,19 +23,28 @@ if command -v sudo > /dev/null && [ "$(id -u)" -ne 0 ]; then SUDO="sudo" fi -get_latest_release() { +if [ -z ${NETBIRD_RELEASE+x} ]; then + NETBIRD_RELEASE=latest +fi + +get_release() { + local RELEASE=$1 + if [ "$RELEASE" = "latest" ]; then + local TAG="latest" + else + local TAG="tags/${RELEASE}" + fi if [ -n "$GITHUB_TOKEN" ]; then - curl -H "Authorization: token ${GITHUB_TOKEN}" -s "https://api.github.com/repos/${OWNER}/${REPO}/releases/latest" \ + curl -H "Authorization: token ${GITHUB_TOKEN}" -s "https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}" \ | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' else - curl -s "https://api.github.com/repos/${OWNER}/${REPO}/releases/latest" \ + curl -s "https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}" \ | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' fi - } download_release_binary() { - VERSION=$(get_latest_release) + VERSION=$(get_release "$NETBIRD_RELEASE") BASE_URL="https://github.com/${OWNER}/${REPO}/releases/download" BINARY_BASE_NAME="${VERSION#v}_${OS_TYPE}_${ARCH}.tar.gz" @@ -66,9 +75,14 @@ download_release_binary() { if [ "$OS_TYPE" = "darwin" ] && [ "$1" = "$UI_APP" ]; then INSTALL_DIR="/Applications/NetBird UI.app" + if test -d "$INSTALL_DIR" ; then + echo "removing $INSTALL_DIR" + rm -rfv "$INSTALL_DIR" + fi + # Unzip the app and move to INSTALL_DIR unzip -q -o "$BINARY_NAME" - mv "netbird_ui_${OS_TYPE}_${ARCH}" "$INSTALL_DIR" + mv "netbird_ui_${OS_TYPE}_${ARCH}/" "$INSTALL_DIR/" else ${SUDO} mkdir -p "$INSTALL_DIR" tar -xzvf "$BINARY_NAME" @@ -184,16 +198,6 @@ install_netbird() { fi fi - # Checks if SKIP_UI_APP env is set - if [ -z "$SKIP_UI_APP" ]; then - SKIP_UI_APP=false - else - if $SKIP_UI_APP; then - echo "SKIP_UI_APP has been set to true in the environment" - echo "NetBird UI installation will be omitted based on your preference" - fi - fi - # Run the installation, if a desktop environment is not detected # only the CLI will be installed case "$PACKAGE_MANAGER" in @@ -294,14 +298,22 @@ is_bin_package_manager() { fi } +stop_running_netbird_ui() { + NB_UI_PROC=$(ps -ef | grep "[n]etbird-ui" | awk '{print $2}') + if [ -n "$NB_UI_PROC" ]; then + echo "NetBird UI is running with PID $NB_UI_PROC. Stopping it..." + kill -9 "$NB_UI_PROC" + fi +} + update_netbird() { if is_bin_package_manager "$CONFIG_FILE"; then - latest_release=$(get_latest_release) + latest_release=$(get_release "latest") latest_version=${latest_release#v} installed_version=$(netbird version) if [ "$latest_version" = "$installed_version" ]; then - echo "Installed netbird version ($installed_version) is up-to-date" + echo "Installed NetBird version ($installed_version) is up-to-date" exit 0 fi @@ -310,8 +322,9 @@ update_netbird() { echo "" echo "Initiating NetBird update. This will stop the netbird service and restart it after the update" - ${SUDO} netbird service stop - ${SUDO} netbird service uninstall + ${SUDO} netbird service stop || true + ${SUDO} netbird service uninstall || true + stop_running_netbird_ui install_native_binaries ${SUDO} netbird service install @@ -322,6 +335,16 @@ update_netbird() { fi } +# Checks if SKIP_UI_APP env is set +if [ -z "$SKIP_UI_APP" ]; then + SKIP_UI_APP=false +else + if $SKIP_UI_APP; then + echo "SKIP_UI_APP has been set to true in the environment" + echo "NetBird UI installation will be omitted based on your preference" + fi +fi + # Identify OS name and default package manager if type uname >/dev/null 2>&1; then case "$(uname)" in @@ -334,10 +357,10 @@ if type uname >/dev/null 2>&1; then if [ "$ARCH" != "amd64" ] && [ "$ARCH" != "arm64" ] \ && [ "$ARCH" != "x86_64" ];then SKIP_UI_APP=true - echo "NetBird UI installation will be omitted as $ARCH is not a compactible architecture" + echo "NetBird UI installation will be omitted as $ARCH is not a compatible architecture" fi - # Allow netbird UI installation for linux running desktop enviroment + # Allow netbird UI installation for linux running desktop environment if [ -z "$XDG_CURRENT_DESKTOP" ];then SKIP_UI_APP=true echo "NetBird UI installation will be omitted as Linux does not run desktop environment" @@ -376,7 +399,13 @@ if type uname >/dev/null 2>&1; then esac fi -case "$1" in +UPDATE_FLAG=$1 + +if [ "${UPDATE_NETBIRD}-x" = "true-x" ]; then + UPDATE_FLAG="--update" +fi + +case "$UPDATE_FLAG" in --update) update_netbird ;; diff --git a/route/route.go b/route/route.go index eb7bcba2f..194e0c80d 100644 --- a/route/route.go +++ b/route/route.go @@ -65,17 +65,19 @@ func ToPrefixType(prefix string) NetworkType { // Route represents a route type Route struct { - ID string - Network netip.Prefix + ID string `gorm:"primaryKey"` + // AccountID is a reference to Account that this object belongs + AccountID string `gorm:"index"` + Network netip.Prefix `gorm:"serializer:gob"` NetID string Description string Peer string - PeerGroups []string + PeerGroups []string `gorm:"serializer:gob"` NetworkType NetworkType Masquerade bool Metric int Enabled bool - Groups []string + Groups []string `gorm:"serializer:json"` } // EventMeta returns activity event meta related to the route diff --git a/signal/client/grpc.go b/signal/client/grpc.go index 08430e8ef..fef443173 100644 --- a/signal/client/grpc.go +++ b/signal/client/grpc.go @@ -79,7 +79,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo transportOption, grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ - Time: 15 * time.Second, + Time: 30 * time.Second, Timeout: 10 * time.Second, })) diff --git a/version/update.go b/version/update.go new file mode 100644 index 000000000..1de60ea9a --- /dev/null +++ b/version/update.go @@ -0,0 +1,184 @@ +package version + +import ( + "io" + "net/http" + "sync" + "time" + + goversion "github.com/hashicorp/go-version" + log "github.com/sirupsen/logrus" +) + +const ( + fetchPeriod = 30 * time.Minute +) + +var ( + versionURL = "https://pkgs.netbird.io/releases/latest/version" +) + +// Update fetch the version info periodically and notify the onUpdateListener in case the UI version or the +// daemon version are deprecated +type Update struct { + uiVersion *goversion.Version + daemonVersion *goversion.Version + latestAvailable *goversion.Version + versionsLock sync.Mutex + + fetchTicker *time.Ticker + fetchDone chan struct{} + + onUpdateListener func() + listenerLock sync.Mutex +} + +// NewUpdate instantiate Update and start to fetch the new version information +func NewUpdate() *Update { + currentVersion, err := goversion.NewVersion(version) + if err != nil { + currentVersion, _ = goversion.NewVersion("0.0.0") + } + + latestAvailable, _ := goversion.NewVersion("0.0.0") + + u := &Update{ + latestAvailable: latestAvailable, + uiVersion: currentVersion, + fetchTicker: time.NewTicker(fetchPeriod), + fetchDone: make(chan struct{}), + } + go u.startFetcher() + return u +} + +// StopWatch stop the version info fetch loop +func (u *Update) StopWatch() { + u.fetchTicker.Stop() + + select { + case u.fetchDone <- struct{}{}: + default: + } +} + +// SetDaemonVersion update the currently running daemon version. If new version is available it will trigger +// the onUpdateListener +func (u *Update) SetDaemonVersion(newVersion string) bool { + daemonVersion, err := goversion.NewVersion(newVersion) + if err != nil { + daemonVersion, _ = goversion.NewVersion("0.0.0") + } + + u.versionsLock.Lock() + if u.daemonVersion != nil && u.daemonVersion.Equal(daemonVersion) { + u.versionsLock.Unlock() + return false + } + + u.daemonVersion = daemonVersion + u.versionsLock.Unlock() + return u.checkUpdate() +} + +// SetOnUpdateListener set new update listener +func (u *Update) SetOnUpdateListener(updateFn func()) { + u.listenerLock.Lock() + defer u.listenerLock.Unlock() + + u.onUpdateListener = updateFn + if u.isUpdateAvailable() { + u.onUpdateListener() + } +} + +func (u *Update) startFetcher() { + changed := u.fetchVersion() + if changed { + u.checkUpdate() + } + + select { + case <-u.fetchDone: + return + case <-u.fetchTicker.C: + changed := u.fetchVersion() + if changed { + u.checkUpdate() + } + } +} + +func (u *Update) fetchVersion() bool { + resp, err := http.Get(versionURL) + if err != nil { + log.Errorf("failed to fetch version info: %s", err) + return false + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + log.Errorf("invalid status code: %d", resp.StatusCode) + return false + } + + if resp.ContentLength > 100 { + log.Errorf("too large response: %d", resp.ContentLength) + return false + } + + content, err := io.ReadAll(resp.Body) + if err != nil { + log.Errorf("failed to read content: %s", err) + return false + } + + latestAvailable, err := goversion.NewVersion(string(content)) + if err != nil { + log.Errorf("failed to parse the version string: %s", err) + return false + } + + u.versionsLock.Lock() + defer u.versionsLock.Unlock() + + if u.latestAvailable.Equal(latestAvailable) { + return false + } + u.latestAvailable = latestAvailable + + return true +} + +func (u *Update) checkUpdate() bool { + if !u.isUpdateAvailable() { + return false + } + + u.listenerLock.Lock() + defer u.listenerLock.Unlock() + if u.onUpdateListener == nil { + return true + } + + go u.onUpdateListener() + return true +} + +func (u *Update) isUpdateAvailable() bool { + u.versionsLock.Lock() + defer u.versionsLock.Unlock() + + if u.latestAvailable.GreaterThan(u.uiVersion) { + return true + } + + if u.daemonVersion == nil { + return false + } + + if u.latestAvailable.GreaterThan(u.daemonVersion) { + return true + } + return false +} diff --git a/version/update_test.go b/version/update_test.go new file mode 100644 index 000000000..4537ce220 --- /dev/null +++ b/version/update_test.go @@ -0,0 +1,101 @@ +package version + +import ( + "fmt" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" +) + +func TestNewUpdate(t *testing.T) { + version = "1.0.0" + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "10.0.0") + })) + defer svr.Close() + versionURL = svr.URL + + wg := &sync.WaitGroup{} + wg.Add(1) + + onUpdate := false + u := NewUpdate() + defer u.StopWatch() + u.SetOnUpdateListener(func() { + onUpdate = true + wg.Done() + }) + + waitTimeout(wg) + if onUpdate != true { + t.Errorf("update not found") + } +} + +func TestDoNotUpdate(t *testing.T) { + version = "11.0.0" + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "10.0.0") + })) + defer svr.Close() + versionURL = svr.URL + + wg := &sync.WaitGroup{} + wg.Add(1) + + onUpdate := false + u := NewUpdate() + defer u.StopWatch() + u.SetOnUpdateListener(func() { + onUpdate = true + wg.Done() + }) + + waitTimeout(wg) + if onUpdate == true { + t.Errorf("invalid update") + } +} + +func TestDaemonUpdate(t *testing.T) { + version = "11.0.0" + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "11.0.0") + })) + defer svr.Close() + versionURL = svr.URL + + wg := &sync.WaitGroup{} + wg.Add(1) + + onUpdate := false + u := NewUpdate() + defer u.StopWatch() + u.SetOnUpdateListener(func() { + onUpdate = true + wg.Done() + }) + + u.SetDaemonVersion("10.0.0") + + waitTimeout(wg) + if onUpdate != true { + t.Errorf("invalid daemon version check") + } +} + +func waitTimeout(wg *sync.WaitGroup) { + c := make(chan struct{}) + go func() { + wg.Wait() + close(c) + }() + select { + case <-c: + return + case <-time.After(time.Second): + return + } +} diff --git a/version/url.go b/version/url.go new file mode 100644 index 000000000..ed43ab042 --- /dev/null +++ b/version/url.go @@ -0,0 +1,5 @@ +package version + +const ( + downloadURL = "https://app.netbird.io/install" +) diff --git a/version/url_darwin.go b/version/url_darwin.go new file mode 100644 index 000000000..cb58612f5 --- /dev/null +++ b/version/url_darwin.go @@ -0,0 +1,33 @@ +package version + +import ( + "os/exec" + "runtime" +) + +const ( + urlMacIntel = "https://pkgs.netbird.io/macos/amd64" + urlMacM1M2 = "https://pkgs.netbird.io/macos/arm64" +) + +// DownloadUrl return with the proper download link +func DownloadUrl() string { + cmd := exec.Command("brew", "list --formula | grep -i netbird") + if err := cmd.Start(); err != nil { + goto PKGINSTALL + } + + if err := cmd.Wait(); err == nil { + return downloadURL + } + +PKGINSTALL: + switch runtime.GOARCH { + case "amd64": + return urlMacIntel + case "arm64": + return urlMacM1M2 + default: + return downloadURL + } +} diff --git a/version/url_linux.go b/version/url_linux.go new file mode 100644 index 000000000..c8193e30c --- /dev/null +++ b/version/url_linux.go @@ -0,0 +1,6 @@ +package version + +// DownloadUrl return with the proper download link +func DownloadUrl() string { + return downloadURL +} diff --git a/version/url_windows.go b/version/url_windows.go new file mode 100644 index 000000000..f2055b109 --- /dev/null +++ b/version/url_windows.go @@ -0,0 +1,19 @@ +package version + +import "golang.org/x/sys/windows/registry" + +const ( + urlWinExe = "https://pkgs.netbird.io/windows/x64" +) + +var regKeyAppPath = "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\Netbird" + +// DownloadUrl return with the proper download link +func DownloadUrl() string { + _, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyAppPath, registry.QUERY_VALUE) + if err == nil { + return urlWinExe + } else { + return downloadURL + } +}