Compare commits

...

20 Commits

Author SHA1 Message Date
bcmmbaga
2b86463e96 fix merge
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-08-12 13:33:46 +03:00
bcmmbaga
9deff6f06b Merge branch 'main' into handle-existing-domain-user
# Conflicts:
#	management/server/account.go
#	management/server/account_test.go
2025-08-12 13:31:40 +03:00
Pascal Fischer
a942e4add5 [management] use readlock on add peer (#4308) 2025-08-11 15:21:26 +02:00
Viktor Liu
1022a5015c [client] Eliminate upstream server strings in dns code (#4267) 2025-08-11 11:57:21 +02:00
Maycon Santos
375fcf2752 [misc] Post release to forum (#4312) 2025-08-08 21:41:33 +02:00
Maycon Santos
9acf7f9262 [client] Update Windows installer description (#4306)
* [client] Update Windows installer description

* Update netbird.wxs
2025-08-08 21:18:58 +02:00
Viktor Liu
82937ba184 [client] Increase logout timeout (#4311) 2025-08-08 19:16:48 +02:00
Maycon Santos
0f52144894 [misc] Add docs acknowledgement check (#4310)
adds a GitHub Actions workflow to enforce documentation requirements for pull requests, ensuring contributors acknowledge whether their changes need documentation updates or provide a link to a corresponding docs PR.

- Adds a new GitHub Actions workflow that validates documentation acknowledgement in PR descriptions
- Updates the PR template to include mandatory documentation checkboxes and URL field
- Implements validation logic to ensure exactly one documentation option is selected and verifies docs PR URLs when provided
2025-08-08 18:14:26 +02:00
Krzysztof Nazarewski (kdn)
0926400b8a fix: profilemanager panic when reading incomplete config (#4309)
fix: profilemanager panic when reading incomplete config (#4309)
2025-08-08 18:44:25 +03:00
Viktor Liu
bef99d48f8 [client] Rename logout to deregister (#4307) 2025-08-08 15:48:30 +02:00
Pascal Fischer
9e95841252 [management] during JSON migration filter duplicates on conflict (#4303) 2025-08-07 14:12:07 +02:00
hakansa
6da3943559 [client] fix ssh command for non-default profile (#4298)
[client] fix ssh command for non-default profile (#4298)
2025-08-07 13:08:30 +03:00
Pascal Fischer
f5b4659adb [management] Mark SaveAccount deprecated (#4300) 2025-08-07 11:49:37 +02:00
Viktor Liu
3d19468b6c [client] Add windows arm64 build (#4206) 2025-08-07 11:30:19 +02:00
Pascal Fischer
5860e5343f [management] Rework DB locks (#4291) 2025-08-06 18:55:14 +02:00
Misha Bragin
dfd8bbc015 Change Netbird to NetBird in CMD (#4296) 2025-08-06 18:32:35 +02:00
Viktor Liu
abd152ee5a [misc] Separate shared code dependencies (#4288)
* Separate shared code dependencies

* Fix import

* Test respective shared code

* Update openapi ref

* Fix test

* Fix test path
2025-08-05 18:34:41 +02:00
bcmmbaga
1a1e94c805 Check account existence without fully loading it
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-08 18:26:46 +03:00
bcmmbaga
ed939bf7f5 add unit test
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-08 18:09:14 +03:00
bcmmbaga
7caf733217 Skip adding user to domain account if already exists
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-08 15:50:32 +03:00
212 changed files with 1492 additions and 1303 deletions

View File

@@ -12,6 +12,16 @@
- [ ] Is a feature enhancement - [ ] Is a feature enhancement
- [ ] It is a refactor - [ ] It is a refactor
- [ ] Created tests that fail without the change (if possible) - [ ] Created tests that fail without the change (if possible)
- [ ] Extended the README / documentation, if necessary
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md). > By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).
## Documentation
Select exactly one:
- [ ] I added/updated documentation for this change
- [ ] Documentation is **not needed** for this change (explain why)
### Docs PR URL (required if "docs added" is checked)
Paste the PR link from https://github.com/netbirdio/docs here:
https://github.com/netbirdio/docs/pull/__

View File

@@ -0,0 +1,41 @@
name: Check License Dependencies
on:
push:
branches: [ main ]
pull_request:
jobs:
check-dependencies:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Check for problematic license dependencies
run: |
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
# Find all directories except the problematic ones and system dirs
FOUND_ISSUES=0
find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort | while read dir; do
echo "=== Checking $dir ==="
# Search for problematic imports, excluding test files
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
if [ ! -z "$RESULTS" ]; then
echo "❌ Found problematic dependencies:"
echo "$RESULTS"
FOUND_ISSUES=1
else
echo "✓ No problematic dependencies found"
fi
done
if [ $FOUND_ISSUES -eq 1 ]; then
echo ""
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
echo "These packages will change license and should not be imported by client or shared code"
exit 1
else
echo ""
echo "✅ All license dependencies are clean"
fi

94
.github/workflows/docs-ack.yml vendored Normal file
View File

@@ -0,0 +1,94 @@
name: Docs Acknowledgement
on:
pull_request:
types: [opened, edited, synchronize]
permissions:
contents: read
pull-requests: read
jobs:
docs-ack:
name: Require docs PR URL or explicit "not needed"
runs-on: ubuntu-latest
steps:
- name: Read PR body
id: body
run: |
BODY=$(jq -r '.pull_request.body // ""' "$GITHUB_EVENT_PATH")
echo "body<<EOF" >> $GITHUB_OUTPUT
echo "$BODY" >> $GITHUB_OUTPUT
echo "EOF" >> $GITHUB_OUTPUT
- name: Validate checkbox selection
id: validate
run: |
body='${{ steps.body.outputs.body }}'
added_checked=$(printf "%s" "$body" | grep -E '^- \[x\] I added/updated documentation' -i | wc -l | tr -d ' ')
noneed_checked=$(printf "%s" "$body" | grep -E '^- \[x\] Documentation is \*\*not needed\*\*' -i | wc -l | tr -d ' ')
if [ "$added_checked" -eq 1 ] && [ "$noneed_checked" -eq 1 ]; then
echo "::error::Choose exactly one: either 'docs added' OR 'not needed'."
exit 1
fi
if [ "$added_checked" -eq 0 ] && [ "$noneed_checked" -eq 0 ]; then
echo "::error::You must check exactly one docs option in the PR template."
exit 1
fi
if [ "$added_checked" -eq 1 ]; then
echo "mode=added" >> $GITHUB_OUTPUT
else
echo "mode=noneed" >> $GITHUB_OUTPUT
fi
- name: Extract docs PR URL (when 'docs added')
if: steps.validate.outputs.mode == 'added'
id: extract
run: |
body='${{ steps.body.outputs.body }}'
# Strictly require HTTPS and that it's a PR in netbirdio/docs
# Examples accepted:
# https://github.com/netbirdio/docs/pull/1234
url=$(printf "%s" "$body" | grep -Eo 'https://github\.com/netbirdio/docs/pull/[0-9]+' | head -n1 || true)
if [ -z "$url" ]; then
echo "::error::You checked 'docs added' but didn't include a valid HTTPS PR link to netbirdio/docs (e.g., https://github.com/netbirdio/docs/pull/1234)."
exit 1
fi
pr_number=$(echo "$url" | sed -E 's#.*/pull/([0-9]+)$#\1#')
echo "url=$url" >> $GITHUB_OUTPUT
echo "pr_number=$pr_number" >> $GITHUB_OUTPUT
- name: Verify docs PR exists (and is open or merged)
if: steps.validate.outputs.mode == 'added'
uses: actions/github-script@v7
id: verify
with:
pr_number: ${{ steps.extract.outputs.pr_number }}
script: |
const prNumber = parseInt(core.getInput('pr_number'), 10);
const { data } = await github.rest.pulls.get({
owner: 'netbirdio',
repo: 'docs',
pull_number: prNumber
});
// Allow open or merged PRs
const ok = data.state === 'open' || data.merged === true;
core.setOutput('state', data.state);
core.setOutput('merged', String(!!data.merged));
if (!ok) {
core.setFailed(`Docs PR #${prNumber} exists but is neither open nor merged (state=${data.state}, merged=${data.merged}).`);
}
result-encoding: string
github-token: ${{ secrets.GITHUB_TOKEN }}
- name: All good
run: echo "Documentation requirement satisfied ✅"

18
.github/workflows/forum.yml vendored Normal file
View File

@@ -0,0 +1,18 @@
name: Post release topic on Discourse
on:
release:
types: [published]
jobs:
post:
runs-on: ubuntu-latest
steps:
- uses: roots/discourse-topic-github-release-action@main
with:
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
discourse-base-url: https://forum.netbird.io
discourse-author-username: NetBird
discourse-category: 17
discourse-tags:
releases

View File

@@ -259,7 +259,7 @@ jobs:
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test ${{ matrix.raceFlag }} \ go test ${{ matrix.raceFlag }} \
-exec 'sudo' \ -exec 'sudo' \
-timeout 10m ./relay/... -timeout 10m ./relay/... ./shared/relay/...
test_signal: test_signal:
name: "Signal / Unit" name: "Signal / Unit"
@@ -309,7 +309,7 @@ jobs:
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test \ go test \
-exec 'sudo' \ -exec 'sudo' \
-timeout 10m ./signal/... -timeout 10m ./signal/... ./shared/signal/...
test_management: test_management:
name: "Management / Unit" name: "Management / Unit"
@@ -369,7 +369,7 @@ jobs:
CI=true \ CI=true \
go test -tags=devcert \ go test -tags=devcert \
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \ -exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
-timeout 20m ./management/... -timeout 20m ./management/... ./shared/management/...
benchmark: benchmark:
name: "Management / Benchmark" name: "Management / Benchmark"
@@ -430,7 +430,7 @@ jobs:
CI=true \ CI=true \
go test -tags devcert -run=^$ -bench=. \ go test -tags devcert -run=^$ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/... -timeout 20m ./management/... ./shared/management/...
api_benchmark: api_benchmark:
name: "Management / Benchmark (API)" name: "Management / Benchmark (API)"
@@ -521,7 +521,7 @@ jobs:
-run=^$ \ -run=^$ \
-bench=. \ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
-timeout 20m ./management/... -timeout 20m ./management/... ./shared/management/...
api_integration_test: api_integration_test:
name: "Management / Integration" name: "Management / Integration"
@@ -571,4 +571,4 @@ jobs:
CI=true \ CI=true \
go test -tags=integration \ go test -tags=integration \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/... -timeout 20m ./management/... ./shared/management/...

View File

@@ -9,7 +9,7 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.0.21" SIGN_PIPE_VER: "v0.0.22"
GORELEASER_VER: "v2.3.2" GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird" PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH" COPYRIGHT: "NetBird GmbH"
@@ -79,6 +79,8 @@ jobs:
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64 - name: Generate windows syso amd64
run: goversioninfo -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso run: goversioninfo -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
- name: Generate windows syso arm64
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
- name: Run GoReleaser - name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
with: with:
@@ -154,10 +156,20 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64 run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
- name: Install LLVM-MinGW for ARM64 cross-compilation
run: |
cd /tmp
wget -q https://github.com/mstorsjo/llvm-mingw/releases/download/20250709/llvm-mingw-20250709-ucrt-ubuntu-22.04-x86_64.tar.xz
echo "60cafae6474c7411174cff1d4ba21a8e46cadbaeb05a1bace306add301628337 llvm-mingw-20250709-ucrt-ubuntu-22.04-x86_64.tar.xz" | sha256sum -c
tar -xf llvm-mingw-20250709-ucrt-ubuntu-22.04-x86_64.tar.xz
echo "/tmp/llvm-mingw-20250709-ucrt-ubuntu-22.04-x86_64/bin" >> $GITHUB_PATH
- name: Install goversioninfo - name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64 - name: Generate windows syso amd64
run: goversioninfo -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso run: goversioninfo -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
- name: Generate windows syso arm64
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
- name: Run GoReleaser - name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
@@ -231,17 +243,3 @@ jobs:
ref: ${{ env.SIGN_PIPE_VER }} ref: ${{ env.SIGN_PIPE_VER }}
token: ${{ secrets.SIGN_GITHUB_TOKEN }} token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }' inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'
post_on_forum:
runs-on: ubuntu-latest
continue-on-error: true
needs: [trigger_signer]
steps:
- uses: Codixer/discourse-topic-github-release-action@v2.0.1
with:
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
discourse-base-url: https://forum.netbird.io
discourse-author-username: NetBird
discourse-category: 17
discourse-tags:
releases

View File

@@ -5,7 +5,7 @@ on:
tags: tags:
- 'v*' - 'v*'
paths: paths:
- 'management/server/http/api/openapi.yml' - 'shared/management/http/api/openapi.yml'
jobs: jobs:
trigger_docs_api_update: trigger_docs_api_update:

View File

@@ -16,8 +16,6 @@ builds:
- arm64 - arm64
- 386 - 386
ignore: ignore:
- goos: windows
goarch: arm64
- goos: windows - goos: windows
goarch: arm goarch: arm
- goos: windows - goos: windows

View File

@@ -15,7 +15,7 @@ builds:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-ui-windows - id: netbird-ui-windows-amd64
dir: client/ui dir: client/ui
binary: netbird-ui binary: netbird-ui
env: env:
@@ -30,6 +30,22 @@ builds:
- -H windowsgui - -H windowsgui
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-ui-windows-arm64
dir: client/ui
binary: netbird-ui
env:
- CGO_ENABLED=1
- CC=aarch64-w64-mingw32-clang
- CXX=aarch64-w64-mingw32-clang++
goos:
- windows
goarch:
- arm64
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
- -H windowsgui
mod_timestamp: "{{ .CommitTimestamp }}"
archives: archives:
- id: linux-arch - id: linux-arch
name_template: "{{ .ProjectName }}-linux_{{ .Version }}_{{ .Os }}_{{ .Arch }}" name_template: "{{ .ProjectName }}-linux_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
@@ -38,7 +54,8 @@ archives:
- id: windows-arch - id: windows-arch
name_template: "{{ .ProjectName }}-windows_{{ .Version }}_{{ .Os }}_{{ .Arch }}" name_template: "{{ .ProjectName }}-windows_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
builds: builds:
- netbird-ui-windows - netbird-ui-windows-amd64
- netbird-ui-windows-arm64
nfpms: nfpms:
- maintainer: Netbird <dev@netbird.io> - maintainer: Netbird <dev@netbird.io>

View File

@@ -4,6 +4,7 @@ package android
import ( import (
"context" "context"
"slices"
"sync" "sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -112,7 +113,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
} }
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
@@ -138,7 +139,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
} }
// Stop the internal client and free the resources // Stop the internal client and free the resources
@@ -235,7 +236,7 @@ func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
return err return err
} }
dnsServer.OnUpdatedHostDNSServer(list.items) dnsServer.OnUpdatedHostDNSServer(slices.Clone(list.items))
return nil return nil
} }

View File

@@ -1,23 +1,34 @@
package android package android
import "fmt" import (
"fmt"
"net/netip"
// DNSList is a wrapper of []string "github.com/netbirdio/netbird/client/internal/dns"
)
// DNSList is a wrapper of []netip.AddrPort with default DNS port
type DNSList struct { type DNSList struct {
items []string items []netip.AddrPort
} }
// Add new DNS address to the collection // Add new DNS address to the collection, returns error if invalid
func (array *DNSList) Add(s string) { func (array *DNSList) Add(s string) error {
array.items = append(array.items, s) addr, err := netip.ParseAddr(s)
if err != nil {
return fmt.Errorf("invalid DNS address: %s", s)
}
addrPort := netip.AddrPortFrom(addr.Unmap(), dns.DefaultPort)
array.items = append(array.items, addrPort)
return nil
} }
// Get return an element of the collection // Get return an element of the collection as string
func (array *DNSList) Get(i int) (string, error) { func (array *DNSList) Get(i int) (string, error) {
if i >= len(array.items) || i < 0 { if i >= len(array.items) || i < 0 {
return "", fmt.Errorf("out of range") return "", fmt.Errorf("out of range")
} }
return array.items[i], nil return array.items[i].Addr().String(), nil
} }
// Size return with the size of the collection // Size return with the size of the collection

View File

@@ -3,20 +3,30 @@ package android
import "testing" import "testing"
func TestDNSList_Get(t *testing.T) { func TestDNSList_Get(t *testing.T) {
l := DNSList{ l := DNSList{}
items: make([]string, 1),
// Add a valid DNS address
err := l.Add("8.8.8.8")
if err != nil {
t.Errorf("unexpected error: %s", err)
} }
_, err := l.Get(0) // Test getting valid index
addr, err := l.Get(0)
if err != nil { if err != nil {
t.Errorf("invalid error: %s", err) t.Errorf("invalid error: %s", err)
} }
if addr != "8.8.8.8" {
t.Errorf("expected 8.8.8.8, got %s", addr)
}
// Test negative index
_, err = l.Get(-1) _, err = l.Get(-1)
if err == nil { if err == nil {
t.Errorf("expected error but got nil") t.Errorf("expected error but got nil")
} }
// Test out of bounds index
_, err = l.Get(1) _, err = l.Get(1)
if err == nil { if err == nil {
t.Errorf("expected error but got nil") t.Errorf("expected error but got nil")

View File

@@ -33,7 +33,7 @@ var (
var debugCmd = &cobra.Command{ var debugCmd = &cobra.Command{
Use: "debug", Use: "debug",
Short: "Debugging commands", Short: "Debugging commands",
Long: "Provides commands for debugging and logging control within the Netbird daemon.", Long: "Provides commands for debugging and logging control within the NetBird daemon.",
} }
var debugBundleCmd = &cobra.Command{ var debugBundleCmd = &cobra.Command{
@@ -46,8 +46,8 @@ var debugBundleCmd = &cobra.Command{
var logCmd = &cobra.Command{ var logCmd = &cobra.Command{
Use: "log", Use: "log",
Short: "Manage logging for the Netbird daemon", Short: "Manage logging for the NetBird daemon",
Long: `Commands to manage logging settings for the Netbird daemon, including ICE, gRPC, and general log levels.`, Long: `Commands to manage logging settings for the NetBird daemon, including ICE, gRPC, and general log levels.`,
} }
var logLevelCmd = &cobra.Command{ var logLevelCmd = &cobra.Command{
@@ -184,7 +184,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil { if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message()) return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
} }
cmd.Println("Netbird up") cmd.Println("netbird up")
time.Sleep(time.Second * 10) time.Sleep(time.Second * 10)
} }
@@ -202,7 +202,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message()) return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
} }
cmd.Println("Netbird down") cmd.Println("netbird down")
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
@@ -216,11 +216,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil { if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message()) return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
} }
cmd.Println("Netbird up") cmd.Println("netbird up")
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
headerPostUp := fmt.Sprintf("----- Netbird post-up - Timestamp: %s", time.Now().Format(time.RFC3339)) headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag)) statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag))
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil { if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
@@ -230,7 +230,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Creating debug bundle...") cmd.Println("Creating debug bundle...")
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration) headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag)) statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
request := &proto.DebugBundleRequest{ request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag, Anonymize: anonymizeFlag,
@@ -250,7 +250,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message()) return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
} }
cmd.Println("Netbird down") cmd.Println("netbird down")
} }
if !initialLevelTrace { if !initialLevelTrace {

View File

@@ -31,7 +31,7 @@ func init() {
var loginCmd = &cobra.Command{ var loginCmd = &cobra.Command{
Use: "login", Use: "login",
Short: "login to the Netbird Management Service (first run)", Short: "login to the NetBird Management Service (first run)",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
if err := setEnvAndFlags(cmd); err != nil { if err := setEnvAndFlags(cmd); err != nil {
return fmt.Errorf("set env and flags: %v", err) return fmt.Errorf("set env and flags: %v", err)

View File

@@ -12,14 +12,15 @@ import (
) )
var logoutCmd = &cobra.Command{ var logoutCmd = &cobra.Command{
Use: "logout", Use: "deregister",
Short: "logout from the Netbird Management Service and delete peer", Aliases: []string{"logout"},
Short: "deregister from the NetBird Management Service and delete peer",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd) SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout()) cmd.SetOut(cmd.OutOrStdout())
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7) ctx, cancel := context.WithTimeout(cmd.Context(), time.Second*15)
defer cancel() defer cancel()
conn, err := DialClientGRPCServer(ctx, daemonAddr) conn, err := DialClientGRPCServer(ctx, daemonAddr)
@@ -44,10 +45,10 @@ var logoutCmd = &cobra.Command{
} }
if _, err := daemonClient.Logout(ctx, req); err != nil { if _, err := daemonClient.Logout(ctx, req); err != nil {
return fmt.Errorf("logout: %v", err) return fmt.Errorf("deregister: %v", err)
} }
cmd.Println("Logged out successfully") cmd.Println("Deregistered successfully")
return nil return nil
}, },
} }

View File

@@ -16,14 +16,14 @@ import (
var profileCmd = &cobra.Command{ var profileCmd = &cobra.Command{
Use: "profile", Use: "profile",
Short: "manage Netbird profiles", Short: "manage NetBird profiles",
Long: `Manage Netbird profiles, allowing you to list, switch, and remove profiles.`, Long: `Manage NetBird profiles, allowing you to list, switch, and remove profiles.`,
} }
var profileListCmd = &cobra.Command{ var profileListCmd = &cobra.Command{
Use: "list", Use: "list",
Short: "list all profiles", Short: "list all profiles",
Long: `List all available profiles in the Netbird client.`, Long: `List all available profiles in the NetBird client.`,
Aliases: []string{"ls"}, Aliases: []string{"ls"},
RunE: listProfilesFunc, RunE: listProfilesFunc,
} }
@@ -31,7 +31,7 @@ var profileListCmd = &cobra.Command{
var profileAddCmd = &cobra.Command{ var profileAddCmd = &cobra.Command{
Use: "add <profile_name>", Use: "add <profile_name>",
Short: "add a new profile", Short: "add a new profile",
Long: `Add a new profile to the Netbird client. The profile name must be unique.`, Long: `Add a new profile to the NetBird client. The profile name must be unique.`,
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
RunE: addProfileFunc, RunE: addProfileFunc,
} }
@@ -39,7 +39,7 @@ var profileAddCmd = &cobra.Command{
var profileRemoveCmd = &cobra.Command{ var profileRemoveCmd = &cobra.Command{
Use: "remove <profile_name>", Use: "remove <profile_name>",
Short: "remove a profile", Short: "remove a profile",
Long: `Remove a profile from the Netbird client. The profile must not be active.`, Long: `Remove a profile from the NetBird client. The profile must not be active.`,
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
RunE: removeProfileFunc, RunE: removeProfileFunc,
} }
@@ -47,7 +47,7 @@ var profileRemoveCmd = &cobra.Command{
var profileSelectCmd = &cobra.Command{ var profileSelectCmd = &cobra.Command{
Use: "select <profile_name>", Use: "select <profile_name>",
Short: "select a profile", Short: "select a profile",
Long: `Select a profile to be the active profile in the Netbird client. The profile must exist.`, Long: `Select a profile to be the active profile in the NetBird client. The profile must exist.`,
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
RunE: selectProfileFunc, RunE: selectProfileFunc,
} }

View File

@@ -119,12 +119,12 @@ func init() {
rootCmd.PersistentFlags().StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]") rootCmd.PersistentFlags().StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]")
rootCmd.PersistentFlags().StringVarP(&managementURL, "management-url", "m", "", fmt.Sprintf("Management Service URL [http|https]://[host]:[port] (default \"%s\")", profilemanager.DefaultManagementURL)) rootCmd.PersistentFlags().StringVarP(&managementURL, "management-url", "m", "", fmt.Sprintf("Management Service URL [http|https]://[host]:[port] (default \"%s\")", profilemanager.DefaultManagementURL))
rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", profilemanager.DefaultAdminURL)) rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", profilemanager.DefaultAdminURL))
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level") rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets NetBird log level")
rootCmd.PersistentFlags().StringSliceVar(&logFiles, "log-file", []string{defaultLogFile}, "sets Netbird log paths written to simultaneously. If `console` is specified the log will be output to stdout. If `syslog` is specified the log will be sent to syslog daemon. You can pass the flag multiple times or separate entries by `,` character") rootCmd.PersistentFlags().StringSliceVar(&logFiles, "log-file", []string{defaultLogFile}, "sets NetBird log paths written to simultaneously. If `console` is specified the log will be output to stdout. If `syslog` is specified the log will be sent to syslog daemon. You can pass the flag multiple times or separate entries by `,` character")
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)") rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
rootCmd.PersistentFlags().StringVar(&setupKeyPath, "setup-key-file", "", "The path to a setup key obtained from the Management Service Dashboard (used to register peer) This is ignored if the setup-key flag is provided.") rootCmd.PersistentFlags().StringVar(&setupKeyPath, "setup-key-file", "", "The path to a setup key obtained from the Management Service Dashboard (used to register peer) This is ignored if the setup-key flag is provided.")
rootCmd.MarkFlagsMutuallyExclusive("setup-key", "setup-key-file") rootCmd.MarkFlagsMutuallyExclusive("setup-key", "setup-key-file")
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.") rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets WireGuard PreSharedKey property. If set, then only peers that have the same key can communicate.")
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device") rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output") rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Overrides the default profile file location") rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Overrides the default profile file location")

View File

@@ -50,10 +50,10 @@ func TestSetFlagsFromEnvVars(t *testing.T) {
} }
cmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil, cmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
`comma separated list of external IPs to map to the Wireguard interface`) `comma separated list of external IPs to map to the WireGuard interface`)
cmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name") cmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "WireGuard interface name")
cmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "Enable Rosenpass feature Rosenpass.") cmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "Enable Rosenpass feature Rosenpass.")
cmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port") cmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "WireGuard interface listening port")
t.Setenv("NB_EXTERNAL_IP_MAP", "abc,dec") t.Setenv("NB_EXTERNAL_IP_MAP", "abc,dec")
t.Setenv("NB_INTERFACE_NAME", "test-name") t.Setenv("NB_INTERFACE_NAME", "test-name")

View File

@@ -19,7 +19,7 @@ import (
var serviceCmd = &cobra.Command{ var serviceCmd = &cobra.Command{
Use: "service", Use: "service",
Short: "manages Netbird service", Short: "manages NetBird service",
} }
var ( var (
@@ -47,7 +47,7 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name") rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
serviceEnvDesc := `Sets extra environment variables for the service. ` + serviceEnvDesc := `Sets extra environment variables for the service. ` +
`You can specify a comma-separated list of KEY=VALUE pairs. ` + `You can specify a comma-separated list of KEY=VALUE pairs. ` +
`E.g. --service-env LOG_LEVEL=debug,CUSTOM_VAR=value` `E.g. --service-env NB_LOG_LEVEL=debug,CUSTOM_VAR=value`
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc) installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
reconfigureCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc) reconfigureCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
@@ -64,7 +64,7 @@ func newSVCConfig() (*service.Config, error) {
config := &service.Config{ config := &service.Config{
Name: serviceName, Name: serviceName,
DisplayName: "Netbird", DisplayName: "Netbird",
Description: "Netbird mesh network client", Description: "NetBird mesh network client",
Option: make(service.KeyValue), Option: make(service.KeyValue),
EnvVars: make(map[string]string), EnvVars: make(map[string]string),
} }

View File

@@ -24,7 +24,7 @@ import (
func (p *program) Start(svc service.Service) error { func (p *program) Start(svc service.Service) error {
// Start should not block. Do the actual work async. // Start should not block. Do the actual work async.
log.Info("starting Netbird service") //nolint log.Info("starting NetBird service") //nolint
// Collect static system and platform information // Collect static system and platform information
system.UpdateStaticInfo() system.UpdateStaticInfo()
@@ -97,7 +97,7 @@ func (p *program) Stop(srv service.Service) error {
} }
time.Sleep(time.Second * 2) time.Sleep(time.Second * 2)
log.Info("stopped Netbird service") //nolint log.Info("stopped NetBird service") //nolint
return nil return nil
} }
@@ -131,7 +131,7 @@ func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel
var runCmd = &cobra.Command{ var runCmd = &cobra.Command{
Use: "run", Use: "run",
Short: "runs Netbird as service", Short: "runs NetBird as service",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context()) ctx, cancel := context.WithCancel(cmd.Context())
@@ -149,7 +149,7 @@ var runCmd = &cobra.Command{
var startCmd = &cobra.Command{ var startCmd = &cobra.Command{
Use: "start", Use: "start",
Short: "starts Netbird service", Short: "starts NetBird service",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context()) ctx, cancel := context.WithCancel(cmd.Context())
s, err := setupServiceControlCommand(cmd, ctx, cancel) s, err := setupServiceControlCommand(cmd, ctx, cancel)
@@ -160,14 +160,14 @@ var startCmd = &cobra.Command{
if err := s.Start(); err != nil { if err := s.Start(); err != nil {
return fmt.Errorf("start service: %w", err) return fmt.Errorf("start service: %w", err)
} }
cmd.Println("Netbird service has been started") cmd.Println("NetBird service has been started")
return nil return nil
}, },
} }
var stopCmd = &cobra.Command{ var stopCmd = &cobra.Command{
Use: "stop", Use: "stop",
Short: "stops Netbird service", Short: "stops NetBird service",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context()) ctx, cancel := context.WithCancel(cmd.Context())
s, err := setupServiceControlCommand(cmd, ctx, cancel) s, err := setupServiceControlCommand(cmd, ctx, cancel)
@@ -178,14 +178,14 @@ var stopCmd = &cobra.Command{
if err := s.Stop(); err != nil { if err := s.Stop(); err != nil {
return fmt.Errorf("stop service: %w", err) return fmt.Errorf("stop service: %w", err)
} }
cmd.Println("Netbird service has been stopped") cmd.Println("NetBird service has been stopped")
return nil return nil
}, },
} }
var restartCmd = &cobra.Command{ var restartCmd = &cobra.Command{
Use: "restart", Use: "restart",
Short: "restarts Netbird service", Short: "restarts NetBird service",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context()) ctx, cancel := context.WithCancel(cmd.Context())
s, err := setupServiceControlCommand(cmd, ctx, cancel) s, err := setupServiceControlCommand(cmd, ctx, cancel)
@@ -196,14 +196,14 @@ var restartCmd = &cobra.Command{
if err := s.Restart(); err != nil { if err := s.Restart(); err != nil {
return fmt.Errorf("restart service: %w", err) return fmt.Errorf("restart service: %w", err)
} }
cmd.Println("Netbird service has been restarted") cmd.Println("NetBird service has been restarted")
return nil return nil
}, },
} }
var svcStatusCmd = &cobra.Command{ var svcStatusCmd = &cobra.Command{
Use: "status", Use: "status",
Short: "shows Netbird service status", Short: "shows NetBird service status",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context()) ctx, cancel := context.WithCancel(cmd.Context())
s, err := setupServiceControlCommand(cmd, ctx, cancel) s, err := setupServiceControlCommand(cmd, ctx, cancel)
@@ -228,7 +228,7 @@ var svcStatusCmd = &cobra.Command{
statusText = fmt.Sprintf("Unknown (%d)", status) statusText = fmt.Sprintf("Unknown (%d)", status)
} }
cmd.Printf("Netbird service status: %s\n", statusText) cmd.Printf("NetBird service status: %s\n", statusText)
return nil return nil
}, },
} }

View File

@@ -99,7 +99,7 @@ func createServiceConfigForInstall() (*service.Config, error) {
var installCmd = &cobra.Command{ var installCmd = &cobra.Command{
Use: "install", Use: "install",
Short: "installs Netbird service", Short: "installs NetBird service",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
if err := setupServiceCommand(cmd); err != nil { if err := setupServiceCommand(cmd); err != nil {
return err return err
@@ -122,14 +122,14 @@ var installCmd = &cobra.Command{
return fmt.Errorf("install service: %w", err) return fmt.Errorf("install service: %w", err)
} }
cmd.Println("Netbird service has been installed") cmd.Println("NetBird service has been installed")
return nil return nil
}, },
} }
var uninstallCmd = &cobra.Command{ var uninstallCmd = &cobra.Command{
Use: "uninstall", Use: "uninstall",
Short: "uninstalls Netbird service from system", Short: "uninstalls NetBird service from system",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
if err := setupServiceCommand(cmd); err != nil { if err := setupServiceCommand(cmd); err != nil {
return err return err
@@ -152,15 +152,15 @@ var uninstallCmd = &cobra.Command{
return fmt.Errorf("uninstall service: %w", err) return fmt.Errorf("uninstall service: %w", err)
} }
cmd.Println("Netbird service has been uninstalled") cmd.Println("NetBird service has been uninstalled")
return nil return nil
}, },
} }
var reconfigureCmd = &cobra.Command{ var reconfigureCmd = &cobra.Command{
Use: "reconfigure", Use: "reconfigure",
Short: "reconfigures Netbird service with new settings", Short: "reconfigures NetBird service with new settings",
Long: `Reconfigures the Netbird service with new settings without manual uninstall/install. Long: `Reconfigures the NetBird service with new settings without manual uninstall/install.
This command will temporarily stop the service, update its configuration, and restart it if it was running.`, This command will temporarily stop the service, update its configuration, and restart it if it was running.`,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
if err := setupServiceCommand(cmd); err != nil { if err := setupServiceCommand(cmd); err != nil {
@@ -186,7 +186,7 @@ This command will temporarily stop the service, update its configuration, and re
} }
if wasRunning { if wasRunning {
cmd.Println("Stopping Netbird service...") cmd.Println("Stopping NetBird service...")
if err := s.Stop(); err != nil { if err := s.Stop(); err != nil {
cmd.Printf("Warning: failed to stop service: %v\n", err) cmd.Printf("Warning: failed to stop service: %v\n", err)
} }
@@ -203,13 +203,13 @@ This command will temporarily stop the service, update its configuration, and re
} }
if wasRunning { if wasRunning {
cmd.Println("Starting Netbird service...") cmd.Println("Starting NetBird service...")
if err := s.Start(); err != nil { if err := s.Start(); err != nil {
return fmt.Errorf("start service after reconfigure: %w", err) return fmt.Errorf("start service after reconfigure: %w", err)
} }
cmd.Println("Netbird service has been reconfigured and started") cmd.Println("NetBird service has been reconfigured and started")
} else { } else {
cmd.Println("Netbird service has been reconfigured") cmd.Println("NetBird service has been reconfigured")
} }
return nil return nil

View File

@@ -59,8 +59,8 @@ var sshCmd = &cobra.Command{
ctx := internal.CtxInitState(cmd.Context()) ctx := internal.CtxInitState(cmd.Context())
pm := profilemanager.NewProfileManager() sm := profilemanager.NewServiceManager(configPath)
activeProf, err := pm.GetActiveProfile() activeProf, err := sm.GetActiveProfileState()
if err != nil { if err != nil {
return fmt.Errorf("get active profile: %v", err) return fmt.Errorf("get active profile: %v", err)
} }

View File

@@ -17,7 +17,7 @@ var (
var stateCmd = &cobra.Command{ var stateCmd = &cobra.Command{
Use: "state", Use: "state",
Short: "Manage daemon state", Short: "Manage daemon state",
Long: "Provides commands for managing and inspecting the Netbird daemon state.", Long: "Provides commands for managing and inspecting the NetBird daemon state.",
} }
var stateListCmd = &cobra.Command{ var stateListCmd = &cobra.Command{

View File

@@ -53,15 +53,15 @@ var (
upCmd = &cobra.Command{ upCmd = &cobra.Command{
Use: "up", Use: "up",
Short: "install, login and start Netbird client", Short: "install, login and start NetBird client",
RunE: upFunc, RunE: upFunc,
} }
) )
func init() { func init() {
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground") upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name") upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "WireGuard interface name")
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port") upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "WireGuard interface listening port")
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor, upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+ `Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`, `E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
@@ -79,7 +79,7 @@ func init() {
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc) upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc) upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location. ") upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) NetBird config file location. ")
} }

View File

@@ -9,7 +9,7 @@ import (
var ( var (
versionCmd = &cobra.Command{ versionCmd = &cobra.Command{
Use: "version", Use: "version",
Short: "prints Netbird version", Short: "prints NetBird version",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
cmd.SetOut(cmd.OutOrStdout()) cmd.SetOut(cmd.OutOrStdout())
cmd.Println(version.NetbirdVersion()) cmd.Println(version.NetbirdVersion())

View File

@@ -3,7 +3,7 @@
!define WEB_SITE "Netbird.io" !define WEB_SITE "Netbird.io"
!define VERSION $%APPVER% !define VERSION $%APPVER%
!define COPYRIGHT "Netbird Authors, 2022" !define COPYRIGHT "Netbird Authors, 2022"
!define DESCRIPTION "A WireGuard®-based mesh network that connects your devices into a single private network" !define DESCRIPTION "Connect your devices into a secure WireGuard-based overlay network with SSO, MFA, and granular access controls."
!define INSTALLER_NAME "netbird-installer.exe" !define INSTALLER_NAME "netbird-installer.exe"
!define MAIN_APP_EXE "Netbird" !define MAIN_APP_EXE "Netbird"
!define ICON "ui\\assets\\netbird.ico" !define ICON "ui\\assets\\netbird.ico"
@@ -59,9 +59,15 @@ ShowInstDetails Show
!define MUI_UNICON "${ICON}" !define MUI_UNICON "${ICON}"
!define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}" !define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}"
!define MUI_UNWELCOMEFINISHPAGE_BITMAP "${BANNER}" !define MUI_UNWELCOMEFINISHPAGE_BITMAP "${BANNER}"
!ifndef ARCH
!define ARCH "amd64"
!endif
!if ${ARCH} == "amd64"
!define MUI_FINISHPAGE_RUN !define MUI_FINISHPAGE_RUN
!define MUI_FINISHPAGE_RUN_TEXT "Start ${UI_APP_NAME}" !define MUI_FINISHPAGE_RUN_TEXT "Start ${UI_APP_NAME}"
!define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink" !define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink"
!endif
###################################################################### ######################################################################
!define MUI_ABORTWARNING !define MUI_ABORTWARNING
@@ -213,7 +219,15 @@ Section -MainProgram
${INSTALL_TYPE} ${INSTALL_TYPE}
# SetOverwrite ifnewer # SetOverwrite ifnewer
SetOutPath "$INSTDIR" SetOutPath "$INSTDIR"
!ifndef ARCH
!define ARCH "amd64"
!endif
!if ${ARCH} == "arm64"
File /r "..\\dist\\netbird_windows_arm64\\"
!else
File /r "..\\dist\\netbird_windows_amd64\\" File /r "..\\dist\\netbird_windows_amd64\\"
!endif
SectionEnd SectionEnd
###################################################################### ######################################################################
@@ -292,7 +306,9 @@ DetailPrint "Deleting application files..."
Delete "$INSTDIR\${UI_APP_EXE}" Delete "$INSTDIR\${UI_APP_EXE}"
Delete "$INSTDIR\${MAIN_APP_EXE}" Delete "$INSTDIR\${MAIN_APP_EXE}"
Delete "$INSTDIR\wintun.dll" Delete "$INSTDIR\wintun.dll"
!if ${ARCH} == "amd64"
Delete "$INSTDIR\opengl32.dll" Delete "$INSTDIR\opengl32.dll"
!endif
DetailPrint "Removing application directory..." DetailPrint "Removing application directory..."
RmDir /r "$INSTDIR" RmDir /r "$INSTDIR"
@@ -314,8 +330,10 @@ DetailPrint "Uninstallation finished."
SectionEnd SectionEnd
!if ${ARCH} == "amd64"
Function LaunchLink Function LaunchLink
SetShellVarContext all SetShellVarContext all
SetOutPath $INSTDIR SetOutPath $INSTDIR
ShellExecAsUser::ShellExecAsUser "" "$DESKTOP\${APP_NAME}.lnk" ShellExecAsUser::ShellExecAsUser "" "$DESKTOP\${APP_NAME}.lnk"
FunctionEnd FunctionEnd
!endif

View File

@@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/netip"
"runtime" "runtime"
"runtime/debug" "runtime/debug"
"strings" "strings"
@@ -70,7 +71,7 @@ func (c *ConnectClient) RunOnAndroid(
tunAdapter device.TunAdapter, tunAdapter device.TunAdapter,
iFaceDiscover stdnet.ExternalIFaceDiscover, iFaceDiscover stdnet.ExternalIFaceDiscover,
networkChangeListener listener.NetworkChangeListener, networkChangeListener listener.NetworkChangeListener,
dnsAddresses []string, dnsAddresses []netip.AddrPort,
dnsReadyListener dns.ReadyListener, dnsReadyListener dns.ReadyListener,
) error { ) error {
// in case of non Android os these variables will be nil // in case of non Android os these variables will be nil

View File

@@ -16,7 +16,7 @@ const (
) )
type resolvConf struct { type resolvConf struct {
nameServers []string nameServers []netip.Addr
searchDomains []string searchDomains []string
others []string others []string
} }
@@ -36,7 +36,7 @@ func parseBackupResolvConf() (*resolvConf, error) {
func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) { func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
rconf := &resolvConf{ rconf := &resolvConf{
searchDomains: make([]string, 0), searchDomains: make([]string, 0),
nameServers: make([]string, 0), nameServers: make([]netip.Addr, 0),
others: make([]string, 0), others: make([]string, 0),
} }
@@ -94,7 +94,11 @@ func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
if len(splitLines) != 2 { if len(splitLines) != 2 {
continue continue
} }
rconf.nameServers = append(rconf.nameServers, splitLines[1]) if addr, err := netip.ParseAddr(splitLines[1]); err == nil {
rconf.nameServers = append(rconf.nameServers, addr.Unmap())
} else {
log.Warnf("invalid nameserver address in resolv.conf: %s, skipping", splitLines[1])
}
continue continue
} }
@@ -104,31 +108,3 @@ func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
} }
return rconf, nil return rconf, nil
} }
// removeFirstNbNameserver removes the given nameserver from the given file if it is in the first position
// and writes the file back to the original location
func removeFirstNbNameserver(filename string, nameserverIP netip.Addr) error {
resolvConf, err := parseResolvConfFile(filename)
if err != nil {
return fmt.Errorf("parse backup resolv.conf: %w", err)
}
content, err := os.ReadFile(filename)
if err != nil {
return fmt.Errorf("read %s: %w", filename, err)
}
if len(resolvConf.nameServers) > 1 && resolvConf.nameServers[0] == nameserverIP.String() {
newContent := strings.Replace(string(content), fmt.Sprintf("nameserver %s\n", nameserverIP), "", 1)
stat, err := os.Stat(filename)
if err != nil {
return fmt.Errorf("stat %s: %w", filename, err)
}
if err := os.WriteFile(filename, []byte(newContent), stat.Mode()); err != nil {
return fmt.Errorf("write %s: %w", filename, err)
}
}
return nil
}

View File

@@ -3,13 +3,9 @@
package dns package dns
import ( import (
"net/netip"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func Test_parseResolvConf(t *testing.T) { func Test_parseResolvConf(t *testing.T) {
@@ -99,9 +95,13 @@ options debug
t.Errorf("invalid parse result for search domains, expected: %v, got: %v", testCase.expectedSearch, cfg.searchDomains) t.Errorf("invalid parse result for search domains, expected: %v, got: %v", testCase.expectedSearch, cfg.searchDomains)
} }
ok = compareLists(cfg.nameServers, testCase.expectedNS) nsStrings := make([]string, len(cfg.nameServers))
for i, ns := range cfg.nameServers {
nsStrings[i] = ns.String()
}
ok = compareLists(nsStrings, testCase.expectedNS)
if !ok { if !ok {
t.Errorf("invalid parse result for ns domains, expected: %v, got: %v", testCase.expectedNS, cfg.nameServers) t.Errorf("invalid parse result for ns domains, expected: %v, got: %v", testCase.expectedNS, nsStrings)
} }
ok = compareLists(cfg.others, testCase.expectedOther) ok = compareLists(cfg.others, testCase.expectedOther)
@@ -177,86 +177,3 @@ nameserver 192.168.0.1
} }
} }
func TestRemoveFirstNbNameserver(t *testing.T) {
testCases := []struct {
name string
content string
ipToRemove string
expected string
}{
{
name: "Unrelated nameservers with comments and options",
content: `# This is a comment
options rotate
nameserver 1.1.1.1
# Another comment
nameserver 8.8.4.4
search example.com`,
ipToRemove: "9.9.9.9",
expected: `# This is a comment
options rotate
nameserver 1.1.1.1
# Another comment
nameserver 8.8.4.4
search example.com`,
},
{
name: "First nameserver matches",
content: `search example.com
nameserver 9.9.9.9
# oof, a comment
nameserver 8.8.4.4
options attempts:5`,
ipToRemove: "9.9.9.9",
expected: `search example.com
# oof, a comment
nameserver 8.8.4.4
options attempts:5`,
},
{
name: "Target IP not the first nameserver",
// nolint:dupword
content: `# Comment about the first nameserver
nameserver 8.8.4.4
# Comment before our target
nameserver 9.9.9.9
options timeout:2`,
ipToRemove: "9.9.9.9",
// nolint:dupword
expected: `# Comment about the first nameserver
nameserver 8.8.4.4
# Comment before our target
nameserver 9.9.9.9
options timeout:2`,
},
{
name: "Only nameserver matches",
content: `options debug
nameserver 9.9.9.9
search localdomain`,
ipToRemove: "9.9.9.9",
expected: `options debug
nameserver 9.9.9.9
search localdomain`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tempDir := t.TempDir()
tempFile := filepath.Join(tempDir, "resolv.conf")
err := os.WriteFile(tempFile, []byte(tc.content), 0644)
assert.NoError(t, err)
ip, err := netip.ParseAddr(tc.ipToRemove)
require.NoError(t, err, "Failed to parse IP address")
err = removeFirstNbNameserver(tempFile, ip)
assert.NoError(t, err)
content, err := os.ReadFile(tempFile)
assert.NoError(t, err)
assert.Equal(t, tc.expected, string(content), "The resulting content should match the expected output.")
})
}
}

View File

@@ -146,7 +146,7 @@ func isNbParamsMissing(nbSearchDomains []string, nbNameserverIP netip.Addr, rCon
return true return true
} }
if rConf.nameServers[0] != nbNameserverIP.String() { if rConf.nameServers[0] != nbNameserverIP {
return true return true
} }

View File

@@ -29,7 +29,7 @@ type fileConfigurator struct {
repair *repair repair *repair
originalPerms os.FileMode originalPerms os.FileMode
nbNameserverIP netip.Addr nbNameserverIP netip.Addr
originalNameservers []string originalNameservers []netip.Addr
} }
func newFileConfigurator() (*fileConfigurator, error) { func newFileConfigurator() (*fileConfigurator, error) {
@@ -70,7 +70,7 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *st
} }
// getOriginalNameservers returns the nameservers that were found in the original resolv.conf // getOriginalNameservers returns the nameservers that were found in the original resolv.conf
func (f *fileConfigurator) getOriginalNameservers() []string { func (f *fileConfigurator) getOriginalNameservers() []netip.Addr {
return f.originalNameservers return f.originalNameservers
} }
@@ -128,20 +128,14 @@ func (f *fileConfigurator) backup() error {
} }
func (f *fileConfigurator) restore() error { func (f *fileConfigurator) restore() error {
err := removeFirstNbNameserver(fileDefaultResolvConfBackupLocation, f.nbNameserverIP) if err := copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath); err != nil {
if err != nil {
log.Errorf("Failed to remove netbird nameserver from %s on backup restore: %s", fileDefaultResolvConfBackupLocation, err)
}
err = copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath)
if err != nil {
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err) return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
} }
return os.RemoveAll(fileDefaultResolvConfBackupLocation) return os.RemoveAll(fileDefaultResolvConfBackupLocation)
} }
func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error { func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress netip.Addr) error {
resolvConf, err := parseDefaultResolvConf() resolvConf, err := parseDefaultResolvConf()
if err != nil { if err != nil {
return fmt.Errorf("parse current resolv.conf: %w", err) return fmt.Errorf("parse current resolv.conf: %w", err)
@@ -152,16 +146,9 @@ func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Add
return restoreResolvConfFile() return restoreResolvConfFile()
} }
currentDNSAddress, err := netip.ParseAddr(resolvConf.nameServers[0])
// not a valid first nameserver -> restore
if err != nil {
log.Errorf("restoring unclean shutdown: parse dns address %s failed: %s", resolvConf.nameServers[0], err)
return restoreResolvConfFile()
}
// current address is still netbird's non-available dns address -> restore // current address is still netbird's non-available dns address -> restore
// comparing parsed addresses only, to remove ambiguity currentDNSAddress := resolvConf.nameServers[0]
if currentDNSAddress.String() == storedDNSAddress.String() { if currentDNSAddress == storedDNSAddress {
return restoreResolvConfFile() return restoreResolvConfFile()
} }

View File

@@ -239,7 +239,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
} else if inServerAddressesArray { } else if inServerAddressesArray {
address := strings.Split(line, " : ")[1] address := strings.Split(line, " : ")[1]
if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() { if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() {
dnsSettings.ServerIP = ip dnsSettings.ServerIP = ip.Unmap()
inServerAddressesArray = false // Stop reading after finding the first IPv4 address inServerAddressesArray = false // Stop reading after finding the first IPv4 address
} }
} }
@@ -250,7 +250,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
} }
// default to 53 port // default to 53 port
dnsSettings.ServerPort = defaultPort dnsSettings.ServerPort = DefaultPort
return dnsSettings, nil return dnsSettings, nil
} }

View File

@@ -42,7 +42,7 @@ func (t osManagerType) String() string {
type restoreHostManager interface { type restoreHostManager interface {
hostManager hostManager
restoreUncleanShutdownDNS(*netip.Addr) error restoreUncleanShutdownDNS(netip.Addr) error
} }
func newHostManager(wgInterface string) (hostManager, error) { func newHostManager(wgInterface string) (hostManager, error) {
@@ -130,8 +130,9 @@ func checkStub() bool {
return true return true
} }
systemdResolvedAddr := netip.AddrFrom4([4]byte{127, 0, 0, 53}) // 127.0.0.53
for _, ns := range rConf.nameServers { for _, ns := range rConf.nameServers {
if ns == "127.0.0.53" { if ns == systemdResolvedAddr {
return true return true
} }
} }

View File

@@ -216,7 +216,7 @@ func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
return fmt.Errorf("adding dns setup for all failed: %w", err) return fmt.Errorf("adding dns setup for all failed: %w", err)
} }
r.routingAll = true r.routingAll = true
log.Infof("configured %s:53 as main DNS forwarder for this peer", ip) log.Infof("configured %s:%d as main DNS forwarder for this peer", ip, DefaultPort)
return nil return nil
} }

View File

@@ -1,38 +1,31 @@
package dns package dns
import ( import (
"fmt"
"net/netip" "net/netip"
"sync" "sync"
log "github.com/sirupsen/logrus"
) )
type hostsDNSHolder struct { type hostsDNSHolder struct {
unprotectedDNSList map[string]struct{} unprotectedDNSList map[netip.AddrPort]struct{}
mutex sync.RWMutex mutex sync.RWMutex
} }
func newHostsDNSHolder() *hostsDNSHolder { func newHostsDNSHolder() *hostsDNSHolder {
return &hostsDNSHolder{ return &hostsDNSHolder{
unprotectedDNSList: make(map[string]struct{}), unprotectedDNSList: make(map[netip.AddrPort]struct{}),
} }
} }
func (h *hostsDNSHolder) set(list []string) { func (h *hostsDNSHolder) set(list []netip.AddrPort) {
h.mutex.Lock() h.mutex.Lock()
h.unprotectedDNSList = make(map[string]struct{}) h.unprotectedDNSList = make(map[netip.AddrPort]struct{})
for _, dns := range list { for _, addrPort := range list {
dnsAddr, err := h.normalizeAddress(dns) h.unprotectedDNSList[addrPort] = struct{}{}
if err != nil {
continue
}
h.unprotectedDNSList[dnsAddr] = struct{}{}
} }
h.mutex.Unlock() h.mutex.Unlock()
} }
func (h *hostsDNSHolder) get() map[string]struct{} { func (h *hostsDNSHolder) get() map[netip.AddrPort]struct{} {
h.mutex.RLock() h.mutex.RLock()
l := h.unprotectedDNSList l := h.unprotectedDNSList
h.mutex.RUnlock() h.mutex.RUnlock()
@@ -40,24 +33,10 @@ func (h *hostsDNSHolder) get() map[string]struct{} {
} }
//nolint:unused //nolint:unused
func (h *hostsDNSHolder) isContain(upstream string) bool { func (h *hostsDNSHolder) contains(upstream netip.AddrPort) bool {
h.mutex.RLock() h.mutex.RLock()
defer h.mutex.RUnlock() defer h.mutex.RUnlock()
_, ok := h.unprotectedDNSList[upstream] _, ok := h.unprotectedDNSList[upstream]
return ok return ok
} }
func (h *hostsDNSHolder) normalizeAddress(addr string) (string, error) {
a, err := netip.ParseAddr(addr)
if err != nil {
log.Errorf("invalid upstream IP address: %s, error: %s", addr, err)
return "", err
}
if a.Is4() {
return fmt.Sprintf("%s:53", addr), nil
} else {
return fmt.Sprintf("[%s]:53", addr), nil
}
}

View File

@@ -50,7 +50,7 @@ func (m *MockServer) DnsIP() netip.Addr {
return netip.MustParseAddr("100.10.254.255") return netip.MustParseAddr("100.10.254.255")
} }
func (m *MockServer) OnUpdatedHostDNSServer(strings []string) { func (m *MockServer) OnUpdatedHostDNSServer(addrs []netip.AddrPort) {
// TODO implement me // TODO implement me
panic("implement me") panic("implement me")
} }

View File

@@ -245,7 +245,7 @@ func (n *networkManagerDbusConfigurator) deleteConnectionSettings() error {
return nil return nil
} }
func (n *networkManagerDbusConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error { func (n *networkManagerDbusConfigurator) restoreUncleanShutdownDNS(netip.Addr) error {
if err := n.restoreHostDNS(); err != nil { if err := n.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via network-manager: %w", err) return fmt.Errorf("restoring dns via network-manager: %w", err)
} }

View File

@@ -40,7 +40,7 @@ type resolvconf struct {
implType resolvconfType implType resolvconfType
originalSearchDomains []string originalSearchDomains []string
originalNameServers []string originalNameServers []netip.Addr
othersConfigs []string othersConfigs []string
} }
@@ -110,7 +110,7 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *stateman
return nil return nil
} }
func (r *resolvconf) getOriginalNameservers() []string { func (r *resolvconf) getOriginalNameservers() []netip.Addr {
return r.originalNameServers return r.originalNameServers
} }
@@ -158,7 +158,7 @@ func (r *resolvconf) applyConfig(content bytes.Buffer) error {
return nil return nil
} }
func (r *resolvconf) restoreUncleanShutdownDNS(*netip.Addr) error { func (r *resolvconf) restoreUncleanShutdownDNS(netip.Addr) error {
if err := r.restoreHostDNS(); err != nil { if err := r.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns for interface %s: %w", r.ifaceName, err) return fmt.Errorf("restoring dns for interface %s: %w", r.ifaceName, err)
} }

View File

@@ -42,7 +42,7 @@ type Server interface {
Stop() Stop()
DnsIP() netip.Addr DnsIP() netip.Addr
UpdateDNSServer(serial uint64, update nbdns.Config) error UpdateDNSServer(serial uint64, update nbdns.Config) error
OnUpdatedHostDNSServer(strings []string) OnUpdatedHostDNSServer(addrs []netip.AddrPort)
SearchDomains() []string SearchDomains() []string
ProbeAvailability() ProbeAvailability()
} }
@@ -55,7 +55,7 @@ type nsGroupsByDomain struct {
// hostManagerWithOriginalNS extends the basic hostManager interface // hostManagerWithOriginalNS extends the basic hostManager interface
type hostManagerWithOriginalNS interface { type hostManagerWithOriginalNS interface {
hostManager hostManager
getOriginalNameservers() []string getOriginalNameservers() []netip.Addr
} }
// DefaultServer dns server object // DefaultServer dns server object
@@ -136,7 +136,7 @@ func NewDefaultServer(
func NewDefaultServerPermanentUpstream( func NewDefaultServerPermanentUpstream(
ctx context.Context, ctx context.Context,
wgInterface WGIface, wgInterface WGIface,
hostsDnsList []string, hostsDnsList []netip.AddrPort,
config nbdns.Config, config nbdns.Config,
listener listener.NetworkChangeListener, listener listener.NetworkChangeListener,
statusRecorder *peer.Status, statusRecorder *peer.Status,
@@ -144,6 +144,7 @@ func NewDefaultServerPermanentUpstream(
) *DefaultServer { ) *DefaultServer {
log.Debugf("host dns address list is: %v", hostsDnsList) log.Debugf("host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys) ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
ds.hostsDNSHolder.set(hostsDnsList) ds.hostsDNSHolder.set(hostsDnsList)
ds.permanent = true ds.permanent = true
ds.addHostRootZone() ds.addHostRootZone()
@@ -340,7 +341,7 @@ func (s *DefaultServer) disableDNS() error {
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones // OnUpdatedHostDNSServer update the DNS servers addresses for root zones
// It will be applied if the mgm server do not enforce DNS settings for root zone // It will be applied if the mgm server do not enforce DNS settings for root zone
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) { func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []netip.AddrPort) {
s.hostsDNSHolder.set(hostsDnsList) s.hostsDNSHolder.set(hostsDnsList)
// Check if there's any root handler // Check if there's any root handler
@@ -461,7 +462,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort()) s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() { if s.service.RuntimePort() != DefaultPort && !s.hostManager.supportCustomPort() {
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " + log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver") "Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
s.currentConfig.RouteAll = false s.currentConfig.RouteAll = false
@@ -581,14 +582,13 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
} }
for _, ns := range originalNameservers { for _, ns := range originalNameservers {
if ns == config.ServerIP.String() { if ns == config.ServerIP {
log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, config.ServerIP) log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, config.ServerIP)
continue continue
} }
ns = formatAddr(ns, defaultPort) addrPort := netip.AddrPortFrom(ns, DefaultPort)
handler.upstreamServers = append(handler.upstreamServers, addrPort)
handler.upstreamServers = append(handler.upstreamServers, ns)
} }
handler.deactivate = func(error) { /* always active */ } handler.deactivate = func(error) { /* always active */ }
handler.reactivate = func() { /* always active */ } handler.reactivate = func() { /* always active */ }
@@ -695,7 +695,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String()) ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
continue continue
} }
handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns)) handler.upstreamServers = append(handler.upstreamServers, ns.AddrPort())
} }
if len(handler.upstreamServers) == 0 { if len(handler.upstreamServers) == 0 {
@@ -770,18 +770,6 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
s.dnsMuxMap = muxUpdateMap s.dnsMuxMap = muxUpdateMap
} }
func getNSHostPort(ns nbdns.NameServer) string {
return formatAddr(ns.IP.String(), ns.Port)
}
// formatAddr formats a nameserver address with port, handling IPv6 addresses properly
func formatAddr(address string, port int) string {
if ip, err := netip.ParseAddr(address); err == nil && ip.Is6() {
return fmt.Sprintf("[%s]:%d", address, port)
}
return fmt.Sprintf("%s:%d", address, port)
}
// upstreamCallbacks returns two functions, the first one is used to deactivate // upstreamCallbacks returns two functions, the first one is used to deactivate
// the upstream resolver from the configuration, the second one is used to // the upstream resolver from the configuration, the second one is used to
// reactivate it. Not allowed to call reactivate before deactivate. // reactivate it. Not allowed to call reactivate before deactivate.
@@ -879,10 +867,7 @@ func (s *DefaultServer) addHostRootZone() {
return return
} }
handler.upstreamServers = make([]string, 0) handler.upstreamServers = maps.Keys(hostDNSServers)
for k := range hostDNSServers {
handler.upstreamServers = append(handler.upstreamServers, k)
}
handler.deactivate = func(error) {} handler.deactivate = func(error) {}
handler.reactivate = func() {} handler.reactivate = func() {}
@@ -893,9 +878,9 @@ func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
var states []peer.NSGroupState var states []peer.NSGroupState
for _, group := range groups { for _, group := range groups {
var servers []string var servers []netip.AddrPort
for _, ns := range group.NameServers { for _, ns := range group.NameServers {
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port)) servers = append(servers, ns.AddrPort())
} }
state := peer.NSGroupState{ state := peer.NSGroupState{
@@ -927,7 +912,7 @@ func (s *DefaultServer) updateNSState(nsGroup *nbdns.NameServerGroup, err error,
func generateGroupKey(nsGroup *nbdns.NameServerGroup) string { func generateGroupKey(nsGroup *nbdns.NameServerGroup) string {
var servers []string var servers []string
for _, ns := range nsGroup.NameServers { for _, ns := range nsGroup.NameServers {
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port)) servers = append(servers, ns.AddrPort().String())
} }
return fmt.Sprintf("%v_%v", servers, nsGroup.Domains) return fmt.Sprintf("%v_%v", servers, nsGroup.Domains)
} }

View File

@@ -97,9 +97,9 @@ func init() {
} }
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase { func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
var srvs []string var srvs []netip.AddrPort
for _, srv := range servers { for _, srv := range servers {
srvs = append(srvs, getNSHostPort(srv)) srvs = append(srvs, srv.AddrPort())
} }
return &upstreamResolverBase{ return &upstreamResolverBase{
domain: domain, domain: domain,
@@ -705,7 +705,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
} }
defer wgIFace.Close() defer wgIFace.Close()
var dnsList []string var dnsList []netip.AddrPort
dnsConfig := nbdns.Config{} dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, peer.NewRecorder("mgm"), false) dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, peer.NewRecorder("mgm"), false)
err = dnsServer.Initialize() err = dnsServer.Initialize()
@@ -715,7 +715,8 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
} }
defer dnsServer.Stop() defer dnsServer.Stop()
dnsServer.OnUpdatedHostDNSServer([]string{"8.8.8.8"}) addrPort := netip.MustParseAddrPort("8.8.8.8:53")
dnsServer.OnUpdatedHostDNSServer([]netip.AddrPort{addrPort})
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort()) resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
_, err = resolver.LookupHost(context.Background(), "netbird.io") _, err = resolver.LookupHost(context.Background(), "netbird.io")
@@ -731,7 +732,8 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
} }
defer wgIFace.Close() defer wgIFace.Close()
dnsConfig := nbdns.Config{} dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, peer.NewRecorder("mgm"), false) addrPort := netip.MustParseAddrPort("8.8.8.8:53")
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []netip.AddrPort{addrPort}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
err = dnsServer.Initialize() err = dnsServer.Initialize()
if err != nil { if err != nil {
t.Errorf("failed to initialize DNS server: %v", err) t.Errorf("failed to initialize DNS server: %v", err)
@@ -823,7 +825,8 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
} }
defer wgIFace.Close() defer wgIFace.Close()
dnsConfig := nbdns.Config{} dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, peer.NewRecorder("mgm"), false) addrPort := netip.MustParseAddrPort("8.8.8.8:53")
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []netip.AddrPort{addrPort}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
err = dnsServer.Initialize() err = dnsServer.Initialize()
if err != nil { if err != nil {
t.Errorf("failed to initialize DNS server: %v", err) t.Errorf("failed to initialize DNS server: %v", err)
@@ -2053,56 +2056,3 @@ func TestLocalResolverPriorityConstants(t *testing.T) {
assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal") assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal")
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain) assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
} }
func TestFormatAddr(t *testing.T) {
tests := []struct {
name string
address string
port int
expected string
}{
{
name: "IPv4 address",
address: "8.8.8.8",
port: 53,
expected: "8.8.8.8:53",
},
{
name: "IPv4 address with custom port",
address: "1.1.1.1",
port: 5353,
expected: "1.1.1.1:5353",
},
{
name: "IPv6 address",
address: "fd78:94bf:7df8::1",
port: 53,
expected: "[fd78:94bf:7df8::1]:53",
},
{
name: "IPv6 address with custom port",
address: "2001:db8::1",
port: 5353,
expected: "[2001:db8::1]:5353",
},
{
name: "IPv6 localhost",
address: "::1",
port: 53,
expected: "[::1]:53",
},
{
name: "Invalid address treated as hostname",
address: "dns.example.com",
port: 53,
expected: "dns.example.com:53",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := formatAddr(tt.address, tt.port)
assert.Equal(t, tt.expected, result)
})
}
}

View File

@@ -7,7 +7,7 @@ import (
) )
const ( const (
defaultPort = 53 DefaultPort = 53
) )
type service interface { type service interface {

View File

@@ -122,7 +122,7 @@ func (s *serviceViaListener) RuntimePort() int {
defer s.listenerFlagLock.Unlock() defer s.listenerFlagLock.Unlock()
if s.ebpfService != nil { if s.ebpfService != nil {
return defaultPort return DefaultPort
} else { } else {
return int(s.listenPort) return int(s.listenPort)
} }
@@ -148,9 +148,9 @@ func (s *serviceViaListener) evalListenAddress() (netip.Addr, uint16, error) {
return s.customAddr.Addr(), s.customAddr.Port(), nil return s.customAddr.Addr(), s.customAddr.Port(), nil
} }
ip, ok := s.testFreePort(defaultPort) ip, ok := s.testFreePort(DefaultPort)
if ok { if ok {
return ip, defaultPort, nil return ip, DefaultPort, nil
} }
ebpfSrv, port, ok := s.tryToUseeBPF() ebpfSrv, port, ok := s.tryToUseeBPF()

View File

@@ -33,7 +33,7 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
dnsMux: dns.NewServeMux(), dnsMux: dns.NewServeMux(),
runtimeIP: lastIP, runtimeIP: lastIP,
runtimePort: defaultPort, runtimePort: DefaultPort,
} }
return s return s
} }

View File

@@ -235,7 +235,7 @@ func (s *systemdDbusConfigurator) callLinkMethod(method string, value any) error
return nil return nil
} }
func (s *systemdDbusConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error { func (s *systemdDbusConfigurator) restoreUncleanShutdownDNS(netip.Addr) error {
if err := s.restoreHostDNS(); err != nil { if err := s.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via systemd: %w", err) return fmt.Errorf("restoring dns via systemd: %w", err)
} }

View File

@@ -27,7 +27,7 @@ func (s *ShutdownState) Cleanup() error {
return fmt.Errorf("create previous host manager: %w", err) return fmt.Errorf("create previous host manager: %w", err)
} }
if err := manager.restoreUncleanShutdownDNS(&s.DNSAddress); err != nil { if err := manager.restoreUncleanShutdownDNS(s.DNSAddress); err != nil {
return fmt.Errorf("restore unclean shutdown dns: %w", err) return fmt.Errorf("restore unclean shutdown dns: %w", err)
} }

View File

@@ -8,6 +8,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/netip"
"slices" "slices"
"strings" "strings"
"sync" "sync"
@@ -48,7 +49,7 @@ type upstreamResolverBase struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
upstreamClient upstreamClient upstreamClient upstreamClient
upstreamServers []string upstreamServers []netip.AddrPort
domain string domain string
disabled bool disabled bool
failsCount atomic.Int32 failsCount atomic.Int32
@@ -79,17 +80,20 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d
// String returns a string representation of the upstream resolver // String returns a string representation of the upstream resolver
func (u *upstreamResolverBase) String() string { func (u *upstreamResolverBase) String() string {
return fmt.Sprintf("upstream %v", u.upstreamServers) return fmt.Sprintf("upstream %s", u.upstreamServers)
} }
// ID returns the unique handler ID // ID returns the unique handler ID
func (u *upstreamResolverBase) ID() types.HandlerID { func (u *upstreamResolverBase) ID() types.HandlerID {
servers := slices.Clone(u.upstreamServers) servers := slices.Clone(u.upstreamServers)
slices.Sort(servers) slices.SortFunc(servers, func(a, b netip.AddrPort) int { return a.Compare(b) })
hash := sha256.New() hash := sha256.New()
hash.Write([]byte(u.domain + ":")) hash.Write([]byte(u.domain + ":"))
hash.Write([]byte(strings.Join(servers, ","))) for _, s := range servers {
hash.Write([]byte(s.String()))
hash.Write([]byte("|"))
}
return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8])) return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
} }
@@ -130,7 +134,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
func() { func() {
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout) ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
defer cancel() defer cancel()
rm, t, err = u.upstreamClient.exchange(ctx, upstream, r) rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
}() }()
if err != nil { if err != nil {
@@ -197,7 +201,7 @@ func (u *upstreamResolverBase) checkUpstreamFails(err error) {
proto.SystemEvent_DNS, proto.SystemEvent_DNS,
"All upstream servers failed (fail count exceeded)", "All upstream servers failed (fail count exceeded)",
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.", "Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
map[string]string{"upstreams": strings.Join(u.upstreamServers, ", ")}, map[string]string{"upstreams": u.upstreamServersString()},
// TODO add domain meta // TODO add domain meta
) )
} }
@@ -258,7 +262,7 @@ func (u *upstreamResolverBase) ProbeAvailability() {
proto.SystemEvent_DNS, proto.SystemEvent_DNS,
"All upstream servers failed (probe failed)", "All upstream servers failed (probe failed)",
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.", "Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
map[string]string{"upstreams": strings.Join(u.upstreamServers, ", ")}, map[string]string{"upstreams": u.upstreamServersString()},
) )
} }
} }
@@ -278,7 +282,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
operation := func() error { operation := func() error {
select { select {
case <-u.ctx.Done(): case <-u.ctx.Done():
return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServers)) return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServersString()))
default: default:
} }
@@ -291,7 +295,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
} }
} }
log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServers, exponentialBackOff.NextBackOff()) log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServersString(), exponentialBackOff.NextBackOff())
return fmt.Errorf("upstream check call error") return fmt.Errorf("upstream check call error")
} }
@@ -301,7 +305,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
return return
} }
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServers) log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
u.failsCount.Store(0) u.failsCount.Store(0)
u.successCount.Add(1) u.successCount.Add(1)
u.reactivate() u.reactivate()
@@ -331,13 +335,21 @@ func (u *upstreamResolverBase) disable(err error) {
go u.waitUntilResponse() go u.waitUntilResponse()
} }
func (u *upstreamResolverBase) testNameserver(server string, timeout time.Duration) error { func (u *upstreamResolverBase) upstreamServersString() string {
var servers []string
for _, server := range u.upstreamServers {
servers = append(servers, server.String())
}
return strings.Join(servers, ", ")
}
func (u *upstreamResolverBase) testNameserver(server netip.AddrPort, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(u.ctx, timeout) ctx, cancel := context.WithTimeout(u.ctx, timeout)
defer cancel() defer cancel()
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA) r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
_, _, err := u.upstreamClient.exchange(ctx, server, r) _, _, err := u.upstreamClient.exchange(ctx, server.String(), r)
return err return err
} }

View File

@@ -79,8 +79,8 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri
} }
func (u *upstreamResolver) isLocalResolver(upstream string) bool { func (u *upstreamResolver) isLocalResolver(upstream string) bool {
if u.hostsDNSHolder.isContain(upstream) { if addrPort, err := netip.ParseAddrPort(upstream); err == nil {
return true return u.hostsDNSHolder.contains(addrPort)
} }
return false return false
} }

View File

@@ -62,6 +62,8 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
upstreamIP, err := netip.ParseAddr(upstreamHost) upstreamIP, err := netip.ParseAddr(upstreamHost)
if err != nil { if err != nil {
log.Warnf("failed to parse upstream host %s: %s", upstreamHost, err) log.Warnf("failed to parse upstream host %s: %s", upstreamHost, err)
} else {
upstreamIP = upstreamIP.Unmap()
} }
if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() { if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() {
log.Debugf("using private client to query upstream: %s", upstream) log.Debugf("using private client to query upstream: %s", upstream)

View File

@@ -59,7 +59,14 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO()) ctx, cancel := context.WithCancel(context.TODO())
resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".") resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".")
resolver.upstreamServers = testCase.InputServers // Convert test servers to netip.AddrPort
var servers []netip.AddrPort
for _, server := range testCase.InputServers {
if addrPort, err := netip.ParseAddrPort(server); err == nil {
servers = append(servers, netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()))
}
}
resolver.upstreamServers = servers
resolver.upstreamTimeout = testCase.timeout resolver.upstreamTimeout = testCase.timeout
if testCase.cancelCTX { if testCase.cancelCTX {
cancel() cancel()
@@ -128,7 +135,8 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
reactivatePeriod: reactivatePeriod, reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact, failsTillDeact: failsTillDeact,
} }
resolver.upstreamServers = []string{"0.0.0.0:-1"} addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection
resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())}
resolver.failsTillDeact = 0 resolver.failsTillDeact = 0
resolver.reactivatePeriod = time.Microsecond * 100 resolver.reactivatePeriod = time.Microsecond * 100

View File

@@ -1,6 +1,8 @@
package internal package internal
import ( import (
"net/netip"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
@@ -13,7 +15,7 @@ type MobileDependency struct {
TunAdapter device.TunAdapter TunAdapter device.TunAdapter
IFaceDiscover stdnet.ExternalIFaceDiscover IFaceDiscover stdnet.ExternalIFaceDiscover
NetworkChangeListener listener.NetworkChangeListener NetworkChangeListener listener.NetworkChangeListener
HostDNSAddresses []string HostDNSAddresses []netip.AddrPort
DnsReadyListener dns.ReadyListener DnsReadyListener dns.ReadyListener
// iOS only // iOS only

View File

@@ -140,7 +140,7 @@ type RosenpassState struct {
// whether it's enabled, and the last error message encountered during probing. // whether it's enabled, and the last error message encountered during probing.
type NSGroupState struct { type NSGroupState struct {
ID string ID string
Servers []string Servers []netip.AddrPort
Domains []string Domains []string
Enabled bool Enabled bool
Error error Error error

View File

@@ -28,7 +28,6 @@ import (
const ( const (
// managementLegacyPortString is the port that was used before by the Management gRPC server. // managementLegacyPortString is the port that was used before by the Management gRPC server.
// It is used for backward compatibility now. // It is used for backward compatibility now.
// NB: hardcoded from github.com/netbirdio/netbird/management/cmd to avoid import
managementLegacyPortString = "33073" managementLegacyPortString = "33073"
// DefaultManagementURL points to the NetBird's cloud management endpoint // DefaultManagementURL points to the NetBird's cloud management endpoint
DefaultManagementURL = "https://api.netbird.io:443" DefaultManagementURL = "https://api.netbird.io:443"
@@ -594,17 +593,9 @@ func update(input ConfigInput) (*Config, error) {
return config, nil return config, nil
} }
// GetConfig read config file and return with Config. Errors out if it does not exist
func GetConfig(configPath string) (*Config, error) { func GetConfig(configPath string) (*Config, error) {
if !fileExists(configPath) { return readConfig(configPath, false)
return nil, fmt.Errorf("config file %s does not exist", configPath)
}
config := &Config{}
if _, err := util.ReadJson(configPath, config); err != nil {
return nil, fmt.Errorf("failed to read config file %s: %w", configPath, err)
}
return config, nil
} }
// UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain. // UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain.
@@ -696,6 +687,11 @@ func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
// ReadConfig read config file and return with Config. If it is not exists create a new with default values // ReadConfig read config file and return with Config. If it is not exists create a new with default values
func ReadConfig(configPath string) (*Config, error) { func ReadConfig(configPath string) (*Config, error) {
return readConfig(configPath, true)
}
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
func readConfig(configPath string, createIfMissing bool) (*Config, error) {
if fileExists(configPath) { if fileExists(configPath) {
err := util.EnforcePermission(configPath) err := util.EnforcePermission(configPath)
if err != nil { if err != nil {
@@ -716,6 +712,8 @@ func ReadConfig(configPath string) (*Config, error) {
} }
return config, nil return config, nil
} else if !createIfMissing {
return nil, fmt.Errorf("config file %s does not exist", configPath)
} }
cfg, err := createNewConfig(ConfigInput{ConfigPath: configPath}) cfg, err := createNewConfig(ConfigInput{ConfigPath: configPath})

View File

@@ -16,19 +16,21 @@
<StandardDirectory Id="ProgramFiles64Folder"> <StandardDirectory Id="ProgramFiles64Folder">
<Directory Id="NetbirdInstallDir" Name="Netbird"> <Directory Id="NetbirdInstallDir" Name="Netbird">
<Component Id="NetbirdFiles" Guid="db3165de-cc6e-4922-8396-9d892950e23e" Bitness="always64"> <Component Id="NetbirdFiles" Guid="db3165de-cc6e-4922-8396-9d892950e23e" Bitness="always64">
<File ProcessorArchitecture="x64" Source=".\dist\netbird_windows_amd64\netbird.exe" KeyPath="yes" /> <File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\netbird.exe" KeyPath="yes" />
<File ProcessorArchitecture="x64" Source=".\dist\netbird_windows_amd64\netbird-ui.exe"> <File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\netbird-ui.exe">
<Shortcut Id="NetbirdDesktopShortcut" Directory="DesktopFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" /> <Shortcut Id="NetbirdDesktopShortcut" Directory="DesktopFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" />
<Shortcut Id="NetbirdStartMenuShortcut" Directory="StartMenuFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" /> <Shortcut Id="NetbirdStartMenuShortcut" Directory="StartMenuFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" />
</File> </File>
<File ProcessorArchitecture="x64" Source=".\dist\netbird_windows_amd64\wintun.dll" /> <File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\wintun.dll" />
<File ProcessorArchitecture="x64" Source=".\dist\netbird_windows_amd64\opengl32.dll" /> <?if $(var.ArchSuffix) = "amd64" ?>
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\opengl32.dll" />
<?endif ?>
<ServiceInstall <ServiceInstall
Id="NetBirdService" Id="NetBirdService"
Name="NetBird" Name="NetBird"
DisplayName="NetBird" DisplayName="NetBird"
Description="A WireGuard-based mesh network that connects your devices into a single private network." Description="Connect your devices into a secure WireGuard-based overlay network with SSO, MFA and granular access controls."
Start="auto" Type="ownProcess" Start="auto" Type="ownProcess"
ErrorControl="normal" ErrorControl="normal"
Account="LocalSystem" Account="LocalSystem"

View File

@@ -1197,8 +1197,14 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
if dnsState.Error != nil { if dnsState.Error != nil {
err = dnsState.Error.Error() err = dnsState.Error.Error()
} }
var servers []string
for _, server := range dnsState.Servers {
servers = append(servers, server.String())
}
pbDnsState := &proto.NSGroupState{ pbDnsState := &proto.NSGroupState{
Servers: dnsState.Servers, Servers: servers,
Domains: dnsState.Domains, Domains: dnsState.Domains,
Enabled: dnsState.Enabled, Enabled: dnsState.Enabled,
Error: err, Error: err,

View File

@@ -46,7 +46,7 @@ func (s *serviceClient) showProfilesUI() {
widget.NewLabel(""), // profile name widget.NewLabel(""), // profile name
layout.NewSpacer(), layout.NewSpacer(),
widget.NewButton("Select", nil), widget.NewButton("Select", nil),
widget.NewButton("Logout", nil), widget.NewButton("Deregister", nil),
widget.NewButton("Remove", nil), widget.NewButton("Remove", nil),
) )
}, },
@@ -128,7 +128,7 @@ func (s *serviceClient) showProfilesUI() {
} }
logoutBtn.Show() logoutBtn.Show()
logoutBtn.SetText("Logout") logoutBtn.SetText("Deregister")
logoutBtn.OnTapped = func() { logoutBtn.OnTapped = func() {
s.handleProfileLogout(profile.Name, refresh) s.handleProfileLogout(profile.Name, refresh)
} }
@@ -334,8 +334,8 @@ func (s *serviceClient) getProfiles() ([]Profile, error) {
func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback func()) { func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback func()) {
dialog.ShowConfirm( dialog.ShowConfirm(
"Logout", "Deregister",
fmt.Sprintf("Are you sure you want to logout from '%s'?", profileName), fmt.Sprintf("Are you sure you want to deregister from '%s'?", profileName),
func(confirm bool) { func(confirm bool) {
if !confirm { if !confirm {
return return
@@ -362,13 +362,13 @@ func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback
}) })
if err != nil { if err != nil {
log.Errorf("logout failed: %v", err) log.Errorf("logout failed: %v", err)
dialog.ShowError(fmt.Errorf("logout failed"), s.wProfiles) dialog.ShowError(fmt.Errorf("deregister failed"), s.wProfiles)
return return
} }
dialog.ShowInformation( dialog.ShowInformation(
"Logged Out", "Deregistered",
fmt.Sprintf("Successfully logged out from '%s'", profileName), fmt.Sprintf("Successfully deregistered from '%s'", profileName),
s.wProfiles, s.wProfiles,
) )
@@ -602,7 +602,7 @@ func (p *profileMenu) refresh() {
// Add Logout menu item // Add Logout menu item
ctx2, cancel2 := context.WithCancel(context.Background()) ctx2, cancel2 := context.WithCancel(context.Background())
logoutItem := p.profileMenuItem.AddSubMenuItem("Logout", "") logoutItem := p.profileMenuItem.AddSubMenuItem("Deregister", "")
p.logoutSubItem = &subItem{logoutItem, ctx2, cancel2} p.logoutSubItem = &subItem{logoutItem, ctx2, cancel2}
go func() { go func() {
@@ -616,9 +616,9 @@ func (p *profileMenu) refresh() {
} }
if err := p.eventHandler.logout(p.ctx); err != nil { if err := p.eventHandler.logout(p.ctx); err != nil {
log.Errorf("logout failed: %v", err) log.Errorf("logout failed: %v", err)
p.app.SendNotification(fyne.NewNotification("Error", "Failed to logout")) p.app.SendNotification(fyne.NewNotification("Error", "Failed to deregister"))
} else { } else {
p.app.SendNotification(fyne.NewNotification("Success", "Logged out successfully")) p.app.SendNotification(fyne.NewNotification("Success", "Deregistered successfully"))
} }
} }
} }

View File

@@ -102,6 +102,11 @@ func (n *NameServer) IsEqual(other *NameServer) bool {
other.Port == n.Port other.Port == n.Port
} }
// AddrPort returns the nameserver as a netip.AddrPort
func (n *NameServer) AddrPort() netip.AddrPort {
return netip.AddrPortFrom(n.IP, uint16(n.Port))
}
// ParseNameServerURL parses a nameserver url in the format <type>://<ip>:<port>, e.g., udp://1.1.1.1:53 // ParseNameServerURL parses a nameserver url in the format <type>://<ip>:<port>, e.g., udp://1.1.1.1:53
func ParseNameServerURL(nsURL string) (NameServer, error) { func ParseNameServerURL(nsURL string) (NameServer, error) {
parsedURL, err := url.Parse(nsURL) parsedURL, err := url.Parse(nsURL)

View File

@@ -9,7 +9,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/shared/context"
) )
type ExecutionContext string type ExecutionContext string

View File

@@ -40,12 +40,12 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/status"
) )
const ( const (
@@ -346,12 +346,12 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
} }
if updateAccountPeers || groupsUpdated { if updateAccountPeers || groupsUpdated {
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
} }
return transaction.SaveAccountSettings(ctx, store.LockingStrengthUpdate, accountID, newSettings) return transaction.SaveAccountSettings(ctx, accountID, newSettings)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -405,7 +405,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain) return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
} }
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "") peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
if err != nil { if err != nil {
return err return err
} }
@@ -571,19 +571,19 @@ func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
accountId := xid.New().String() accountId := xid.New().String()
_, err := am.Store.GetAccount(ctx, accountId) exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, accountId)
statusErr, _ := status.FromError(err) if err != nil {
switch { return nil, err
case err == nil: }
if exists {
log.WithContext(ctx).Warnf("an account with ID already exists, retrying...") log.WithContext(ctx).Warnf("an account with ID already exists, retrying...")
continue continue
case statusErr.Type() == status.NotFound: }
newAccount := newAccountWithId(ctx, accountId, userID, domain, am.disableDefaultPolicy) newAccount := newAccountWithId(ctx, accountId, userID, domain, am.disableDefaultPolicy)
am.StoreEvent(ctx, userID, newAccount.Id, accountId, activity.AccountCreated, nil) am.StoreEvent(ctx, userID, newAccount.Id, accountId, activity.AccountCreated, nil)
return newAccount, nil return newAccount, nil
default:
return nil, err
}
} }
return nil, status.Errorf(status.Internal, "error while creating new account") return nil, status.Errorf(status.Internal, "error while creating new account")
@@ -746,7 +746,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
// AccountExists checks if an account exists. // AccountExists checks if an account exists.
func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) { func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) {
return am.Store.AccountExists(ctx, store.LockingStrengthShare, accountID) return am.Store.AccountExists(ctx, store.LockingStrengthNone, accountID)
} }
// GetAccountIDByUserID retrieves the account ID based on the userID provided. // GetAccountIDByUserID retrieves the account ID based on the userID provided.
@@ -758,7 +758,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI
return "", status.Errorf(status.NotFound, "no valid userID provided") return "", status.Errorf(status.NotFound, "no valid userID provided")
} }
accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userID) accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
@@ -813,7 +813,7 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any)
log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID) log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID)
accountIDString := fmt.Sprintf("%v", accountID) accountIDString := fmt.Sprintf("%v", accountID)
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountIDString) accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountIDString)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -867,7 +867,7 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(ctx context.Context, e
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil // lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, accountID string) (*idp.UserData, error) { func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, accountID string) (*idp.UserData, error) {
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -897,7 +897,7 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s
// add extra check on external cache manager. We may get to this point when the user is not yet findable in IDP, // add extra check on external cache manager. We may get to this point when the user is not yet findable in IDP,
// or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta // or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, accountID) log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, accountID)
return nil, err return nil, err
@@ -1048,7 +1048,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID) unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlockAccount() defer unlockAccount()
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, accountID) accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
return err return err
@@ -1058,7 +1058,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
return nil return nil
} }
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error getting user: %v", err) log.WithContext(ctx).Errorf("error getting user: %v", err)
return err return err
@@ -1143,9 +1143,12 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context,
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
defer unlockAccount() defer unlockAccount()
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
if err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
newUser := types.NewRegularUser(userAuth.UserId) newUser := types.NewRegularUser(userAuth.UserId)
newUser.AccountID = domainAccountID newUser.AccountID = domainAccountID
err := am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser) err = am.Store.SaveUser(ctx, newUser)
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -1156,10 +1159,15 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context,
} }
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, nil) am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, nil)
return domainAccountID, nil return domainAccountID, nil
} }
return "", err
}
return user.AccountID, nil
}
// redeemInvite checks whether user has been invited and redeems the invite // redeemInvite checks whether user has been invited and redeems the invite
func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID string, userID string) error { func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID string, userID string) error {
// only possible with the enabled IdP manager // only possible with the enabled IdP manager
@@ -1223,7 +1231,7 @@ func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID s
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountMeta(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountMeta(ctx, store.LockingStrengthNone, accountID)
} }
// GetAccountOnboarding retrieves the onboarding information for a specific account. // GetAccountOnboarding retrieves the onboarding information for a specific account.
@@ -1308,7 +1316,7 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
return "", "", err return "", "", err
} }
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil { if err != nil {
// this is not really possible because we got an account by user ID // this is not really possible because we got an account by user ID
return "", "", status.Errorf(status.NotFound, "user %s not found", userAuth.UserId) return "", "", status.Errorf(status.NotFound, "user %s not found", userAuth.UserId)
@@ -1340,7 +1348,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
return nil return nil
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil { if err != nil {
return err return err
} }
@@ -1366,12 +1374,12 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
var hasChanges bool var hasChanges bool
var user *types.User var user *types.User
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil { if err != nil {
return fmt.Errorf("error getting user: %w", err) return fmt.Errorf("error getting user: %w", err)
} }
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId) groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil { if err != nil {
return fmt.Errorf("error getting account groups: %w", err) return fmt.Errorf("error getting account groups: %w", err)
} }
@@ -1387,7 +1395,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
return nil return nil
} }
if err = transaction.CreateGroups(ctx, store.LockingStrengthUpdate, userAuth.AccountId, newGroupsToCreate); err != nil { if err = transaction.CreateGroups(ctx, userAuth.AccountId, newGroupsToCreate); err != nil {
return fmt.Errorf("error saving groups: %w", err) return fmt.Errorf("error saving groups: %w", err)
} }
@@ -1395,13 +1403,13 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
removeOldGroups = util.Difference(user.AutoGroups, updatedAutoGroups) removeOldGroups = util.Difference(user.AutoGroups, updatedAutoGroups)
user.AutoGroups = updatedAutoGroups user.AutoGroups = updatedAutoGroups
if err = transaction.SaveUser(ctx, store.LockingStrengthUpdate, user); err != nil { if err = transaction.SaveUser(ctx, user); err != nil {
return fmt.Errorf("error saving user: %w", err) return fmt.Errorf("error saving user: %w", err)
} }
// Propagate changes to peers if group propagation is enabled // Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled { if settings.GroupsPropagationEnabled {
peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, userAuth.AccountId, userAuth.UserId) peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, userAuth.AccountId, userAuth.UserId)
if err != nil { if err != nil {
return fmt.Errorf("error getting user peers: %w", err) return fmt.Errorf("error getting user peers: %w", err)
} }
@@ -1419,7 +1427,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
} }
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, userAuth.AccountId); err != nil { if err = transaction.IncrementNetworkSerial(ctx, userAuth.AccountId); err != nil {
return fmt.Errorf("error incrementing network serial: %w", err) return fmt.Errorf("error incrementing network serial: %w", err)
} }
} }
@@ -1437,7 +1445,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
} }
for _, g := range addNewGroups { for _, g := range addNewGroups {
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g) group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthNone, userAuth.AccountId, g)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId) log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId)
} else { } else {
@@ -1450,7 +1458,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
} }
for _, g := range removeOldGroups { for _, g := range removeOldGroups {
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g) group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthNone, userAuth.AccountId, g)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId) log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId)
} else { } else {
@@ -1511,7 +1519,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
} }
if userAuth.IsChild { if userAuth.IsChild {
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, userAuth.AccountId) exists, err := am.Store.AccountExists(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil || !exists { if err != nil || !exists {
return "", err return "", err
} }
@@ -1535,7 +1543,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
return "", err return "", err
} }
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
return "", err return "", err
@@ -1556,7 +1564,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
return am.addNewPrivateAccount(ctx, domainAccountID, userAuth) return am.addNewPrivateAccount(ctx, domainAccountID, userAuth)
} }
func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) { func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) {
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain) domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
@@ -1571,7 +1579,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
cancel := am.Store.AcquireGlobalLock(ctx) cancel := am.Store.AcquireGlobalLock(ctx)
// check again if the domain has a primary account because of simultaneous requests // check again if the domain has a primary account because of simultaneous requests
domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain) domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
cancel() cancel()
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
@@ -1582,7 +1590,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
} }
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) { func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) {
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
return "", err return "", err
@@ -1592,7 +1600,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context
return "", fmt.Errorf("user %s is not part of the account id %s", userAuth.UserId, userAuth.AccountId) return "", fmt.Errorf("user %s is not part of the account id %s", userAuth.UserId, userAuth.AccountId)
} }
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, userAuth.AccountId) accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, userAuth.AccountId)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
return "", err return "", err
@@ -1603,7 +1611,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context
} }
// We checked if the domain has a primary account already // We checked if the domain has a primary account already
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, userAuth.Domain) domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, userAuth.Domain)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
return "", err return "", err
@@ -1751,7 +1759,7 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
} }
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction store.Store, peer *nbpeer.Peer, settings *types.Settings) (bool, error) { func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction store.Store, peer *nbpeer.Peer, settings *types.Settings) (bool, error) {
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID) user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, peer.UserID)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -1780,7 +1788,7 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account
if !allowed { if !allowed {
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
} }
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id // newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
@@ -1870,7 +1878,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
cancel := am.Store.AcquireGlobalLock(ctx) cancel := am.Store.AcquireGlobalLock(ctx)
defer cancel() defer cancel()
existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain) existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
return nil, false, err return nil, false, err
} }
@@ -1890,7 +1898,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
for range 2 { for range 2 {
accountId := xid.New().String() accountId := xid.New().String()
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, accountId) exists, err := am.Store.AccountExists(ctx, store.LockingStrengthNone, accountId)
if err != nil || exists { if err != nil || exists {
continue continue
} }
@@ -1965,7 +1973,7 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
return nil return nil
} }
existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, account.Domain) existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, account.Domain)
// error is not a not found error // error is not a not found error
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
@@ -2002,17 +2010,17 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
// propagateUserGroupMemberships propagates all account users' group memberships to their peers. // propagateUserGroupMemberships propagates all account users' group memberships to their peers.
// Returns true if any groups were modified, true if those updates affect peers and an error. // Returns true if any groups were modified, true if those updates affect peers and an error.
func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) { func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) {
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return false, false, err return false, false, err
} }
accountGroupPeers, err := transaction.GetAccountGroupPeers(ctx, store.LockingStrengthShare, accountID) accountGroupPeers, err := transaction.GetAccountGroupPeers(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return false, false, fmt.Errorf("error getting account group peers: %w", err) return false, false, fmt.Errorf("error getting account group peers: %w", err)
} }
accountGroups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) accountGroups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return false, false, fmt.Errorf("error getting account groups: %w", err) return false, false, fmt.Errorf("error getting account groups: %w", err)
} }
@@ -2025,7 +2033,7 @@ func propagateUserGroupMemberships(ctx context.Context, transaction store.Store,
updatedGroups := []string{} updatedGroups := []string{}
for _, user := range users { for _, user := range users {
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, user.Id) userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, user.Id)
if err != nil { if err != nil {
return false, false, err return false, false, err
} }
@@ -2074,7 +2082,7 @@ func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, t
account.Network.Net = newIPNet account.Network.Net = newIPNet
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "") peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
if err != nil { if err != nil {
return err return err
} }
@@ -2099,7 +2107,7 @@ func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, t
} }
for _, peer := range peers { for _, peer := range peers {
if err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer); err != nil { if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
return status.Errorf(status.Internal, "save updated peer %s: %v", peer.ID, err) return status.Errorf(status.Internal, "save updated peer %s: %v", peer.ID, err)
} }
} }
@@ -2154,7 +2162,7 @@ func (am *DefaultAccountManager) updatePeerIPInTransaction(ctx context.Context,
return fmt.Errorf("get account: %w", err) return fmt.Errorf("get account: %w", err)
} }
existingPeer, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID) existingPeer, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil { if err != nil {
return fmt.Errorf("get peer: %w", err) return fmt.Errorf("get peer: %w", err)
} }
@@ -2185,7 +2193,7 @@ func (am *DefaultAccountManager) updatePeerIPInTransaction(ctx context.Context,
func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transaction store.Store, accountID, userID string, peer *nbpeer.Peer, newIP netip.Addr) error { func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transaction store.Store, accountID, userID string, peer *nbpeer.Peer, newIP netip.Addr) error {
log.WithContext(ctx).Infof("updating peer %s IP from %s to %s", peer.ID, peer.IP, newIP) log.WithContext(ctx).Infof("updating peer %s IP from %s to %s", peer.ID, peer.IP, newIP)
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return fmt.Errorf("get account settings: %w", err) return fmt.Errorf("get account settings: %w", err)
} }
@@ -2195,7 +2203,7 @@ func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transacti
oldIP := peer.IP.String() oldIP := peer.IP.String()
peer.IP = newIP.AsSlice() peer.IP = newIP.AsSlice()
err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer) err = transaction.SavePeer(ctx, accountID, peer)
if err != nil { if err != nil {
return fmt.Errorf("save peer: %w", err) return fmt.Errorf("save peer: %w", err)
} }

View File

@@ -783,7 +783,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
return return
} }
exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthShare, accountID) exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthNone, accountID)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, exists, "expected to get existing account after creation using userid") assert.True(t, exists, "expected to get existing account after creation using userid")
@@ -900,11 +900,11 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount)) t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount))
} }
pats, err := manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, "service-user-1") pats, err := manager.Store.GetUserPATs(context.Background(), store.LockingStrengthNone, "service-user-1")
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, pats, 0) assert.Len(t, pats, 0)
pats, err = manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, userId) pats, err = manager.Store.GetUserPATs(context.Background(), store.LockingStrengthNone, userId)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, pats, 0) assert.Len(t, pats, 0)
} }
@@ -1786,7 +1786,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID) settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthNone, accountID)
require.NoError(t, err, "unable to get account settings") require.NoError(t, err, "unable to get account settings")
assert.NotNil(t, settings) assert.NotNil(t, settings)
@@ -1971,7 +1971,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
assert.False(t, updatedSettings.PeerLoginExpirationEnabled) assert.False(t, updatedSettings.PeerLoginExpirationEnabled)
assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour) assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour)
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID) settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthNone, accountID)
require.NoError(t, err, "unable to get account settings") require.NoError(t, err, "unable to get account settings")
assert.False(t, settings.PeerLoginExpirationEnabled) assert.False(t, settings.PeerLoginExpirationEnabled)
@@ -2655,7 +2655,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0, "JWT groups should not be synced") assert.Len(t, user.AutoGroups, 0, "JWT groups should not be synced")
}) })
@@ -2669,7 +2669,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err := manager.SyncUserJWTGroups(context.Background(), claims) err := manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Empty(t, user.AutoGroups, "auto groups must be empty") assert.Empty(t, user.AutoGroups, "auto groups must be empty")
}) })
@@ -2683,18 +2683,18 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err := manager.SyncUserJWTGroups(context.Background(), claims) err := manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0) assert.Len(t, user.AutoGroups, 0)
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1") group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthNone, "accountID", "group1")
assert.NoError(t, err, "unable to get group") assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued") assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued")
}) })
t.Run("jwt match existing api group in user auto groups", func(t *testing.T) { t.Run("jwt match existing api group in user auto groups", func(t *testing.T) {
account.Users["user1"].AutoGroups = []string{"group1"} account.Users["user1"].AutoGroups = []string{"group1"}
assert.NoError(t, manager.Store.SaveUser(context.Background(), store.LockingStrengthUpdate, account.Users["user1"])) assert.NoError(t, manager.Store.SaveUser(context.Background(), account.Users["user1"]))
claims := nbcontext.UserAuth{ claims := nbcontext.UserAuth{
UserId: "user1", UserId: "user1",
@@ -2704,11 +2704,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1) assert.Len(t, user.AutoGroups, 1)
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1") group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthNone, "accountID", "group1")
assert.NoError(t, err, "unable to get group") assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued") assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued")
}) })
@@ -2722,7 +2722,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 2, "groups count should not be change") assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
}) })
@@ -2736,7 +2736,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 2, "groups count should not be change") assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
}) })
@@ -2750,11 +2750,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, "accountID") groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthNone, "accountID")
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, groups, 3, "new group3 should be added") assert.Len(t, groups, 3, "new group3 should be added")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "new group should be added") assert.Len(t, user.AutoGroups, 1, "new group should be added")
}) })
@@ -2768,7 +2768,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain") assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain")
assert.Contains(t, user.AutoGroups, "group1", "group1 should still be present") assert.Contains(t, user.AutoGroups, "group1", "group1 should still be present")
@@ -2783,7 +2783,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0, "all JWT groups should be removed") assert.Len(t, user.AutoGroups, 0, "all JWT groups should be removed")
}) })
@@ -3348,11 +3348,11 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"} peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"}
err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer1) err = manager.Store.AddPeerToAccount(ctx, peer1)
require.NoError(t, err) require.NoError(t, err)
peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"} peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"}
err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer2) err = manager.Store.AddPeerToAccount(ctx, peer2)
require.NoError(t, err) require.NoError(t, err)
t.Run("should skip propagation when the user has no groups", func(t *testing.T) { t.Run("should skip propagation when the user has no groups", func(t *testing.T) {
@@ -3364,20 +3364,20 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
t.Run("should update membership but no account peers update for unused groups", func(t *testing.T) { t.Run("should update membership but no account peers update for unused groups", func(t *testing.T) {
group1 := &types.Group{ID: "group1", Name: "Group 1", AccountID: account.Id} group1 := &types.Group{ID: "group1", Name: "Group 1", AccountID: account.Id}
require.NoError(t, manager.Store.CreateGroup(ctx, store.LockingStrengthUpdate, group1)) require.NoError(t, manager.Store.CreateGroup(ctx, group1))
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId) user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
require.NoError(t, err) require.NoError(t, err)
user.AutoGroups = append(user.AutoGroups, group1.ID) user.AutoGroups = append(user.AutoGroups, group1.ID)
require.NoError(t, manager.Store.SaveUser(ctx, store.LockingStrengthUpdate, user)) require.NoError(t, manager.Store.SaveUser(ctx, user))
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, groupsUpdated) assert.True(t, groupsUpdated)
assert.False(t, groupChangesAffectPeers) assert.False(t, groupChangesAffectPeers)
group, err := manager.Store.GetGroupByID(ctx, store.LockingStrengthShare, account.Id, group1.ID) group, err := manager.Store.GetGroupByID(ctx, store.LockingStrengthNone, account.Id, group1.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, group.Peers, 2) assert.Len(t, group.Peers, 2)
assert.Contains(t, group.Peers, "peer1") assert.Contains(t, group.Peers, "peer1")
@@ -3386,13 +3386,13 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
t.Run("should update membership and account peers for used groups", func(t *testing.T) { t.Run("should update membership and account peers for used groups", func(t *testing.T) {
group2 := &types.Group{ID: "group2", Name: "Group 2", AccountID: account.Id} group2 := &types.Group{ID: "group2", Name: "Group 2", AccountID: account.Id}
require.NoError(t, manager.Store.CreateGroup(ctx, store.LockingStrengthUpdate, group2)) require.NoError(t, manager.Store.CreateGroup(ctx, group2))
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId) user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
require.NoError(t, err) require.NoError(t, err)
user.AutoGroups = append(user.AutoGroups, group2.ID) user.AutoGroups = append(user.AutoGroups, group2.ID)
require.NoError(t, manager.Store.SaveUser(ctx, store.LockingStrengthUpdate, user)) require.NoError(t, manager.Store.SaveUser(ctx, user))
_, err = manager.SavePolicy(context.Background(), account.Id, initiatorId, &types.Policy{ _, err = manager.SavePolicy(context.Background(), account.Id, initiatorId, &types.Policy{
Name: "Group1 Policy", Name: "Group1 Policy",
@@ -3415,7 +3415,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
assert.True(t, groupsUpdated) assert.True(t, groupsUpdated)
assert.True(t, groupChangesAffectPeers) assert.True(t, groupChangesAffectPeers)
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthShare, account.Id, []string{"group1", "group2"}) groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthNone, account.Id, []string{"group1", "group2"})
require.NoError(t, err) require.NoError(t, err)
for _, group := range groups { for _, group := range groups {
assert.Len(t, group.Peers, 2) assert.Len(t, group.Peers, 2)
@@ -3432,18 +3432,18 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
}) })
t.Run("should not remove peers when groups are removed from user", func(t *testing.T) { t.Run("should not remove peers when groups are removed from user", func(t *testing.T) {
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId) user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
require.NoError(t, err) require.NoError(t, err)
user.AutoGroups = []string{"group1"} user.AutoGroups = []string{"group1"}
require.NoError(t, manager.Store.SaveUser(ctx, store.LockingStrengthUpdate, user)) require.NoError(t, manager.Store.SaveUser(ctx, user))
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
require.NoError(t, err) require.NoError(t, err)
assert.False(t, groupsUpdated) assert.False(t, groupsUpdated)
assert.False(t, groupChangesAffectPeers) assert.False(t, groupChangesAffectPeers)
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthShare, account.Id, []string{"group1", "group2"}) groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthNone, account.Id, []string{"group1", "group2"})
require.NoError(t, err) require.NoError(t, err)
for _, group := range groups { for _, group := range groups {
assert.Len(t, group.Peers, 2) assert.Len(t, group.Peers, 2)
@@ -3453,6 +3453,50 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
}) })
} }
func TestDefaultAccountManager_AddNewUserToDomainAccount(t *testing.T) {
testCases := []struct {
name string
userAuth nbcontext.UserAuth
expectedRole types.UserRole
}{
{
name: "existing user",
userAuth: nbcontext.UserAuth{
Domain: "example.com",
UserId: "user1",
},
expectedRole: types.UserRoleOwner,
},
{
name: "new user",
userAuth: nbcontext.UserAuth{
Domain: "example.com",
UserId: "user2",
},
expectedRole: types.UserRoleUser,
},
}
manager, err := createManager(t)
require.NoError(t, err)
accountID, err := manager.GetAccountIDByUserID(context.Background(), "user1", "example.com")
require.NoError(t, err, "create init user failed")
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
userAccountID, err := manager.addNewUserToDomainAccount(context.Background(), accountID, tc.userAuth)
require.NoError(t, err)
assert.Equal(t, accountID, userAccountID)
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, tc.userAuth.UserId)
require.NoError(t, err)
assert.Equal(t, accountID, user.AccountID)
assert.Equal(t, tc.expectedRole, user.Role)
})
}
}
func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) { func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err) require.NoError(t, err)

View File

@@ -73,7 +73,7 @@ func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbco
return userAuth, nil return userAuth, nil
} }
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId) settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil { if err != nil {
return userAuth, err return userAuth, err
} }
@@ -94,7 +94,7 @@ func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbco
// MarkPATUsed marks a personal access token as used // MarkPATUsed marks a personal access token as used
func (am *manager) MarkPATUsed(ctx context.Context, tokenID string) error { func (am *manager) MarkPATUsed(ctx context.Context, tokenID string) error {
return am.store.MarkPATUsed(ctx, store.LockingStrengthUpdate, tokenID) return am.store.MarkPATUsed(ctx, tokenID)
} }
// GetPATInfo retrieves user, personal access token, domain, and category details from a personal access token. // GetPATInfo retrieves user, personal access token, domain, and category details from a personal access token.
@@ -104,7 +104,7 @@ func (am *manager) GetPATInfo(ctx context.Context, token string) (user *types.Us
return nil, nil, "", "", err return nil, nil, "", "", err
} }
domain, category, err = am.store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, user.AccountID) domain, category, err = am.store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, user.AccountID)
if err != nil { if err != nil {
return nil, nil, "", "", err return nil, nil, "", "", err
} }
@@ -142,12 +142,12 @@ func (am *manager) extractPATFromToken(ctx context.Context, token string) (*type
var pat *types.PersonalAccessToken var pat *types.PersonalAccessToken
err = am.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken) pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthNone, encodedHashedToken)
if err != nil { if err != nil {
return err return err
} }
user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthShare, pat.ID) user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthNone, pat.ID)
return err return err
}) })
if err != nil { if err != nil {

View File

@@ -1,8 +1,10 @@
package context package context
import "github.com/netbirdio/netbird/shared/context"
const ( const (
RequestIDKey = "requestID" RequestIDKey = context.RequestIDKey
AccountIDKey = "accountID" AccountIDKey = context.AccountIDKey
UserIDKey = "userID" UserIDKey = context.UserIDKey
PeerIDKey = "peerID" PeerIDKey = context.PeerIDKey
) )

View File

@@ -8,14 +8,14 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status"
) )
// DNSConfigCache is a thread-safe cache for DNS configuration components // DNSConfigCache is a thread-safe cache for DNS configuration components
@@ -72,7 +72,7 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
} }
// SaveDNSSettings validates a user role and updates the account's DNS settings // SaveDNSSettings validates a user role and updates the account's DNS settings
@@ -113,11 +113,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups) events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
eventsToStore = append(eventsToStore, events...) eventsToStore = append(eventsToStore, events...)
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.SaveDNSSettings(ctx, store.LockingStrengthUpdate, accountID, dnsSettingsToSave) return transaction.SaveDNSSettings(ctx, accountID, dnsSettingsToSave)
}) })
if err != nil { if err != nil {
return err return err
@@ -139,7 +139,7 @@ func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, t
var eventsToStore []func() var eventsToStore []func()
modifiedGroups := slices.Concat(addedGroups, removedGroups) modifiedGroups := slices.Concat(addedGroups, removedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, modifiedGroups)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err) log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err)
return nil return nil
@@ -195,7 +195,7 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID
return nil return nil
} }
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, settings.DisabledManagementGroups) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, settings.DisabledManagementGroups)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -24,7 +24,7 @@ import (
"github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/shared/management/status"
) )
const ( const (

View File

@@ -134,7 +134,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
} }
func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) { func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthShare) peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthNone)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err) log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err)
return return

View File

@@ -11,9 +11,9 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
) )
func isEnabled() bool { func isEnabled() bool {
@@ -103,7 +103,7 @@ func (am *DefaultAccountManager) fillEventsWithUserInfo(ctx context.Context, eve
} }
func (am *DefaultAccountManager) getEventsUserInfo(ctx context.Context, events []*activity.Event, accountId string, userId string) (map[string]eventUserInfo, error) { func (am *DefaultAccountManager) getEventsUserInfo(ctx context.Context, events []*activity.Event, accountId string, userId string) (map[string]eventUserInfo, error) {
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountId) accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -154,7 +154,7 @@ func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context,
continue continue
} }
externalUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id) externalUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id)
if err != nil { if err != nil {
// @todo consider logging // @todo consider logging
continue continue

View File

@@ -13,7 +13,7 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/shared/management/status"
) )
type GeoNames struct { type GeoNames struct {

View File

@@ -14,11 +14,11 @@ import (
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/status"
) )
type GroupLinkError struct { type GroupLinkError struct {
@@ -49,7 +49,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err return nil, err
} }
return am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID) return am.Store.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
} }
// GetAllGroups returns all groups in an account // GetAllGroups returns all groups in an account
@@ -57,12 +57,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err return nil, err
} }
return am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
} }
// GetGroupByName filters all groups in an account by name and returns the one with the most peers // GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) { func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
return am.Store.GetGroupByName(ctx, store.LockingStrengthShare, accountID, groupName) return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
} }
// CreateGroup object of the peers // CreateGroup object of the peers
@@ -96,11 +96,11 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
if err := transaction.CreateGroup(ctx, store.LockingStrengthUpdate, newGroup); err != nil { if err := transaction.CreateGroup(ctx, newGroup); err != nil {
return status.Errorf(status.Internal, "failed to create group: %v", err) return status.Errorf(status.Internal, "failed to create group: %v", err)
} }
@@ -147,7 +147,7 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
return err return err
} }
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, newGroup.ID) oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
if err != nil { if err != nil {
return status.Errorf(status.NotFound, "group with ID %s not found", newGroup.ID) return status.Errorf(status.NotFound, "group with ID %s not found", newGroup.ID)
} }
@@ -176,11 +176,11 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.UpdateGroup(ctx, store.LockingStrengthUpdate, newGroup) return transaction.UpdateGroup(ctx, newGroup)
}) })
if err != nil { if err != nil {
return err return err
@@ -234,11 +234,11 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.CreateGroups(ctx, store.LockingStrengthUpdate, accountID, groupsToSave) return transaction.CreateGroups(ctx, accountID, groupsToSave)
}) })
if err != nil { if err != nil {
return err return err
@@ -292,11 +292,11 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.UpdateGroups(ctx, store.LockingStrengthUpdate, accountID, groupsToSave) return transaction.UpdateGroups(ctx, accountID, groupsToSave)
}) })
if err != nil { if err != nil {
return err return err
@@ -320,7 +320,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
addedPeers := make([]string, 0) addedPeers := make([]string, 0)
removedPeers := make([]string, 0) removedPeers := make([]string, 0)
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, newGroup.ID) oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
if err == nil && oldGroup != nil { if err == nil && oldGroup != nil {
addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers) addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers)
removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers) removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers)
@@ -332,13 +332,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
} }
modifiedPeers := slices.Concat(addedPeers, removedPeers) modifiedPeers := slices.Concat(addedPeers, removedPeers)
peers, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, modifiedPeers) peers, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, modifiedPeers)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err) log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err)
return nil return nil
} }
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err) log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err)
return nil return nil
@@ -423,11 +423,11 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
deletedGroups = append(deletedGroups, group) deletedGroups = append(deletedGroups, group)
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.DeleteGroups(ctx, store.LockingStrengthUpdate, accountID, groupIDsToDelete) return transaction.DeleteGroups(ctx, accountID, groupIDsToDelete)
}) })
if err != nil { if err != nil {
return err return err
@@ -454,7 +454,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
@@ -495,11 +495,11 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.UpdateGroup(ctx, store.LockingStrengthUpdate, group) return transaction.UpdateGroup(ctx, group)
}) })
if err != nil { if err != nil {
return err return err
@@ -526,7 +526,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
@@ -567,11 +567,11 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.UpdateGroup(ctx, store.LockingStrengthUpdate, group) return transaction.UpdateGroup(ctx, group)
}) })
if err != nil { if err != nil {
return err return err
@@ -591,7 +591,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
} }
if newGroup.ID == "" && newGroup.Issued == types.GroupIssuedAPI { if newGroup.ID == "" && newGroup.Issued == types.GroupIssuedAPI {
existingGroup, err := transaction.GetGroupByName(ctx, store.LockingStrengthShare, accountID, newGroup.Name) existingGroup, err := transaction.GetGroupByName(ctx, store.LockingStrengthNone, accountID, newGroup.Name)
if err != nil { if err != nil {
if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound { if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound {
return err return err
@@ -608,7 +608,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
} }
for _, peerID := range newGroup.Peers { for _, peerID := range newGroup.Peers {
_, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID) _, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil { if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
} }
@@ -620,7 +620,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error { func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error {
// disable a deleting integration group if the initiator is not an admin service user // disable a deleting integration group if the initiator is not an admin service user
if group.Issued == types.GroupIssuedIntegration { if group.Issued == types.GroupIssuedIntegration {
executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userID) executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
return status.Errorf(status.Internal, "failed to get user") return status.Errorf(status.Internal, "failed to get user")
} }
@@ -666,7 +666,7 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty
// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account. // checkGroupLinkedToSettings verifies if a group is linked to any settings in the account.
func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, group *types.Group) error { func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, group *types.Group) error {
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, group.AccountID) dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, group.AccountID)
if err != nil { if err != nil {
return status.Errorf(status.Internal, "failed to get DNS settings") return status.Errorf(status.Internal, "failed to get DNS settings")
} }
@@ -675,7 +675,7 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, gr
return &GroupLinkError{"disabled DNS management groups", group.Name} return &GroupLinkError{"disabled DNS management groups", group.Name}
} }
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, group.AccountID) settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, group.AccountID)
if err != nil { if err != nil {
return status.Errorf(status.Internal, "failed to get account settings") return status.Errorf(status.Internal, "failed to get account settings")
} }
@@ -689,7 +689,7 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, gr
// isGroupLinkedToRoute checks if a group is linked to any route in the account. // isGroupLinkedToRoute checks if a group is linked to any route in the account.
func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *route.Route) { func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *route.Route) {
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
return false, nil return false, nil
@@ -709,7 +709,7 @@ func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountI
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account. // isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.Policy) { func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.Policy) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
return false, nil return false, nil
@@ -727,7 +727,7 @@ func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, account
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. // isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) { func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID) nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
return false, nil return false, nil
@@ -746,7 +746,7 @@ func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. // isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.SetupKey) { func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.SetupKey) {
setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID) setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
return false, nil return false, nil
@@ -762,7 +762,7 @@ func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accou
// isGroupLinkedToUser checks if a group is linked to any user in the account. // isGroupLinkedToUser checks if a group is linked to any user in the account.
func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.User) { func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.User) {
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
return false, nil return false, nil
@@ -778,7 +778,7 @@ func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID
// isGroupLinkedToNetworkRouter checks if a group is linked to any network router in the account. // isGroupLinkedToNetworkRouter checks if a group is linked to any network router in the account.
func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *routerTypes.NetworkRouter) { func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *routerTypes.NetworkRouter) {
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID) routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving network routers while checking group linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving network routers while checking group linkage: %v", err)
return false, nil return false, nil
@@ -798,7 +798,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
return false, nil return false, nil
} }
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID) dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -826,7 +826,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources. // anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) { func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupIDs)
if err != nil { if err != nil {
return false, err return false, err
} }

View File

@@ -26,10 +26,10 @@ import (
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
peer2 "github.com/netbirdio/netbird/management/server/peer" peer2 "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/status"
) )
const ( const (
@@ -898,7 +898,7 @@ func Test_AddPeerAndAddToAll(t *testing.T) {
} }
err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error { err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error {
err = transaction.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer) err = transaction.AddPeerToAccount(context.Background(), peer)
if err != nil { if err != nil {
return fmt.Errorf("AddPeer failed for peer %d: %w", i, err) return fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
} }
@@ -971,7 +971,7 @@ func Test_IncrementNetworkSerial(t *testing.T) {
<-start <-start
err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error { err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error {
err = transaction.IncrementNetworkSerial(context.Background(), store.LockingStrengthNone, accountID) err = transaction.IncrementNetworkSerial(context.Background(), accountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get account %s: %v", accountID, err) return fmt.Errorf("failed to get account %s: %v", accountID, err)
} }

View File

@@ -6,12 +6,12 @@ import (
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
) )
type Manager interface { type Manager interface {
@@ -49,7 +49,7 @@ func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string
return nil, err return nil, err
} }
groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting account groups: %w", err) return nil, fmt.Errorf("error getting account groups: %w", err)
} }
@@ -96,13 +96,13 @@ func (m *managerImpl) AddResourceToGroupInTransaction(ctx context.Context, trans
return nil, fmt.Errorf("error adding resource to group: %w", err) return nil, fmt.Errorf("error adding resource to group: %w", err)
} }
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID) group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting group: %w", err) return nil, fmt.Errorf("error getting group: %w", err)
} }
// TODO: at some point, this will need to become a switch statement // TODO: at some point, this will need to become a switch statement
networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resource.ID) networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resource.ID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting network resource: %w", err) return nil, fmt.Errorf("error getting network resource: %w", err)
} }
@@ -120,13 +120,13 @@ func (m *managerImpl) RemoveResourceFromGroupInTransaction(ctx context.Context,
return nil, fmt.Errorf("error removing resource from group: %w", err) return nil, fmt.Errorf("error removing resource from group: %w", err)
} }
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID) group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting group: %w", err) return nil, fmt.Errorf("error getting group: %w", err)
} }
// TODO: at some point, this will need to become a switch statement // TODO: at some point, this will need to become a switch statement
networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resourceID) networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting network resource: %w", err) return nil, fmt.Errorf("error getting network resource: %w", err)
} }

View File

@@ -24,7 +24,6 @@ import (
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/auth"
@@ -32,9 +31,10 @@ import (
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
internalStatus "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
internalStatus "github.com/netbirdio/netbird/shared/management/status"
) )
// GRPCServer an instance of a Management gRPC API server // GRPCServer an instance of a Management gRPC API server
@@ -913,6 +913,7 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage)
func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) { func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
log.WithContext(ctx).Debugf("Logout request from peer [%s]", req.WgPubKey) log.WithContext(ctx).Debugf("Logout request from peer [%s]", req.WgPubKey)
start := time.Now()
empty := &proto.Empty{} empty := &proto.Empty{}
peerKey, err := s.parseRequest(ctx, req, empty) peerKey, err := s.parseRequest(ctx, req, empty)
@@ -920,7 +921,7 @@ func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*
return nil, err return nil, err
} }
peer, err := s.accountManager.GetStore().GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, peerKey.String()) peer, err := s.accountManager.GetStore().GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerKey.String())
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("peer %s is not registered for logout", peerKey.String()) log.WithContext(ctx).Debugf("peer %s is not registered for logout", peerKey.String())
// TODO: consider idempotency // TODO: consider idempotency
@@ -944,7 +945,7 @@ func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*
s.accountManager.BufferUpdateAccountPeers(ctx, peer.AccountID) s.accountManager.BufferUpdateAccountPeers(ctx, peer.AccountID)
log.WithContext(ctx).Infof("peer %s logged out successfully", peerKey.String()) log.WithContext(ctx).Debugf("peer %s logged out successfully after %s", peerKey.String(), time.Since(start))
return &proto.Empty{}, nil return &proto.Empty{}, nil
} }

View File

@@ -11,10 +11,10 @@ import (
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )

View File

@@ -15,10 +15,10 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )

View File

@@ -9,8 +9,8 @@ import (
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )

View File

@@ -11,8 +11,8 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/gorilla/mux" "github.com/gorilla/mux"

View File

@@ -11,9 +11,9 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/shared/management/status"
) )
// nameserversHandler is the nameserver group handler of the account // nameserversHandler is the nameserver group handler of the account

View File

@@ -13,8 +13,8 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/shared/management/status"
"github.com/gorilla/mux" "github.com/gorilla/mux"

View File

@@ -10,8 +10,8 @@ import (
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
) )
// handler HTTP handler // handler HTTP handler

View File

@@ -16,7 +16,7 @@ import (
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )

View File

@@ -11,9 +11,9 @@ import (
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )

View File

@@ -19,11 +19,11 @@ import (
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )

View File

@@ -12,14 +12,14 @@ import (
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers" "github.com/netbirdio/netbird/management/server/networks/routers"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/networks/types" "github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/shared/management/status"
nbtypes "github.com/netbirdio/netbird/management/server/types" nbtypes "github.com/netbirdio/netbird/management/server/types"
) )

View File

@@ -8,8 +8,8 @@ import (
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/resources/types" "github.com/netbirdio/netbird/management/server/networks/resources/types"
) )

View File

@@ -7,8 +7,8 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/networks/routers" "github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/networks/routers/types" "github.com/netbirdio/netbird/management/server/networks/routers/types"
) )

View File

@@ -14,10 +14,10 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )

View File

@@ -17,7 +17,7 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"

View File

@@ -16,7 +16,7 @@ import (
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"

View File

@@ -9,12 +9,12 @@ import (
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/shared/management/status"
) )
var ( var (

View File

@@ -10,9 +10,9 @@ import (
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )

View File

@@ -14,9 +14,9 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )

View File

@@ -9,10 +9,10 @@ import (
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/shared/management/status"
) )
// postureChecksHandler is a handler that returns posture checks of the account. // postureChecksHandler is a handler that returns posture checks of the account.

View File

@@ -16,10 +16,10 @@ import (
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/shared/management/status"
) )
var berlin = "Berlin" var berlin = "Berlin"

View File

@@ -11,9 +11,9 @@ import (
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )

View File

@@ -17,9 +17,9 @@ import (
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )

View File

@@ -10,9 +10,9 @@ import (
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )

View File

@@ -15,9 +15,9 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )

View File

@@ -8,9 +8,9 @@ import (
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )

View File

@@ -17,9 +17,9 @@ import (
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )

View File

@@ -9,9 +9,9 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/users"

View File

@@ -16,11 +16,11 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/roles" "github.com/netbirdio/netbird/management/server/permissions/roles"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/users"
) )

View File

@@ -13,8 +13,8 @@ import (
"github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )

Some files were not shown because too many files have changed in this diff Show More