Merge branch 'main' into fix/ice-handshake

This commit is contained in:
Zoltán Papp
2025-08-13 10:47:54 +02:00
119 changed files with 2002 additions and 1176 deletions

View File

@@ -12,6 +12,16 @@
- [ ] Is a feature enhancement - [ ] Is a feature enhancement
- [ ] It is a refactor - [ ] It is a refactor
- [ ] Created tests that fail without the change (if possible) - [ ] Created tests that fail without the change (if possible)
- [ ] Extended the README / documentation, if necessary
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md). > By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).
## Documentation
Select exactly one:
- [ ] I added/updated documentation for this change
- [ ] Documentation is **not needed** for this change (explain why)
### Docs PR URL (required if "docs added" is checked)
Paste the PR link from https://github.com/netbirdio/docs here:
https://github.com/netbirdio/docs/pull/__

94
.github/workflows/docs-ack.yml vendored Normal file
View File

@@ -0,0 +1,94 @@
name: Docs Acknowledgement
on:
pull_request:
types: [opened, edited, synchronize]
permissions:
contents: read
pull-requests: read
jobs:
docs-ack:
name: Require docs PR URL or explicit "not needed"
runs-on: ubuntu-latest
steps:
- name: Read PR body
id: body
run: |
BODY=$(jq -r '.pull_request.body // ""' "$GITHUB_EVENT_PATH")
echo "body<<EOF" >> $GITHUB_OUTPUT
echo "$BODY" >> $GITHUB_OUTPUT
echo "EOF" >> $GITHUB_OUTPUT
- name: Validate checkbox selection
id: validate
run: |
body='${{ steps.body.outputs.body }}'
added_checked=$(printf "%s" "$body" | grep -E '^- \[x\] I added/updated documentation' -i | wc -l | tr -d ' ')
noneed_checked=$(printf "%s" "$body" | grep -E '^- \[x\] Documentation is \*\*not needed\*\*' -i | wc -l | tr -d ' ')
if [ "$added_checked" -eq 1 ] && [ "$noneed_checked" -eq 1 ]; then
echo "::error::Choose exactly one: either 'docs added' OR 'not needed'."
exit 1
fi
if [ "$added_checked" -eq 0 ] && [ "$noneed_checked" -eq 0 ]; then
echo "::error::You must check exactly one docs option in the PR template."
exit 1
fi
if [ "$added_checked" -eq 1 ]; then
echo "mode=added" >> $GITHUB_OUTPUT
else
echo "mode=noneed" >> $GITHUB_OUTPUT
fi
- name: Extract docs PR URL (when 'docs added')
if: steps.validate.outputs.mode == 'added'
id: extract
run: |
body='${{ steps.body.outputs.body }}'
# Strictly require HTTPS and that it's a PR in netbirdio/docs
# Examples accepted:
# https://github.com/netbirdio/docs/pull/1234
url=$(printf "%s" "$body" | grep -Eo 'https://github\.com/netbirdio/docs/pull/[0-9]+' | head -n1 || true)
if [ -z "$url" ]; then
echo "::error::You checked 'docs added' but didn't include a valid HTTPS PR link to netbirdio/docs (e.g., https://github.com/netbirdio/docs/pull/1234)."
exit 1
fi
pr_number=$(echo "$url" | sed -E 's#.*/pull/([0-9]+)$#\1#')
echo "url=$url" >> $GITHUB_OUTPUT
echo "pr_number=$pr_number" >> $GITHUB_OUTPUT
- name: Verify docs PR exists (and is open or merged)
if: steps.validate.outputs.mode == 'added'
uses: actions/github-script@v7
id: verify
with:
pr_number: ${{ steps.extract.outputs.pr_number }}
script: |
const prNumber = parseInt(core.getInput('pr_number'), 10);
const { data } = await github.rest.pulls.get({
owner: 'netbirdio',
repo: 'docs',
pull_number: prNumber
});
// Allow open or merged PRs
const ok = data.state === 'open' || data.merged === true;
core.setOutput('state', data.state);
core.setOutput('merged', String(!!data.merged));
if (!ok) {
core.setFailed(`Docs PR #${prNumber} exists but is neither open nor merged (state=${data.state}, merged=${data.merged}).`);
}
result-encoding: string
github-token: ${{ secrets.GITHUB_TOKEN }}
- name: All good
run: echo "Documentation requirement satisfied ✅"

18
.github/workflows/forum.yml vendored Normal file
View File

@@ -0,0 +1,18 @@
name: Post release topic on Discourse
on:
release:
types: [published]
jobs:
post:
runs-on: ubuntu-latest
steps:
- uses: roots/discourse-topic-github-release-action@main
with:
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
discourse-base-url: https://forum.netbird.io
discourse-author-username: NetBird
discourse-category: 17
discourse-tags:
releases

View File

@@ -9,7 +9,7 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.0.21" SIGN_PIPE_VER: "v0.0.22"
GORELEASER_VER: "v2.3.2" GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird" PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH" COPYRIGHT: "NetBird GmbH"
@@ -79,6 +79,8 @@ jobs:
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64 - name: Generate windows syso amd64
run: goversioninfo -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso run: goversioninfo -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
- name: Generate windows syso arm64
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
- name: Run GoReleaser - name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
with: with:
@@ -154,10 +156,20 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64 run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
- name: Install LLVM-MinGW for ARM64 cross-compilation
run: |
cd /tmp
wget -q https://github.com/mstorsjo/llvm-mingw/releases/download/20250709/llvm-mingw-20250709-ucrt-ubuntu-22.04-x86_64.tar.xz
echo "60cafae6474c7411174cff1d4ba21a8e46cadbaeb05a1bace306add301628337 llvm-mingw-20250709-ucrt-ubuntu-22.04-x86_64.tar.xz" | sha256sum -c
tar -xf llvm-mingw-20250709-ucrt-ubuntu-22.04-x86_64.tar.xz
echo "/tmp/llvm-mingw-20250709-ucrt-ubuntu-22.04-x86_64/bin" >> $GITHUB_PATH
- name: Install goversioninfo - name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64 - name: Generate windows syso amd64
run: goversioninfo -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso run: goversioninfo -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
- name: Generate windows syso arm64
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
- name: Run GoReleaser - name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
@@ -231,17 +243,3 @@ jobs:
ref: ${{ env.SIGN_PIPE_VER }} ref: ${{ env.SIGN_PIPE_VER }}
token: ${{ secrets.SIGN_GITHUB_TOKEN }} token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }' inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'
post_on_forum:
runs-on: ubuntu-latest
continue-on-error: true
needs: [trigger_signer]
steps:
- uses: Codixer/discourse-topic-github-release-action@v2.0.1
with:
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
discourse-base-url: https://forum.netbird.io
discourse-author-username: NetBird
discourse-category: 17
discourse-tags:
releases

View File

@@ -16,8 +16,6 @@ builds:
- arm64 - arm64
- 386 - 386
ignore: ignore:
- goos: windows
goarch: arm64
- goos: windows - goos: windows
goarch: arm goarch: arm
- goos: windows - goos: windows

View File

@@ -15,7 +15,7 @@ builds:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-ui-windows - id: netbird-ui-windows-amd64
dir: client/ui dir: client/ui
binary: netbird-ui binary: netbird-ui
env: env:
@@ -30,6 +30,22 @@ builds:
- -H windowsgui - -H windowsgui
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-ui-windows-arm64
dir: client/ui
binary: netbird-ui
env:
- CGO_ENABLED=1
- CC=aarch64-w64-mingw32-clang
- CXX=aarch64-w64-mingw32-clang++
goos:
- windows
goarch:
- arm64
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
- -H windowsgui
mod_timestamp: "{{ .CommitTimestamp }}"
archives: archives:
- id: linux-arch - id: linux-arch
name_template: "{{ .ProjectName }}-linux_{{ .Version }}_{{ .Os }}_{{ .Arch }}" name_template: "{{ .ProjectName }}-linux_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
@@ -38,7 +54,8 @@ archives:
- id: windows-arch - id: windows-arch
name_template: "{{ .ProjectName }}-windows_{{ .Version }}_{{ .Os }}_{{ .Arch }}" name_template: "{{ .ProjectName }}-windows_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
builds: builds:
- netbird-ui-windows - netbird-ui-windows-amd64
- netbird-ui-windows-arm64
nfpms: nfpms:
- maintainer: Netbird <dev@netbird.io> - maintainer: Netbird <dev@netbird.io>

View File

@@ -4,6 +4,7 @@ package android
import ( import (
"context" "context"
"slices"
"sync" "sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -112,7 +113,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
} }
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
@@ -138,7 +139,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
} }
// Stop the internal client and free the resources // Stop the internal client and free the resources
@@ -235,7 +236,7 @@ func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
return err return err
} }
dnsServer.OnUpdatedHostDNSServer(list.items) dnsServer.OnUpdatedHostDNSServer(slices.Clone(list.items))
return nil return nil
} }

View File

@@ -1,23 +1,34 @@
package android package android
import "fmt" import (
"fmt"
"net/netip"
// DNSList is a wrapper of []string "github.com/netbirdio/netbird/client/internal/dns"
)
// DNSList is a wrapper of []netip.AddrPort with default DNS port
type DNSList struct { type DNSList struct {
items []string items []netip.AddrPort
} }
// Add new DNS address to the collection // Add new DNS address to the collection, returns error if invalid
func (array *DNSList) Add(s string) { func (array *DNSList) Add(s string) error {
array.items = append(array.items, s) addr, err := netip.ParseAddr(s)
if err != nil {
return fmt.Errorf("invalid DNS address: %s", s)
}
addrPort := netip.AddrPortFrom(addr.Unmap(), dns.DefaultPort)
array.items = append(array.items, addrPort)
return nil
} }
// Get return an element of the collection // Get return an element of the collection as string
func (array *DNSList) Get(i int) (string, error) { func (array *DNSList) Get(i int) (string, error) {
if i >= len(array.items) || i < 0 { if i >= len(array.items) || i < 0 {
return "", fmt.Errorf("out of range") return "", fmt.Errorf("out of range")
} }
return array.items[i], nil return array.items[i].Addr().String(), nil
} }
// Size return with the size of the collection // Size return with the size of the collection

View File

@@ -3,20 +3,30 @@ package android
import "testing" import "testing"
func TestDNSList_Get(t *testing.T) { func TestDNSList_Get(t *testing.T) {
l := DNSList{ l := DNSList{}
items: make([]string, 1),
// Add a valid DNS address
err := l.Add("8.8.8.8")
if err != nil {
t.Errorf("unexpected error: %s", err)
} }
_, err := l.Get(0) // Test getting valid index
addr, err := l.Get(0)
if err != nil { if err != nil {
t.Errorf("invalid error: %s", err) t.Errorf("invalid error: %s", err)
} }
if addr != "8.8.8.8" {
t.Errorf("expected 8.8.8.8, got %s", addr)
}
// Test negative index
_, err = l.Get(-1) _, err = l.Get(-1)
if err == nil { if err == nil {
t.Errorf("expected error but got nil") t.Errorf("expected error but got nil")
} }
// Test out of bounds index
_, err = l.Get(1) _, err = l.Get(1)
if err == nil { if err == nil {
t.Errorf("expected error but got nil") t.Errorf("expected error but got nil")

View File

@@ -33,7 +33,7 @@ var (
var debugCmd = &cobra.Command{ var debugCmd = &cobra.Command{
Use: "debug", Use: "debug",
Short: "Debugging commands", Short: "Debugging commands",
Long: "Provides commands for debugging and logging control within the Netbird daemon.", Long: "Provides commands for debugging and logging control within the NetBird daemon.",
} }
var debugBundleCmd = &cobra.Command{ var debugBundleCmd = &cobra.Command{
@@ -46,8 +46,8 @@ var debugBundleCmd = &cobra.Command{
var logCmd = &cobra.Command{ var logCmd = &cobra.Command{
Use: "log", Use: "log",
Short: "Manage logging for the Netbird daemon", Short: "Manage logging for the NetBird daemon",
Long: `Commands to manage logging settings for the Netbird daemon, including ICE, gRPC, and general log levels.`, Long: `Commands to manage logging settings for the NetBird daemon, including ICE, gRPC, and general log levels.`,
} }
var logLevelCmd = &cobra.Command{ var logLevelCmd = &cobra.Command{
@@ -184,7 +184,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil { if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message()) return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
} }
cmd.Println("Netbird up") cmd.Println("netbird up")
time.Sleep(time.Second * 10) time.Sleep(time.Second * 10)
} }
@@ -202,7 +202,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message()) return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
} }
cmd.Println("Netbird down") cmd.Println("netbird down")
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
@@ -216,11 +216,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil { if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message()) return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
} }
cmd.Println("Netbird up") cmd.Println("netbird up")
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
headerPostUp := fmt.Sprintf("----- Netbird post-up - Timestamp: %s", time.Now().Format(time.RFC3339)) headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag)) statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag))
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil { if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
@@ -230,7 +230,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Creating debug bundle...") cmd.Println("Creating debug bundle...")
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration) headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag)) statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
request := &proto.DebugBundleRequest{ request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag, Anonymize: anonymizeFlag,
@@ -250,7 +250,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message()) return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
} }
cmd.Println("Netbird down") cmd.Println("netbird down")
} }
if !initialLevelTrace { if !initialLevelTrace {

View File

@@ -31,7 +31,7 @@ func init() {
var loginCmd = &cobra.Command{ var loginCmd = &cobra.Command{
Use: "login", Use: "login",
Short: "login to the Netbird Management Service (first run)", Short: "login to the NetBird Management Service (first run)",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
if err := setEnvAndFlags(cmd); err != nil { if err := setEnvAndFlags(cmd); err != nil {
return fmt.Errorf("set env and flags: %v", err) return fmt.Errorf("set env and flags: %v", err)

View File

@@ -12,14 +12,15 @@ import (
) )
var logoutCmd = &cobra.Command{ var logoutCmd = &cobra.Command{
Use: "logout", Use: "deregister",
Short: "logout from the Netbird Management Service and delete peer", Aliases: []string{"logout"},
Short: "deregister from the NetBird Management Service and delete peer",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd) SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout()) cmd.SetOut(cmd.OutOrStdout())
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7) ctx, cancel := context.WithTimeout(cmd.Context(), time.Second*15)
defer cancel() defer cancel()
conn, err := DialClientGRPCServer(ctx, daemonAddr) conn, err := DialClientGRPCServer(ctx, daemonAddr)
@@ -44,10 +45,10 @@ var logoutCmd = &cobra.Command{
} }
if _, err := daemonClient.Logout(ctx, req); err != nil { if _, err := daemonClient.Logout(ctx, req); err != nil {
return fmt.Errorf("logout: %v", err) return fmt.Errorf("deregister: %v", err)
} }
cmd.Println("Logged out successfully") cmd.Println("Deregistered successfully")
return nil return nil
}, },
} }

View File

@@ -16,14 +16,14 @@ import (
var profileCmd = &cobra.Command{ var profileCmd = &cobra.Command{
Use: "profile", Use: "profile",
Short: "manage Netbird profiles", Short: "manage NetBird profiles",
Long: `Manage Netbird profiles, allowing you to list, switch, and remove profiles.`, Long: `Manage NetBird profiles, allowing you to list, switch, and remove profiles.`,
} }
var profileListCmd = &cobra.Command{ var profileListCmd = &cobra.Command{
Use: "list", Use: "list",
Short: "list all profiles", Short: "list all profiles",
Long: `List all available profiles in the Netbird client.`, Long: `List all available profiles in the NetBird client.`,
Aliases: []string{"ls"}, Aliases: []string{"ls"},
RunE: listProfilesFunc, RunE: listProfilesFunc,
} }
@@ -31,7 +31,7 @@ var profileListCmd = &cobra.Command{
var profileAddCmd = &cobra.Command{ var profileAddCmd = &cobra.Command{
Use: "add <profile_name>", Use: "add <profile_name>",
Short: "add a new profile", Short: "add a new profile",
Long: `Add a new profile to the Netbird client. The profile name must be unique.`, Long: `Add a new profile to the NetBird client. The profile name must be unique.`,
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
RunE: addProfileFunc, RunE: addProfileFunc,
} }
@@ -39,7 +39,7 @@ var profileAddCmd = &cobra.Command{
var profileRemoveCmd = &cobra.Command{ var profileRemoveCmd = &cobra.Command{
Use: "remove <profile_name>", Use: "remove <profile_name>",
Short: "remove a profile", Short: "remove a profile",
Long: `Remove a profile from the Netbird client. The profile must not be active.`, Long: `Remove a profile from the NetBird client. The profile must not be active.`,
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
RunE: removeProfileFunc, RunE: removeProfileFunc,
} }
@@ -47,7 +47,7 @@ var profileRemoveCmd = &cobra.Command{
var profileSelectCmd = &cobra.Command{ var profileSelectCmd = &cobra.Command{
Use: "select <profile_name>", Use: "select <profile_name>",
Short: "select a profile", Short: "select a profile",
Long: `Select a profile to be the active profile in the Netbird client. The profile must exist.`, Long: `Select a profile to be the active profile in the NetBird client. The profile must exist.`,
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
RunE: selectProfileFunc, RunE: selectProfileFunc,
} }

View File

@@ -119,12 +119,12 @@ func init() {
rootCmd.PersistentFlags().StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]") rootCmd.PersistentFlags().StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]")
rootCmd.PersistentFlags().StringVarP(&managementURL, "management-url", "m", "", fmt.Sprintf("Management Service URL [http|https]://[host]:[port] (default \"%s\")", profilemanager.DefaultManagementURL)) rootCmd.PersistentFlags().StringVarP(&managementURL, "management-url", "m", "", fmt.Sprintf("Management Service URL [http|https]://[host]:[port] (default \"%s\")", profilemanager.DefaultManagementURL))
rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", profilemanager.DefaultAdminURL)) rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", profilemanager.DefaultAdminURL))
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level") rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets NetBird log level")
rootCmd.PersistentFlags().StringSliceVar(&logFiles, "log-file", []string{defaultLogFile}, "sets Netbird log paths written to simultaneously. If `console` is specified the log will be output to stdout. If `syslog` is specified the log will be sent to syslog daemon. You can pass the flag multiple times or separate entries by `,` character") rootCmd.PersistentFlags().StringSliceVar(&logFiles, "log-file", []string{defaultLogFile}, "sets NetBird log paths written to simultaneously. If `console` is specified the log will be output to stdout. If `syslog` is specified the log will be sent to syslog daemon. You can pass the flag multiple times or separate entries by `,` character")
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)") rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
rootCmd.PersistentFlags().StringVar(&setupKeyPath, "setup-key-file", "", "The path to a setup key obtained from the Management Service Dashboard (used to register peer) This is ignored if the setup-key flag is provided.") rootCmd.PersistentFlags().StringVar(&setupKeyPath, "setup-key-file", "", "The path to a setup key obtained from the Management Service Dashboard (used to register peer) This is ignored if the setup-key flag is provided.")
rootCmd.MarkFlagsMutuallyExclusive("setup-key", "setup-key-file") rootCmd.MarkFlagsMutuallyExclusive("setup-key", "setup-key-file")
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.") rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets WireGuard PreSharedKey property. If set, then only peers that have the same key can communicate.")
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device") rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output") rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Overrides the default profile file location") rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Overrides the default profile file location")

View File

@@ -50,10 +50,10 @@ func TestSetFlagsFromEnvVars(t *testing.T) {
} }
cmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil, cmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
`comma separated list of external IPs to map to the Wireguard interface`) `comma separated list of external IPs to map to the WireGuard interface`)
cmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name") cmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "WireGuard interface name")
cmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "Enable Rosenpass feature Rosenpass.") cmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "Enable Rosenpass feature Rosenpass.")
cmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port") cmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "WireGuard interface listening port")
t.Setenv("NB_EXTERNAL_IP_MAP", "abc,dec") t.Setenv("NB_EXTERNAL_IP_MAP", "abc,dec")
t.Setenv("NB_INTERFACE_NAME", "test-name") t.Setenv("NB_INTERFACE_NAME", "test-name")

View File

@@ -19,7 +19,7 @@ import (
var serviceCmd = &cobra.Command{ var serviceCmd = &cobra.Command{
Use: "service", Use: "service",
Short: "manages Netbird service", Short: "manages NetBird service",
} }
var ( var (
@@ -47,7 +47,7 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name") rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
serviceEnvDesc := `Sets extra environment variables for the service. ` + serviceEnvDesc := `Sets extra environment variables for the service. ` +
`You can specify a comma-separated list of KEY=VALUE pairs. ` + `You can specify a comma-separated list of KEY=VALUE pairs. ` +
`E.g. --service-env LOG_LEVEL=debug,CUSTOM_VAR=value` `E.g. --service-env NB_LOG_LEVEL=debug,CUSTOM_VAR=value`
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc) installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
reconfigureCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc) reconfigureCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
@@ -64,7 +64,7 @@ func newSVCConfig() (*service.Config, error) {
config := &service.Config{ config := &service.Config{
Name: serviceName, Name: serviceName,
DisplayName: "Netbird", DisplayName: "Netbird",
Description: "Netbird mesh network client", Description: "NetBird mesh network client",
Option: make(service.KeyValue), Option: make(service.KeyValue),
EnvVars: make(map[string]string), EnvVars: make(map[string]string),
} }

View File

@@ -24,7 +24,7 @@ import (
func (p *program) Start(svc service.Service) error { func (p *program) Start(svc service.Service) error {
// Start should not block. Do the actual work async. // Start should not block. Do the actual work async.
log.Info("starting Netbird service") //nolint log.Info("starting NetBird service") //nolint
// Collect static system and platform information // Collect static system and platform information
system.UpdateStaticInfo() system.UpdateStaticInfo()
@@ -97,7 +97,7 @@ func (p *program) Stop(srv service.Service) error {
} }
time.Sleep(time.Second * 2) time.Sleep(time.Second * 2)
log.Info("stopped Netbird service") //nolint log.Info("stopped NetBird service") //nolint
return nil return nil
} }
@@ -131,7 +131,7 @@ func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel
var runCmd = &cobra.Command{ var runCmd = &cobra.Command{
Use: "run", Use: "run",
Short: "runs Netbird as service", Short: "runs NetBird as service",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context()) ctx, cancel := context.WithCancel(cmd.Context())
@@ -149,7 +149,7 @@ var runCmd = &cobra.Command{
var startCmd = &cobra.Command{ var startCmd = &cobra.Command{
Use: "start", Use: "start",
Short: "starts Netbird service", Short: "starts NetBird service",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context()) ctx, cancel := context.WithCancel(cmd.Context())
s, err := setupServiceControlCommand(cmd, ctx, cancel) s, err := setupServiceControlCommand(cmd, ctx, cancel)
@@ -160,14 +160,14 @@ var startCmd = &cobra.Command{
if err := s.Start(); err != nil { if err := s.Start(); err != nil {
return fmt.Errorf("start service: %w", err) return fmt.Errorf("start service: %w", err)
} }
cmd.Println("Netbird service has been started") cmd.Println("NetBird service has been started")
return nil return nil
}, },
} }
var stopCmd = &cobra.Command{ var stopCmd = &cobra.Command{
Use: "stop", Use: "stop",
Short: "stops Netbird service", Short: "stops NetBird service",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context()) ctx, cancel := context.WithCancel(cmd.Context())
s, err := setupServiceControlCommand(cmd, ctx, cancel) s, err := setupServiceControlCommand(cmd, ctx, cancel)
@@ -178,14 +178,14 @@ var stopCmd = &cobra.Command{
if err := s.Stop(); err != nil { if err := s.Stop(); err != nil {
return fmt.Errorf("stop service: %w", err) return fmt.Errorf("stop service: %w", err)
} }
cmd.Println("Netbird service has been stopped") cmd.Println("NetBird service has been stopped")
return nil return nil
}, },
} }
var restartCmd = &cobra.Command{ var restartCmd = &cobra.Command{
Use: "restart", Use: "restart",
Short: "restarts Netbird service", Short: "restarts NetBird service",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context()) ctx, cancel := context.WithCancel(cmd.Context())
s, err := setupServiceControlCommand(cmd, ctx, cancel) s, err := setupServiceControlCommand(cmd, ctx, cancel)
@@ -196,14 +196,14 @@ var restartCmd = &cobra.Command{
if err := s.Restart(); err != nil { if err := s.Restart(); err != nil {
return fmt.Errorf("restart service: %w", err) return fmt.Errorf("restart service: %w", err)
} }
cmd.Println("Netbird service has been restarted") cmd.Println("NetBird service has been restarted")
return nil return nil
}, },
} }
var svcStatusCmd = &cobra.Command{ var svcStatusCmd = &cobra.Command{
Use: "status", Use: "status",
Short: "shows Netbird service status", Short: "shows NetBird service status",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context()) ctx, cancel := context.WithCancel(cmd.Context())
s, err := setupServiceControlCommand(cmd, ctx, cancel) s, err := setupServiceControlCommand(cmd, ctx, cancel)
@@ -228,7 +228,7 @@ var svcStatusCmd = &cobra.Command{
statusText = fmt.Sprintf("Unknown (%d)", status) statusText = fmt.Sprintf("Unknown (%d)", status)
} }
cmd.Printf("Netbird service status: %s\n", statusText) cmd.Printf("NetBird service status: %s\n", statusText)
return nil return nil
}, },
} }

View File

@@ -99,7 +99,7 @@ func createServiceConfigForInstall() (*service.Config, error) {
var installCmd = &cobra.Command{ var installCmd = &cobra.Command{
Use: "install", Use: "install",
Short: "installs Netbird service", Short: "installs NetBird service",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
if err := setupServiceCommand(cmd); err != nil { if err := setupServiceCommand(cmd); err != nil {
return err return err
@@ -122,14 +122,14 @@ var installCmd = &cobra.Command{
return fmt.Errorf("install service: %w", err) return fmt.Errorf("install service: %w", err)
} }
cmd.Println("Netbird service has been installed") cmd.Println("NetBird service has been installed")
return nil return nil
}, },
} }
var uninstallCmd = &cobra.Command{ var uninstallCmd = &cobra.Command{
Use: "uninstall", Use: "uninstall",
Short: "uninstalls Netbird service from system", Short: "uninstalls NetBird service from system",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
if err := setupServiceCommand(cmd); err != nil { if err := setupServiceCommand(cmd); err != nil {
return err return err
@@ -152,15 +152,15 @@ var uninstallCmd = &cobra.Command{
return fmt.Errorf("uninstall service: %w", err) return fmt.Errorf("uninstall service: %w", err)
} }
cmd.Println("Netbird service has been uninstalled") cmd.Println("NetBird service has been uninstalled")
return nil return nil
}, },
} }
var reconfigureCmd = &cobra.Command{ var reconfigureCmd = &cobra.Command{
Use: "reconfigure", Use: "reconfigure",
Short: "reconfigures Netbird service with new settings", Short: "reconfigures NetBird service with new settings",
Long: `Reconfigures the Netbird service with new settings without manual uninstall/install. Long: `Reconfigures the NetBird service with new settings without manual uninstall/install.
This command will temporarily stop the service, update its configuration, and restart it if it was running.`, This command will temporarily stop the service, update its configuration, and restart it if it was running.`,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
if err := setupServiceCommand(cmd); err != nil { if err := setupServiceCommand(cmd); err != nil {
@@ -186,7 +186,7 @@ This command will temporarily stop the service, update its configuration, and re
} }
if wasRunning { if wasRunning {
cmd.Println("Stopping Netbird service...") cmd.Println("Stopping NetBird service...")
if err := s.Stop(); err != nil { if err := s.Stop(); err != nil {
cmd.Printf("Warning: failed to stop service: %v\n", err) cmd.Printf("Warning: failed to stop service: %v\n", err)
} }
@@ -203,13 +203,13 @@ This command will temporarily stop the service, update its configuration, and re
} }
if wasRunning { if wasRunning {
cmd.Println("Starting Netbird service...") cmd.Println("Starting NetBird service...")
if err := s.Start(); err != nil { if err := s.Start(); err != nil {
return fmt.Errorf("start service after reconfigure: %w", err) return fmt.Errorf("start service after reconfigure: %w", err)
} }
cmd.Println("Netbird service has been reconfigured and started") cmd.Println("NetBird service has been reconfigured and started")
} else { } else {
cmd.Println("Netbird service has been reconfigured") cmd.Println("NetBird service has been reconfigured")
} }
return nil return nil

View File

@@ -59,8 +59,8 @@ var sshCmd = &cobra.Command{
ctx := internal.CtxInitState(cmd.Context()) ctx := internal.CtxInitState(cmd.Context())
pm := profilemanager.NewProfileManager() sm := profilemanager.NewServiceManager(configPath)
activeProf, err := pm.GetActiveProfile() activeProf, err := sm.GetActiveProfileState()
if err != nil { if err != nil {
return fmt.Errorf("get active profile: %v", err) return fmt.Errorf("get active profile: %v", err)
} }

View File

@@ -17,7 +17,7 @@ var (
var stateCmd = &cobra.Command{ var stateCmd = &cobra.Command{
Use: "state", Use: "state",
Short: "Manage daemon state", Short: "Manage daemon state",
Long: "Provides commands for managing and inspecting the Netbird daemon state.", Long: "Provides commands for managing and inspecting the NetBird daemon state.",
} }
var stateListCmd = &cobra.Command{ var stateListCmd = &cobra.Command{

View File

@@ -11,6 +11,7 @@ import (
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
@@ -97,6 +98,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
permissionsManagerMock := permissions.NewMockManager(ctrl) permissionsManagerMock := permissions.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
settingsMockManager.EXPECT(). settingsMockManager.EXPECT().
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()). GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
@@ -108,7 +110,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
t.Fatal(err) t.Fatal(err)
} }
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{}) mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@@ -53,15 +53,15 @@ var (
upCmd = &cobra.Command{ upCmd = &cobra.Command{
Use: "up", Use: "up",
Short: "install, login and start Netbird client", Short: "install, login and start NetBird client",
RunE: upFunc, RunE: upFunc,
} }
) )
func init() { func init() {
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground") upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name") upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "WireGuard interface name")
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port") upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "WireGuard interface listening port")
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor, upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+ `Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`, `E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
@@ -79,7 +79,7 @@ func init() {
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc) upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc) upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location. ") upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) NetBird config file location. ")
} }

View File

@@ -9,7 +9,7 @@ import (
var ( var (
versionCmd = &cobra.Command{ versionCmd = &cobra.Command{
Use: "version", Use: "version",
Short: "prints Netbird version", Short: "prints NetBird version",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
cmd.SetOut(cmd.OutOrStdout()) cmd.SetOut(cmd.OutOrStdout())
cmd.Println(version.NetbirdVersion()) cmd.Println(version.NetbirdVersion())

View File

@@ -3,7 +3,7 @@
!define WEB_SITE "Netbird.io" !define WEB_SITE "Netbird.io"
!define VERSION $%APPVER% !define VERSION $%APPVER%
!define COPYRIGHT "Netbird Authors, 2022" !define COPYRIGHT "Netbird Authors, 2022"
!define DESCRIPTION "A WireGuard®-based mesh network that connects your devices into a single private network" !define DESCRIPTION "Connect your devices into a secure WireGuard-based overlay network with SSO, MFA, and granular access controls."
!define INSTALLER_NAME "netbird-installer.exe" !define INSTALLER_NAME "netbird-installer.exe"
!define MAIN_APP_EXE "Netbird" !define MAIN_APP_EXE "Netbird"
!define ICON "ui\\assets\\netbird.ico" !define ICON "ui\\assets\\netbird.ico"
@@ -59,9 +59,15 @@ ShowInstDetails Show
!define MUI_UNICON "${ICON}" !define MUI_UNICON "${ICON}"
!define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}" !define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}"
!define MUI_UNWELCOMEFINISHPAGE_BITMAP "${BANNER}" !define MUI_UNWELCOMEFINISHPAGE_BITMAP "${BANNER}"
!define MUI_FINISHPAGE_RUN !ifndef ARCH
!define MUI_FINISHPAGE_RUN_TEXT "Start ${UI_APP_NAME}" !define ARCH "amd64"
!define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink" !endif
!if ${ARCH} == "amd64"
!define MUI_FINISHPAGE_RUN
!define MUI_FINISHPAGE_RUN_TEXT "Start ${UI_APP_NAME}"
!define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink"
!endif
###################################################################### ######################################################################
!define MUI_ABORTWARNING !define MUI_ABORTWARNING
@@ -213,7 +219,15 @@ Section -MainProgram
${INSTALL_TYPE} ${INSTALL_TYPE}
# SetOverwrite ifnewer # SetOverwrite ifnewer
SetOutPath "$INSTDIR" SetOutPath "$INSTDIR"
File /r "..\\dist\\netbird_windows_amd64\\" !ifndef ARCH
!define ARCH "amd64"
!endif
!if ${ARCH} == "arm64"
File /r "..\\dist\\netbird_windows_arm64\\"
!else
File /r "..\\dist\\netbird_windows_amd64\\"
!endif
SectionEnd SectionEnd
###################################################################### ######################################################################
@@ -292,7 +306,9 @@ DetailPrint "Deleting application files..."
Delete "$INSTDIR\${UI_APP_EXE}" Delete "$INSTDIR\${UI_APP_EXE}"
Delete "$INSTDIR\${MAIN_APP_EXE}" Delete "$INSTDIR\${MAIN_APP_EXE}"
Delete "$INSTDIR\wintun.dll" Delete "$INSTDIR\wintun.dll"
!if ${ARCH} == "amd64"
Delete "$INSTDIR\opengl32.dll" Delete "$INSTDIR\opengl32.dll"
!endif
DetailPrint "Removing application directory..." DetailPrint "Removing application directory..."
RmDir /r "$INSTDIR" RmDir /r "$INSTDIR"
@@ -314,8 +330,10 @@ DetailPrint "Uninstallation finished."
SectionEnd SectionEnd
!if ${ARCH} == "amd64"
Function LaunchLink Function LaunchLink
SetShellVarContext all SetShellVarContext all
SetOutPath $INSTDIR SetOutPath $INSTDIR
ShellExecAsUser::ShellExecAsUser "" "$DESKTOP\${APP_NAME}.lnk" ShellExecAsUser::ShellExecAsUser "" "$DESKTOP\${APP_NAME}.lnk"
FunctionEnd FunctionEnd
!endif

View File

@@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/netip"
"runtime" "runtime"
"runtime/debug" "runtime/debug"
"strings" "strings"
@@ -70,7 +71,7 @@ func (c *ConnectClient) RunOnAndroid(
tunAdapter device.TunAdapter, tunAdapter device.TunAdapter,
iFaceDiscover stdnet.ExternalIFaceDiscover, iFaceDiscover stdnet.ExternalIFaceDiscover,
networkChangeListener listener.NetworkChangeListener, networkChangeListener listener.NetworkChangeListener,
dnsAddresses []string, dnsAddresses []netip.AddrPort,
dnsReadyListener dns.ReadyListener, dnsReadyListener dns.ReadyListener,
) error { ) error {
// in case of non Android os these variables will be nil // in case of non Android os these variables will be nil

View File

@@ -16,7 +16,7 @@ const (
) )
type resolvConf struct { type resolvConf struct {
nameServers []string nameServers []netip.Addr
searchDomains []string searchDomains []string
others []string others []string
} }
@@ -36,7 +36,7 @@ func parseBackupResolvConf() (*resolvConf, error) {
func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) { func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
rconf := &resolvConf{ rconf := &resolvConf{
searchDomains: make([]string, 0), searchDomains: make([]string, 0),
nameServers: make([]string, 0), nameServers: make([]netip.Addr, 0),
others: make([]string, 0), others: make([]string, 0),
} }
@@ -94,7 +94,11 @@ func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
if len(splitLines) != 2 { if len(splitLines) != 2 {
continue continue
} }
rconf.nameServers = append(rconf.nameServers, splitLines[1]) if addr, err := netip.ParseAddr(splitLines[1]); err == nil {
rconf.nameServers = append(rconf.nameServers, addr.Unmap())
} else {
log.Warnf("invalid nameserver address in resolv.conf: %s, skipping", splitLines[1])
}
continue continue
} }
@@ -104,31 +108,3 @@ func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
} }
return rconf, nil return rconf, nil
} }
// removeFirstNbNameserver removes the given nameserver from the given file if it is in the first position
// and writes the file back to the original location
func removeFirstNbNameserver(filename string, nameserverIP netip.Addr) error {
resolvConf, err := parseResolvConfFile(filename)
if err != nil {
return fmt.Errorf("parse backup resolv.conf: %w", err)
}
content, err := os.ReadFile(filename)
if err != nil {
return fmt.Errorf("read %s: %w", filename, err)
}
if len(resolvConf.nameServers) > 1 && resolvConf.nameServers[0] == nameserverIP.String() {
newContent := strings.Replace(string(content), fmt.Sprintf("nameserver %s\n", nameserverIP), "", 1)
stat, err := os.Stat(filename)
if err != nil {
return fmt.Errorf("stat %s: %w", filename, err)
}
if err := os.WriteFile(filename, []byte(newContent), stat.Mode()); err != nil {
return fmt.Errorf("write %s: %w", filename, err)
}
}
return nil
}

View File

@@ -3,13 +3,9 @@
package dns package dns
import ( import (
"net/netip"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func Test_parseResolvConf(t *testing.T) { func Test_parseResolvConf(t *testing.T) {
@@ -99,9 +95,13 @@ options debug
t.Errorf("invalid parse result for search domains, expected: %v, got: %v", testCase.expectedSearch, cfg.searchDomains) t.Errorf("invalid parse result for search domains, expected: %v, got: %v", testCase.expectedSearch, cfg.searchDomains)
} }
ok = compareLists(cfg.nameServers, testCase.expectedNS) nsStrings := make([]string, len(cfg.nameServers))
for i, ns := range cfg.nameServers {
nsStrings[i] = ns.String()
}
ok = compareLists(nsStrings, testCase.expectedNS)
if !ok { if !ok {
t.Errorf("invalid parse result for ns domains, expected: %v, got: %v", testCase.expectedNS, cfg.nameServers) t.Errorf("invalid parse result for ns domains, expected: %v, got: %v", testCase.expectedNS, nsStrings)
} }
ok = compareLists(cfg.others, testCase.expectedOther) ok = compareLists(cfg.others, testCase.expectedOther)
@@ -176,87 +176,3 @@ nameserver 192.168.0.1
t.Errorf("unexpected resolv.conf content: %v", cfg) t.Errorf("unexpected resolv.conf content: %v", cfg)
} }
} }
func TestRemoveFirstNbNameserver(t *testing.T) {
testCases := []struct {
name string
content string
ipToRemove string
expected string
}{
{
name: "Unrelated nameservers with comments and options",
content: `# This is a comment
options rotate
nameserver 1.1.1.1
# Another comment
nameserver 8.8.4.4
search example.com`,
ipToRemove: "9.9.9.9",
expected: `# This is a comment
options rotate
nameserver 1.1.1.1
# Another comment
nameserver 8.8.4.4
search example.com`,
},
{
name: "First nameserver matches",
content: `search example.com
nameserver 9.9.9.9
# oof, a comment
nameserver 8.8.4.4
options attempts:5`,
ipToRemove: "9.9.9.9",
expected: `search example.com
# oof, a comment
nameserver 8.8.4.4
options attempts:5`,
},
{
name: "Target IP not the first nameserver",
// nolint:dupword
content: `# Comment about the first nameserver
nameserver 8.8.4.4
# Comment before our target
nameserver 9.9.9.9
options timeout:2`,
ipToRemove: "9.9.9.9",
// nolint:dupword
expected: `# Comment about the first nameserver
nameserver 8.8.4.4
# Comment before our target
nameserver 9.9.9.9
options timeout:2`,
},
{
name: "Only nameserver matches",
content: `options debug
nameserver 9.9.9.9
search localdomain`,
ipToRemove: "9.9.9.9",
expected: `options debug
nameserver 9.9.9.9
search localdomain`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tempDir := t.TempDir()
tempFile := filepath.Join(tempDir, "resolv.conf")
err := os.WriteFile(tempFile, []byte(tc.content), 0644)
assert.NoError(t, err)
ip, err := netip.ParseAddr(tc.ipToRemove)
require.NoError(t, err, "Failed to parse IP address")
err = removeFirstNbNameserver(tempFile, ip)
assert.NoError(t, err)
content, err := os.ReadFile(tempFile)
assert.NoError(t, err)
assert.Equal(t, tc.expected, string(content), "The resulting content should match the expected output.")
})
}
}

View File

@@ -146,7 +146,7 @@ func isNbParamsMissing(nbSearchDomains []string, nbNameserverIP netip.Addr, rCon
return true return true
} }
if rConf.nameServers[0] != nbNameserverIP.String() { if rConf.nameServers[0] != nbNameserverIP {
return true return true
} }

View File

@@ -29,7 +29,7 @@ type fileConfigurator struct {
repair *repair repair *repair
originalPerms os.FileMode originalPerms os.FileMode
nbNameserverIP netip.Addr nbNameserverIP netip.Addr
originalNameservers []string originalNameservers []netip.Addr
} }
func newFileConfigurator() (*fileConfigurator, error) { func newFileConfigurator() (*fileConfigurator, error) {
@@ -70,7 +70,7 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *st
} }
// getOriginalNameservers returns the nameservers that were found in the original resolv.conf // getOriginalNameservers returns the nameservers that were found in the original resolv.conf
func (f *fileConfigurator) getOriginalNameservers() []string { func (f *fileConfigurator) getOriginalNameservers() []netip.Addr {
return f.originalNameservers return f.originalNameservers
} }
@@ -128,20 +128,14 @@ func (f *fileConfigurator) backup() error {
} }
func (f *fileConfigurator) restore() error { func (f *fileConfigurator) restore() error {
err := removeFirstNbNameserver(fileDefaultResolvConfBackupLocation, f.nbNameserverIP) if err := copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath); err != nil {
if err != nil {
log.Errorf("Failed to remove netbird nameserver from %s on backup restore: %s", fileDefaultResolvConfBackupLocation, err)
}
err = copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath)
if err != nil {
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err) return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
} }
return os.RemoveAll(fileDefaultResolvConfBackupLocation) return os.RemoveAll(fileDefaultResolvConfBackupLocation)
} }
func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error { func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress netip.Addr) error {
resolvConf, err := parseDefaultResolvConf() resolvConf, err := parseDefaultResolvConf()
if err != nil { if err != nil {
return fmt.Errorf("parse current resolv.conf: %w", err) return fmt.Errorf("parse current resolv.conf: %w", err)
@@ -152,16 +146,9 @@ func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Add
return restoreResolvConfFile() return restoreResolvConfFile()
} }
currentDNSAddress, err := netip.ParseAddr(resolvConf.nameServers[0])
// not a valid first nameserver -> restore
if err != nil {
log.Errorf("restoring unclean shutdown: parse dns address %s failed: %s", resolvConf.nameServers[0], err)
return restoreResolvConfFile()
}
// current address is still netbird's non-available dns address -> restore // current address is still netbird's non-available dns address -> restore
// comparing parsed addresses only, to remove ambiguity currentDNSAddress := resolvConf.nameServers[0]
if currentDNSAddress.String() == storedDNSAddress.String() { if currentDNSAddress == storedDNSAddress {
return restoreResolvConfFile() return restoreResolvConfFile()
} }

View File

@@ -239,7 +239,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
} else if inServerAddressesArray { } else if inServerAddressesArray {
address := strings.Split(line, " : ")[1] address := strings.Split(line, " : ")[1]
if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() { if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() {
dnsSettings.ServerIP = ip dnsSettings.ServerIP = ip.Unmap()
inServerAddressesArray = false // Stop reading after finding the first IPv4 address inServerAddressesArray = false // Stop reading after finding the first IPv4 address
} }
} }
@@ -250,7 +250,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
} }
// default to 53 port // default to 53 port
dnsSettings.ServerPort = defaultPort dnsSettings.ServerPort = DefaultPort
return dnsSettings, nil return dnsSettings, nil
} }

View File

@@ -42,7 +42,7 @@ func (t osManagerType) String() string {
type restoreHostManager interface { type restoreHostManager interface {
hostManager hostManager
restoreUncleanShutdownDNS(*netip.Addr) error restoreUncleanShutdownDNS(netip.Addr) error
} }
func newHostManager(wgInterface string) (hostManager, error) { func newHostManager(wgInterface string) (hostManager, error) {
@@ -130,8 +130,9 @@ func checkStub() bool {
return true return true
} }
systemdResolvedAddr := netip.AddrFrom4([4]byte{127, 0, 0, 53}) // 127.0.0.53
for _, ns := range rConf.nameServers { for _, ns := range rConf.nameServers {
if ns == "127.0.0.53" { if ns == systemdResolvedAddr {
return true return true
} }
} }

View File

@@ -64,9 +64,10 @@ const (
) )
type registryConfigurator struct { type registryConfigurator struct {
guid string guid string
routingAll bool routingAll bool
gpo bool gpo bool
nrptEntryCount int
} }
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) { func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
@@ -177,7 +178,11 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP) log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
} }
if err := stateManager.UpdateState(&ShutdownState{Guid: r.guid, GPO: r.gpo}); err != nil { if err := stateManager.UpdateState(&ShutdownState{
Guid: r.guid,
GPO: r.gpo,
NRPTEntryCount: r.nrptEntryCount,
}); err != nil {
log.Errorf("failed to update shutdown state: %s", err) log.Errorf("failed to update shutdown state: %s", err)
} }
@@ -193,13 +198,24 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
} }
if len(matchDomains) != 0 { if len(matchDomains) != 0 {
if err := r.addDNSMatchPolicy(matchDomains, config.ServerIP); err != nil { count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
if err != nil {
return fmt.Errorf("add dns match policy: %w", err) return fmt.Errorf("add dns match policy: %w", err)
} }
r.nrptEntryCount = count
} else { } else {
if err := r.removeDNSMatchPolicies(); err != nil { if err := r.removeDNSMatchPolicies(); err != nil {
return fmt.Errorf("remove dns match policies: %w", err) return fmt.Errorf("remove dns match policies: %w", err)
} }
r.nrptEntryCount = 0
}
if err := stateManager.UpdateState(&ShutdownState{
Guid: r.guid,
GPO: r.gpo,
NRPTEntryCount: r.nrptEntryCount,
}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
} }
if err := r.updateSearchDomains(searchDomains); err != nil { if err := r.updateSearchDomains(searchDomains); err != nil {
@@ -216,32 +232,38 @@ func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
return fmt.Errorf("adding dns setup for all failed: %w", err) return fmt.Errorf("adding dns setup for all failed: %w", err)
} }
r.routingAll = true r.routingAll = true
log.Infof("configured %s:53 as main DNS forwarder for this peer", ip) log.Infof("configured %s:%d as main DNS forwarder for this peer", ip, DefaultPort)
return nil return nil
} }
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) error { func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) (int, error) {
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored // if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745 // see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
if r.gpo { for i, domain := range domains {
if err := r.configureDNSPolicy(gpoDnsPolicyConfigMatchPath, domains, ip); err != nil { policyPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
return fmt.Errorf("configure GPO DNS policy: %w", err) if r.gpo {
policyPath = fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
} }
singleDomain := []string{domain}
if err := r.configureDNSPolicy(policyPath, singleDomain, ip); err != nil {
return i, fmt.Errorf("configure DNS policy for domain %s: %w", domain, err)
}
log.Debugf("added NRPT entry for domain: %s", domain)
}
if r.gpo {
if err := refreshGroupPolicy(); err != nil { if err := refreshGroupPolicy(); err != nil {
log.Warnf("failed to refresh group policy: %v", err) log.Warnf("failed to refresh group policy: %v", err)
} }
} else {
if err := r.configureDNSPolicy(dnsPolicyConfigMatchPath, domains, ip); err != nil {
return fmt.Errorf("configure local DNS policy: %w", err)
}
} }
log.Infof("added %d match domains. Domain list: %s", len(domains), domains) log.Infof("added %d separate NRPT entries. Domain list: %s", len(domains), domains)
return nil return len(domains), nil
} }
// configureDNSPolicy handles the actual configuration of a DNS policy at the specified path
func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error { func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error {
if err := removeRegistryKeyFromDNSPolicyConfig(policyPath); err != nil { if err := removeRegistryKeyFromDNSPolicyConfig(policyPath); err != nil {
return fmt.Errorf("remove existing dns policy: %w", err) return fmt.Errorf("remove existing dns policy: %w", err)
@@ -374,12 +396,25 @@ func (r *registryConfigurator) restoreHostDNS() error {
func (r *registryConfigurator) removeDNSMatchPolicies() error { func (r *registryConfigurator) removeDNSMatchPolicies() error {
var merr *multierror.Error var merr *multierror.Error
// Try to remove the base entries (for backward compatibility)
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil { if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove local registry key: %w", err)) merr = multierror.Append(merr, fmt.Errorf("remove local base entry: %w", err))
}
if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove GPO base entry: %w", err))
} }
if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil { for i := 0; i < r.nrptEntryCount; i++ {
merr = multierror.Append(merr, fmt.Errorf("remove GPO registry key: %w", err)) localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
if err := removeRegistryKeyFromDNSPolicyConfig(localPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove local entry %d: %w", i, err))
}
if err := removeRegistryKeyFromDNSPolicyConfig(gpoPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove GPO entry %d: %w", i, err))
}
} }
if err := refreshGroupPolicy(); err != nil { if err := refreshGroupPolicy(); err != nil {

View File

@@ -1,38 +1,31 @@
package dns package dns
import ( import (
"fmt"
"net/netip" "net/netip"
"sync" "sync"
log "github.com/sirupsen/logrus"
) )
type hostsDNSHolder struct { type hostsDNSHolder struct {
unprotectedDNSList map[string]struct{} unprotectedDNSList map[netip.AddrPort]struct{}
mutex sync.RWMutex mutex sync.RWMutex
} }
func newHostsDNSHolder() *hostsDNSHolder { func newHostsDNSHolder() *hostsDNSHolder {
return &hostsDNSHolder{ return &hostsDNSHolder{
unprotectedDNSList: make(map[string]struct{}), unprotectedDNSList: make(map[netip.AddrPort]struct{}),
} }
} }
func (h *hostsDNSHolder) set(list []string) { func (h *hostsDNSHolder) set(list []netip.AddrPort) {
h.mutex.Lock() h.mutex.Lock()
h.unprotectedDNSList = make(map[string]struct{}) h.unprotectedDNSList = make(map[netip.AddrPort]struct{})
for _, dns := range list { for _, addrPort := range list {
dnsAddr, err := h.normalizeAddress(dns) h.unprotectedDNSList[addrPort] = struct{}{}
if err != nil {
continue
}
h.unprotectedDNSList[dnsAddr] = struct{}{}
} }
h.mutex.Unlock() h.mutex.Unlock()
} }
func (h *hostsDNSHolder) get() map[string]struct{} { func (h *hostsDNSHolder) get() map[netip.AddrPort]struct{} {
h.mutex.RLock() h.mutex.RLock()
l := h.unprotectedDNSList l := h.unprotectedDNSList
h.mutex.RUnlock() h.mutex.RUnlock()
@@ -40,24 +33,10 @@ func (h *hostsDNSHolder) get() map[string]struct{} {
} }
//nolint:unused //nolint:unused
func (h *hostsDNSHolder) isContain(upstream string) bool { func (h *hostsDNSHolder) contains(upstream netip.AddrPort) bool {
h.mutex.RLock() h.mutex.RLock()
defer h.mutex.RUnlock() defer h.mutex.RUnlock()
_, ok := h.unprotectedDNSList[upstream] _, ok := h.unprotectedDNSList[upstream]
return ok return ok
} }
func (h *hostsDNSHolder) normalizeAddress(addr string) (string, error) {
a, err := netip.ParseAddr(addr)
if err != nil {
log.Errorf("invalid upstream IP address: %s, error: %s", addr, err)
return "", err
}
if a.Is4() {
return fmt.Sprintf("%s:53", addr), nil
} else {
return fmt.Sprintf("[%s]:53", addr), nil
}
}

View File

@@ -50,7 +50,7 @@ func (m *MockServer) DnsIP() netip.Addr {
return netip.MustParseAddr("100.10.254.255") return netip.MustParseAddr("100.10.254.255")
} }
func (m *MockServer) OnUpdatedHostDNSServer(strings []string) { func (m *MockServer) OnUpdatedHostDNSServer(addrs []netip.AddrPort) {
// TODO implement me // TODO implement me
panic("implement me") panic("implement me")
} }

View File

@@ -245,7 +245,7 @@ func (n *networkManagerDbusConfigurator) deleteConnectionSettings() error {
return nil return nil
} }
func (n *networkManagerDbusConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error { func (n *networkManagerDbusConfigurator) restoreUncleanShutdownDNS(netip.Addr) error {
if err := n.restoreHostDNS(); err != nil { if err := n.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via network-manager: %w", err) return fmt.Errorf("restoring dns via network-manager: %w", err)
} }

View File

@@ -40,7 +40,7 @@ type resolvconf struct {
implType resolvconfType implType resolvconfType
originalSearchDomains []string originalSearchDomains []string
originalNameServers []string originalNameServers []netip.Addr
othersConfigs []string othersConfigs []string
} }
@@ -110,7 +110,7 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *stateman
return nil return nil
} }
func (r *resolvconf) getOriginalNameservers() []string { func (r *resolvconf) getOriginalNameservers() []netip.Addr {
return r.originalNameServers return r.originalNameServers
} }
@@ -158,7 +158,7 @@ func (r *resolvconf) applyConfig(content bytes.Buffer) error {
return nil return nil
} }
func (r *resolvconf) restoreUncleanShutdownDNS(*netip.Addr) error { func (r *resolvconf) restoreUncleanShutdownDNS(netip.Addr) error {
if err := r.restoreHostDNS(); err != nil { if err := r.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns for interface %s: %w", r.ifaceName, err) return fmt.Errorf("restoring dns for interface %s: %w", r.ifaceName, err)
} }

View File

@@ -42,7 +42,7 @@ type Server interface {
Stop() Stop()
DnsIP() netip.Addr DnsIP() netip.Addr
UpdateDNSServer(serial uint64, update nbdns.Config) error UpdateDNSServer(serial uint64, update nbdns.Config) error
OnUpdatedHostDNSServer(strings []string) OnUpdatedHostDNSServer(addrs []netip.AddrPort)
SearchDomains() []string SearchDomains() []string
ProbeAvailability() ProbeAvailability()
} }
@@ -55,7 +55,7 @@ type nsGroupsByDomain struct {
// hostManagerWithOriginalNS extends the basic hostManager interface // hostManagerWithOriginalNS extends the basic hostManager interface
type hostManagerWithOriginalNS interface { type hostManagerWithOriginalNS interface {
hostManager hostManager
getOriginalNameservers() []string getOriginalNameservers() []netip.Addr
} }
// DefaultServer dns server object // DefaultServer dns server object
@@ -136,7 +136,7 @@ func NewDefaultServer(
func NewDefaultServerPermanentUpstream( func NewDefaultServerPermanentUpstream(
ctx context.Context, ctx context.Context,
wgInterface WGIface, wgInterface WGIface,
hostsDnsList []string, hostsDnsList []netip.AddrPort,
config nbdns.Config, config nbdns.Config,
listener listener.NetworkChangeListener, listener listener.NetworkChangeListener,
statusRecorder *peer.Status, statusRecorder *peer.Status,
@@ -144,6 +144,7 @@ func NewDefaultServerPermanentUpstream(
) *DefaultServer { ) *DefaultServer {
log.Debugf("host dns address list is: %v", hostsDnsList) log.Debugf("host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys) ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
ds.hostsDNSHolder.set(hostsDnsList) ds.hostsDNSHolder.set(hostsDnsList)
ds.permanent = true ds.permanent = true
ds.addHostRootZone() ds.addHostRootZone()
@@ -340,7 +341,7 @@ func (s *DefaultServer) disableDNS() error {
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones // OnUpdatedHostDNSServer update the DNS servers addresses for root zones
// It will be applied if the mgm server do not enforce DNS settings for root zone // It will be applied if the mgm server do not enforce DNS settings for root zone
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) { func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []netip.AddrPort) {
s.hostsDNSHolder.set(hostsDnsList) s.hostsDNSHolder.set(hostsDnsList)
// Check if there's any root handler // Check if there's any root handler
@@ -461,7 +462,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort()) s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() { if s.service.RuntimePort() != DefaultPort && !s.hostManager.supportCustomPort() {
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " + log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver") "Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
s.currentConfig.RouteAll = false s.currentConfig.RouteAll = false
@@ -581,14 +582,13 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
} }
for _, ns := range originalNameservers { for _, ns := range originalNameservers {
if ns == config.ServerIP.String() { if ns == config.ServerIP {
log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, config.ServerIP) log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, config.ServerIP)
continue continue
} }
ns = formatAddr(ns, defaultPort) addrPort := netip.AddrPortFrom(ns, DefaultPort)
handler.upstreamServers = append(handler.upstreamServers, addrPort)
handler.upstreamServers = append(handler.upstreamServers, ns)
} }
handler.deactivate = func(error) { /* always active */ } handler.deactivate = func(error) { /* always active */ }
handler.reactivate = func() { /* always active */ } handler.reactivate = func() { /* always active */ }
@@ -695,7 +695,13 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String()) ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
continue continue
} }
handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns))
if ns.IP == s.service.RuntimeIP() {
log.Warnf("skipping nameserver %s as it matches our DNS server IP, preventing potential loop", ns.IP)
continue
}
handler.upstreamServers = append(handler.upstreamServers, ns.AddrPort())
} }
if len(handler.upstreamServers) == 0 { if len(handler.upstreamServers) == 0 {
@@ -770,18 +776,6 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
s.dnsMuxMap = muxUpdateMap s.dnsMuxMap = muxUpdateMap
} }
func getNSHostPort(ns nbdns.NameServer) string {
return formatAddr(ns.IP.String(), ns.Port)
}
// formatAddr formats a nameserver address with port, handling IPv6 addresses properly
func formatAddr(address string, port int) string {
if ip, err := netip.ParseAddr(address); err == nil && ip.Is6() {
return fmt.Sprintf("[%s]:%d", address, port)
}
return fmt.Sprintf("%s:%d", address, port)
}
// upstreamCallbacks returns two functions, the first one is used to deactivate // upstreamCallbacks returns two functions, the first one is used to deactivate
// the upstream resolver from the configuration, the second one is used to // the upstream resolver from the configuration, the second one is used to
// reactivate it. Not allowed to call reactivate before deactivate. // reactivate it. Not allowed to call reactivate before deactivate.
@@ -879,10 +873,7 @@ func (s *DefaultServer) addHostRootZone() {
return return
} }
handler.upstreamServers = make([]string, 0) handler.upstreamServers = maps.Keys(hostDNSServers)
for k := range hostDNSServers {
handler.upstreamServers = append(handler.upstreamServers, k)
}
handler.deactivate = func(error) {} handler.deactivate = func(error) {}
handler.reactivate = func() {} handler.reactivate = func() {}
@@ -893,9 +884,9 @@ func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
var states []peer.NSGroupState var states []peer.NSGroupState
for _, group := range groups { for _, group := range groups {
var servers []string var servers []netip.AddrPort
for _, ns := range group.NameServers { for _, ns := range group.NameServers {
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port)) servers = append(servers, ns.AddrPort())
} }
state := peer.NSGroupState{ state := peer.NSGroupState{
@@ -927,7 +918,7 @@ func (s *DefaultServer) updateNSState(nsGroup *nbdns.NameServerGroup, err error,
func generateGroupKey(nsGroup *nbdns.NameServerGroup) string { func generateGroupKey(nsGroup *nbdns.NameServerGroup) string {
var servers []string var servers []string
for _, ns := range nsGroup.NameServers { for _, ns := range nsGroup.NameServers {
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port)) servers = append(servers, ns.AddrPort().String())
} }
return fmt.Sprintf("%v_%v", servers, nsGroup.Domains) return fmt.Sprintf("%v_%v", servers, nsGroup.Domains)
} }

View File

@@ -97,9 +97,9 @@ func init() {
} }
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase { func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
var srvs []string var srvs []netip.AddrPort
for _, srv := range servers { for _, srv := range servers {
srvs = append(srvs, getNSHostPort(srv)) srvs = append(srvs, srv.AddrPort())
} }
return &upstreamResolverBase{ return &upstreamResolverBase{
domain: domain, domain: domain,
@@ -705,7 +705,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
} }
defer wgIFace.Close() defer wgIFace.Close()
var dnsList []string var dnsList []netip.AddrPort
dnsConfig := nbdns.Config{} dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, peer.NewRecorder("mgm"), false) dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, peer.NewRecorder("mgm"), false)
err = dnsServer.Initialize() err = dnsServer.Initialize()
@@ -715,7 +715,8 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
} }
defer dnsServer.Stop() defer dnsServer.Stop()
dnsServer.OnUpdatedHostDNSServer([]string{"8.8.8.8"}) addrPort := netip.MustParseAddrPort("8.8.8.8:53")
dnsServer.OnUpdatedHostDNSServer([]netip.AddrPort{addrPort})
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort()) resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
_, err = resolver.LookupHost(context.Background(), "netbird.io") _, err = resolver.LookupHost(context.Background(), "netbird.io")
@@ -731,7 +732,8 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
} }
defer wgIFace.Close() defer wgIFace.Close()
dnsConfig := nbdns.Config{} dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, peer.NewRecorder("mgm"), false) addrPort := netip.MustParseAddrPort("8.8.8.8:53")
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []netip.AddrPort{addrPort}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
err = dnsServer.Initialize() err = dnsServer.Initialize()
if err != nil { if err != nil {
t.Errorf("failed to initialize DNS server: %v", err) t.Errorf("failed to initialize DNS server: %v", err)
@@ -823,7 +825,8 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
} }
defer wgIFace.Close() defer wgIFace.Close()
dnsConfig := nbdns.Config{} dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, peer.NewRecorder("mgm"), false) addrPort := netip.MustParseAddrPort("8.8.8.8:53")
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []netip.AddrPort{addrPort}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
err = dnsServer.Initialize() err = dnsServer.Initialize()
if err != nil { if err != nil {
t.Errorf("failed to initialize DNS server: %v", err) t.Errorf("failed to initialize DNS server: %v", err)
@@ -2054,55 +2057,123 @@ func TestLocalResolverPriorityConstants(t *testing.T) {
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain) assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
} }
func TestFormatAddr(t *testing.T) { func TestDNSLoopPrevention(t *testing.T) {
wgInterface := &mocWGIface{}
service := NewServiceViaMemory(wgInterface)
dnsServerIP := service.RuntimeIP()
server := &DefaultServer{
ctx: context.Background(),
wgInterface: wgInterface,
service: service,
localResolver: local.NewResolver(),
handlerChain: NewHandlerChain(),
hostManager: &noopHostConfigurator{},
dnsMuxMap: make(registeredHandlerMap),
}
tests := []struct { tests := []struct {
name string name string
address string nsGroups []*nbdns.NameServerGroup
port int expectedHandlers int
expected string expectedServers []netip.Addr
shouldFilterOwnIP bool
}{ }{
{ {
name: "IPv4 address", name: "FilterOwnDNSServerIP",
address: "8.8.8.8", nsGroups: []*nbdns.NameServerGroup{
port: 53, {
expected: "8.8.8.8:53", Primary: true,
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53},
{IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53},
{IP: netip.MustParseAddr("1.1.1.1"), NSType: nbdns.UDPNameServerType, Port: 53},
},
Domains: []string{},
},
},
expectedHandlers: 1,
expectedServers: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("1.1.1.1")},
shouldFilterOwnIP: true,
}, },
{ {
name: "IPv4 address with custom port", name: "AllServersFiltered",
address: "1.1.1.1", nsGroups: []*nbdns.NameServerGroup{
port: 5353, {
expected: "1.1.1.1:5353", Primary: false,
NameServers: []nbdns.NameServer{
{IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53},
},
Domains: []string{"example.com"},
},
},
expectedHandlers: 0,
expectedServers: []netip.Addr{},
shouldFilterOwnIP: true,
}, },
{ {
name: "IPv6 address", name: "MixedServersWithOwnIP",
address: "fd78:94bf:7df8::1", nsGroups: []*nbdns.NameServerGroup{
port: 53, {
expected: "[fd78:94bf:7df8::1]:53", Primary: false,
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53},
{IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53},
{IP: netip.MustParseAddr("1.1.1.1"), NSType: nbdns.UDPNameServerType, Port: 53},
{IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53}, // duplicate
},
Domains: []string{"test.com"},
},
},
expectedHandlers: 1,
expectedServers: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("1.1.1.1")},
shouldFilterOwnIP: true,
}, },
{ {
name: "IPv6 address with custom port", name: "NoOwnIPInList",
address: "2001:db8::1", nsGroups: []*nbdns.NameServerGroup{
port: 5353, {
expected: "[2001:db8::1]:5353", Primary: true,
}, NameServers: []nbdns.NameServer{
{ {IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53},
name: "IPv6 localhost", {IP: netip.MustParseAddr("1.1.1.1"), NSType: nbdns.UDPNameServerType, Port: 53},
address: "::1", },
port: 53, Domains: []string{},
expected: "[::1]:53", },
}, },
{ expectedHandlers: 1,
name: "Invalid address treated as hostname", expectedServers: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("1.1.1.1")},
address: "dns.example.com", shouldFilterOwnIP: false,
port: 53,
expected: "dns.example.com:53",
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := formatAddr(tt.address, tt.port) muxUpdates, err := server.buildUpstreamHandlerUpdate(tt.nsGroups)
assert.Equal(t, tt.expected, result) assert.NoError(t, err)
assert.Len(t, muxUpdates, tt.expectedHandlers)
if tt.expectedHandlers > 0 {
handler := muxUpdates[0].handler.(*upstreamResolver)
assert.Len(t, handler.upstreamServers, len(tt.expectedServers))
if tt.shouldFilterOwnIP {
for _, upstream := range handler.upstreamServers {
assert.NotEqual(t, dnsServerIP, upstream.Addr())
}
}
for _, expected := range tt.expectedServers {
found := false
for _, upstream := range handler.upstreamServers {
if upstream.Addr() == expected {
found = true
break
}
}
assert.True(t, found, "Expected server %s not found", expected)
}
}
}) })
} }
} }

View File

@@ -7,7 +7,7 @@ import (
) )
const ( const (
defaultPort = 53 DefaultPort = 53
) )
type service interface { type service interface {

View File

@@ -122,7 +122,7 @@ func (s *serviceViaListener) RuntimePort() int {
defer s.listenerFlagLock.Unlock() defer s.listenerFlagLock.Unlock()
if s.ebpfService != nil { if s.ebpfService != nil {
return defaultPort return DefaultPort
} else { } else {
return int(s.listenPort) return int(s.listenPort)
} }
@@ -148,9 +148,9 @@ func (s *serviceViaListener) evalListenAddress() (netip.Addr, uint16, error) {
return s.customAddr.Addr(), s.customAddr.Port(), nil return s.customAddr.Addr(), s.customAddr.Port(), nil
} }
ip, ok := s.testFreePort(defaultPort) ip, ok := s.testFreePort(DefaultPort)
if ok { if ok {
return ip, defaultPort, nil return ip, DefaultPort, nil
} }
ebpfSrv, port, ok := s.tryToUseeBPF() ebpfSrv, port, ok := s.tryToUseeBPF()

View File

@@ -33,7 +33,7 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
dnsMux: dns.NewServeMux(), dnsMux: dns.NewServeMux(),
runtimeIP: lastIP, runtimeIP: lastIP,
runtimePort: defaultPort, runtimePort: DefaultPort,
} }
return s return s
} }

View File

@@ -235,7 +235,7 @@ func (s *systemdDbusConfigurator) callLinkMethod(method string, value any) error
return nil return nil
} }
func (s *systemdDbusConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error { func (s *systemdDbusConfigurator) restoreUncleanShutdownDNS(netip.Addr) error {
if err := s.restoreHostDNS(); err != nil { if err := s.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via systemd: %w", err) return fmt.Errorf("restoring dns via systemd: %w", err)
} }

View File

@@ -27,7 +27,7 @@ func (s *ShutdownState) Cleanup() error {
return fmt.Errorf("create previous host manager: %w", err) return fmt.Errorf("create previous host manager: %w", err)
} }
if err := manager.restoreUncleanShutdownDNS(&s.DNSAddress); err != nil { if err := manager.restoreUncleanShutdownDNS(s.DNSAddress); err != nil {
return fmt.Errorf("restore unclean shutdown dns: %w", err) return fmt.Errorf("restore unclean shutdown dns: %w", err)
} }

View File

@@ -5,8 +5,9 @@ import (
) )
type ShutdownState struct { type ShutdownState struct {
Guid string Guid string
GPO bool GPO bool
NRPTEntryCount int
} }
func (s *ShutdownState) Name() string { func (s *ShutdownState) Name() string {
@@ -15,8 +16,9 @@ func (s *ShutdownState) Name() string {
func (s *ShutdownState) Cleanup() error { func (s *ShutdownState) Cleanup() error {
manager := &registryConfigurator{ manager := &registryConfigurator{
guid: s.Guid, guid: s.Guid,
gpo: s.GPO, gpo: s.GPO,
nrptEntryCount: s.NRPTEntryCount,
} }
if err := manager.restoreUncleanShutdownDNS(); err != nil { if err := manager.restoreUncleanShutdownDNS(); err != nil {

View File

@@ -8,6 +8,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/netip"
"slices" "slices"
"strings" "strings"
"sync" "sync"
@@ -48,7 +49,7 @@ type upstreamResolverBase struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
upstreamClient upstreamClient upstreamClient upstreamClient
upstreamServers []string upstreamServers []netip.AddrPort
domain string domain string
disabled bool disabled bool
failsCount atomic.Int32 failsCount atomic.Int32
@@ -79,17 +80,20 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d
// String returns a string representation of the upstream resolver // String returns a string representation of the upstream resolver
func (u *upstreamResolverBase) String() string { func (u *upstreamResolverBase) String() string {
return fmt.Sprintf("upstream %v", u.upstreamServers) return fmt.Sprintf("upstream %s", u.upstreamServers)
} }
// ID returns the unique handler ID // ID returns the unique handler ID
func (u *upstreamResolverBase) ID() types.HandlerID { func (u *upstreamResolverBase) ID() types.HandlerID {
servers := slices.Clone(u.upstreamServers) servers := slices.Clone(u.upstreamServers)
slices.Sort(servers) slices.SortFunc(servers, func(a, b netip.AddrPort) int { return a.Compare(b) })
hash := sha256.New() hash := sha256.New()
hash.Write([]byte(u.domain + ":")) hash.Write([]byte(u.domain + ":"))
hash.Write([]byte(strings.Join(servers, ","))) for _, s := range servers {
hash.Write([]byte(s.String()))
hash.Write([]byte("|"))
}
return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8])) return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
} }
@@ -130,7 +134,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
func() { func() {
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout) ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
defer cancel() defer cancel()
rm, t, err = u.upstreamClient.exchange(ctx, upstream, r) rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
}() }()
if err != nil { if err != nil {
@@ -197,7 +201,7 @@ func (u *upstreamResolverBase) checkUpstreamFails(err error) {
proto.SystemEvent_DNS, proto.SystemEvent_DNS,
"All upstream servers failed (fail count exceeded)", "All upstream servers failed (fail count exceeded)",
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.", "Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
map[string]string{"upstreams": strings.Join(u.upstreamServers, ", ")}, map[string]string{"upstreams": u.upstreamServersString()},
// TODO add domain meta // TODO add domain meta
) )
} }
@@ -258,7 +262,7 @@ func (u *upstreamResolverBase) ProbeAvailability() {
proto.SystemEvent_DNS, proto.SystemEvent_DNS,
"All upstream servers failed (probe failed)", "All upstream servers failed (probe failed)",
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.", "Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
map[string]string{"upstreams": strings.Join(u.upstreamServers, ", ")}, map[string]string{"upstreams": u.upstreamServersString()},
) )
} }
} }
@@ -278,7 +282,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
operation := func() error { operation := func() error {
select { select {
case <-u.ctx.Done(): case <-u.ctx.Done():
return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServers)) return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServersString()))
default: default:
} }
@@ -291,7 +295,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
} }
} }
log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServers, exponentialBackOff.NextBackOff()) log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServersString(), exponentialBackOff.NextBackOff())
return fmt.Errorf("upstream check call error") return fmt.Errorf("upstream check call error")
} }
@@ -301,7 +305,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
return return
} }
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServers) log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
u.failsCount.Store(0) u.failsCount.Store(0)
u.successCount.Add(1) u.successCount.Add(1)
u.reactivate() u.reactivate()
@@ -331,13 +335,21 @@ func (u *upstreamResolverBase) disable(err error) {
go u.waitUntilResponse() go u.waitUntilResponse()
} }
func (u *upstreamResolverBase) testNameserver(server string, timeout time.Duration) error { func (u *upstreamResolverBase) upstreamServersString() string {
var servers []string
for _, server := range u.upstreamServers {
servers = append(servers, server.String())
}
return strings.Join(servers, ", ")
}
func (u *upstreamResolverBase) testNameserver(server netip.AddrPort, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(u.ctx, timeout) ctx, cancel := context.WithTimeout(u.ctx, timeout)
defer cancel() defer cancel()
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA) r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
_, _, err := u.upstreamClient.exchange(ctx, server, r) _, _, err := u.upstreamClient.exchange(ctx, server.String(), r)
return err return err
} }

View File

@@ -79,8 +79,8 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri
} }
func (u *upstreamResolver) isLocalResolver(upstream string) bool { func (u *upstreamResolver) isLocalResolver(upstream string) bool {
if u.hostsDNSHolder.isContain(upstream) { if addrPort, err := netip.ParseAddrPort(upstream); err == nil {
return true return u.hostsDNSHolder.contains(addrPort)
} }
return false return false
} }

View File

@@ -62,6 +62,8 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
upstreamIP, err := netip.ParseAddr(upstreamHost) upstreamIP, err := netip.ParseAddr(upstreamHost)
if err != nil { if err != nil {
log.Warnf("failed to parse upstream host %s: %s", upstreamHost, err) log.Warnf("failed to parse upstream host %s: %s", upstreamHost, err)
} else {
upstreamIP = upstreamIP.Unmap()
} }
if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() { if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() {
log.Debugf("using private client to query upstream: %s", upstream) log.Debugf("using private client to query upstream: %s", upstream)

View File

@@ -59,7 +59,14 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO()) ctx, cancel := context.WithCancel(context.TODO())
resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".") resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".")
resolver.upstreamServers = testCase.InputServers // Convert test servers to netip.AddrPort
var servers []netip.AddrPort
for _, server := range testCase.InputServers {
if addrPort, err := netip.ParseAddrPort(server); err == nil {
servers = append(servers, netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()))
}
}
resolver.upstreamServers = servers
resolver.upstreamTimeout = testCase.timeout resolver.upstreamTimeout = testCase.timeout
if testCase.cancelCTX { if testCase.cancelCTX {
cancel() cancel()
@@ -128,7 +135,8 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
reactivatePeriod: reactivatePeriod, reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact, failsTillDeact: failsTillDeact,
} }
resolver.upstreamServers = []string{"0.0.0.0:-1"} addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection
resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())}
resolver.failsTillDeact = 0 resolver.failsTillDeact = 0
resolver.reactivatePeriod = time.Microsecond * 100 resolver.reactivatePeriod = time.Microsecond * 100

View File

@@ -165,7 +165,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
defer cancel() defer cancel()
ips, err := f.resolver.LookupNetIP(ctx, network, domain) ips, err := f.resolver.LookupNetIP(ctx, network, domain)
if err != nil { if err != nil {
f.handleDNSError(w, query, resp, domain, err) f.handleDNSError(ctx, w, question, resp, domain, err)
return nil return nil
} }
@@ -244,20 +244,57 @@ func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixe
} }
} }
// setResponseCodeForNotFound determines and sets the appropriate response code when IsNotFound is true
// It distinguishes between NXDOMAIN (domain doesn't exist) and NODATA (domain exists but no records of requested type)
//
// LIMITATION: This function only checks A and AAAA record types to determine domain existence.
// If a domain has only other record types (MX, TXT, CNAME, etc.) but no A/AAAA records,
// it may incorrectly return NXDOMAIN instead of NODATA. This is acceptable since the forwarder
// only handles A/AAAA queries and returns NOTIMP for other types.
func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns.Msg, domain string, originalQtype uint16) {
// Try querying for a different record type to see if the domain exists
// If the original query was for AAAA, try A. If it was for A, try AAAA.
// This helps distinguish between NXDOMAIN and NODATA.
var alternativeNetwork string
switch originalQtype {
case dns.TypeAAAA:
alternativeNetwork = "ip4"
case dns.TypeA:
alternativeNetwork = "ip6"
default:
resp.Rcode = dns.RcodeNameError
return
}
if _, err := f.resolver.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
// Alternative query also returned not found - domain truly doesn't exist
resp.Rcode = dns.RcodeNameError
return
}
// Some other error (timeout, server failure, etc.) - can't determine, assume domain exists
resp.Rcode = dns.RcodeSuccess
return
}
// Alternative query succeeded - domain exists but has no records of this type
resp.Rcode = dns.RcodeSuccess
}
// handleDNSError processes DNS lookup errors and sends an appropriate error response // handleDNSError processes DNS lookup errors and sends an appropriate error response
func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, query, resp *dns.Msg, domain string, err error) { func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter, question dns.Question, resp *dns.Msg, domain string, err error) {
var dnsErr *net.DNSError var dnsErr *net.DNSError
switch { switch {
case errors.As(err, &dnsErr): case errors.As(err, &dnsErr):
resp.Rcode = dns.RcodeServerFailure resp.Rcode = dns.RcodeServerFailure
if dnsErr.IsNotFound { if dnsErr.IsNotFound {
// Pass through NXDOMAIN f.setResponseCodeForNotFound(ctx, resp, domain, question.Qtype)
resp.Rcode = dns.RcodeNameError
} }
if dnsErr.Server != "" { if dnsErr.Server != "" {
log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[query.Question[0].Qtype], domain, dnsErr.Server, err) log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[question.Qtype], domain, dnsErr.Server, err)
} else { } else {
log.Warnf(errResolveFailed, domain, err) log.Warnf(errResolveFailed, domain, err)
} }

View File

@@ -3,6 +3,7 @@ package dnsfwd
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"strings" "strings"
"testing" "testing"
@@ -16,8 +17,8 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/dns/test" "github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
) )
func Test_getMatchingEntries(t *testing.T) { func Test_getMatchingEntries(t *testing.T) {
@@ -708,6 +709,131 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
assert.Len(t, matches, 3, "Should match 3 patterns") assert.Len(t, matches, 3, "Should match 3 patterns")
} }
// TestDNSForwarder_NodataVsNxdomain tests that the forwarder correctly distinguishes
// between NXDOMAIN (domain doesn't exist) and NODATA (domain exists but no records of that type)
func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
d, err := domain.FromString("example.com")
require.NoError(t, err)
set := firewall.NewDomainSet([]domain.Domain{d})
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res", Set: set}}
forwarder.UpdateDomains(entries)
tests := []struct {
name string
queryType uint16
setupMocks func()
expectedCode int
expectNoAnswer bool // true if we expect NOERROR with empty answer (NODATA case)
description string
}{
{
name: "domain exists but no AAAA records (NODATA)",
queryType: dns.TypeAAAA,
setupMocks: func() {
// First query for AAAA returns not found
mockResolver.On("LookupNetIP", mock.Anything, "ip6", "example.com.").
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
// Check query for A records succeeds (domain exists)
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
Return([]netip.Addr{netip.MustParseAddr("1.2.3.4")}, nil).Once()
},
expectedCode: dns.RcodeSuccess,
expectNoAnswer: true,
description: "Should return NOERROR when domain exists but has no records of requested type",
},
{
name: "domain exists but no A records (NODATA)",
queryType: dns.TypeA,
setupMocks: func() {
// First query for A returns not found
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
// Check query for AAAA records succeeds (domain exists)
mockResolver.On("LookupNetIP", mock.Anything, "ip6", "example.com.").
Return([]netip.Addr{netip.MustParseAddr("2001:db8::1")}, nil).Once()
},
expectedCode: dns.RcodeSuccess,
expectNoAnswer: true,
description: "Should return NOERROR when domain exists but has no A records",
},
{
name: "domain doesn't exist (NXDOMAIN)",
queryType: dns.TypeA,
setupMocks: func() {
// First query for A returns not found
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
// Check query for AAAA also returns not found (domain doesn't exist)
mockResolver.On("LookupNetIP", mock.Anything, "ip6", "example.com.").
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
},
expectedCode: dns.RcodeNameError,
expectNoAnswer: true,
description: "Should return NXDOMAIN when domain doesn't exist at all",
},
{
name: "domain exists with records (normal success)",
queryType: dns.TypeA,
setupMocks: func() {
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
Return([]netip.Addr{netip.MustParseAddr("1.2.3.4")}, nil).Once()
// Expect firewall update for successful resolution
expectedPrefix := netip.PrefixFrom(netip.MustParseAddr("1.2.3.4"), 32)
mockFirewall.On("UpdateSet", set, []netip.Prefix{expectedPrefix}).Return(nil).Once()
},
expectedCode: dns.RcodeSuccess,
expectNoAnswer: false,
description: "Should return NOERROR with answer when records exist",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset mock expectations
mockResolver.ExpectedCalls = nil
mockResolver.Calls = nil
mockFirewall.ExpectedCalls = nil
mockFirewall.Calls = nil
tt.setupMocks()
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
resp := forwarder.handleDNSQuery(mockWriter, query)
// If a response was returned, it means it should be written (happens in wrapper functions)
if resp != nil && writtenResp == nil {
writtenResp = resp
}
require.NotNil(t, writtenResp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
if tt.expectNoAnswer {
assert.Empty(t, writtenResp.Answer, "Response should have no answer records")
}
mockResolver.AssertExpectations(t)
})
}
}
func TestDNSForwarder_EmptyQuery(t *testing.T) { func TestDNSForwarder_EmptyQuery(t *testing.T) {
// Test handling of malformed query with no questions // Test handling of malformed query with no questions
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})

View File

@@ -27,6 +27,7 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
@@ -1564,13 +1565,14 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
AnyTimes() AnyTimes()
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
groupsManager := groups.NewManagerMock()
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{}) mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
if err != nil { if err != nil {
return nil, "", err return nil, "", err

View File

@@ -1,6 +1,8 @@
package internal package internal
import ( import (
"net/netip"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
@@ -13,7 +15,7 @@ type MobileDependency struct {
TunAdapter device.TunAdapter TunAdapter device.TunAdapter
IFaceDiscover stdnet.ExternalIFaceDiscover IFaceDiscover stdnet.ExternalIFaceDiscover
NetworkChangeListener listener.NetworkChangeListener NetworkChangeListener listener.NetworkChangeListener
HostDNSAddresses []string HostDNSAddresses []netip.AddrPort
DnsReadyListener dns.ReadyListener DnsReadyListener dns.ReadyListener
// iOS only // iOS only

View File

@@ -140,7 +140,7 @@ type RosenpassState struct {
// whether it's enabled, and the last error message encountered during probing. // whether it's enabled, and the last error message encountered during probing.
type NSGroupState struct { type NSGroupState struct {
ID string ID string
Servers []string Servers []netip.AddrPort
Domains []string Domains []string
Enabled bool Enabled bool
Error error Error error

View File

@@ -593,17 +593,9 @@ func update(input ConfigInput) (*Config, error) {
return config, nil return config, nil
} }
// GetConfig read config file and return with Config. Errors out if it does not exist
func GetConfig(configPath string) (*Config, error) { func GetConfig(configPath string) (*Config, error) {
if !fileExists(configPath) { return readConfig(configPath, false)
return nil, fmt.Errorf("config file %s does not exist", configPath)
}
config := &Config{}
if _, err := util.ReadJson(configPath, config); err != nil {
return nil, fmt.Errorf("failed to read config file %s: %w", configPath, err)
}
return config, nil
} }
// UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain. // UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain.
@@ -695,6 +687,11 @@ func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
// ReadConfig read config file and return with Config. If it is not exists create a new with default values // ReadConfig read config file and return with Config. If it is not exists create a new with default values
func ReadConfig(configPath string) (*Config, error) { func ReadConfig(configPath string) (*Config, error) {
return readConfig(configPath, true)
}
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
func readConfig(configPath string, createIfMissing bool) (*Config, error) {
if fileExists(configPath) { if fileExists(configPath) {
err := util.EnforcePermission(configPath) err := util.EnforcePermission(configPath)
if err != nil { if err != nil {
@@ -715,6 +712,8 @@ func ReadConfig(configPath string) (*Config, error) {
} }
return config, nil return config, nil
} else if !createIfMissing {
return nil, fmt.Errorf("config file %s does not exist", configPath)
} }
cfg, err := createNewConfig(ConfigInput{ConfigPath: configPath}) cfg, err := createNewConfig(ConfigInput{ConfigPath: configPath})

View File

@@ -16,19 +16,21 @@
<StandardDirectory Id="ProgramFiles64Folder"> <StandardDirectory Id="ProgramFiles64Folder">
<Directory Id="NetbirdInstallDir" Name="Netbird"> <Directory Id="NetbirdInstallDir" Name="Netbird">
<Component Id="NetbirdFiles" Guid="db3165de-cc6e-4922-8396-9d892950e23e" Bitness="always64"> <Component Id="NetbirdFiles" Guid="db3165de-cc6e-4922-8396-9d892950e23e" Bitness="always64">
<File ProcessorArchitecture="x64" Source=".\dist\netbird_windows_amd64\netbird.exe" KeyPath="yes" /> <File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\netbird.exe" KeyPath="yes" />
<File ProcessorArchitecture="x64" Source=".\dist\netbird_windows_amd64\netbird-ui.exe"> <File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\netbird-ui.exe">
<Shortcut Id="NetbirdDesktopShortcut" Directory="DesktopFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" /> <Shortcut Id="NetbirdDesktopShortcut" Directory="DesktopFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" />
<Shortcut Id="NetbirdStartMenuShortcut" Directory="StartMenuFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" /> <Shortcut Id="NetbirdStartMenuShortcut" Directory="StartMenuFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" />
</File> </File>
<File ProcessorArchitecture="x64" Source=".\dist\netbird_windows_amd64\wintun.dll" /> <File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\wintun.dll" />
<File ProcessorArchitecture="x64" Source=".\dist\netbird_windows_amd64\opengl32.dll" /> <?if $(var.ArchSuffix) = "amd64" ?>
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\opengl32.dll" />
<?endif ?>
<ServiceInstall <ServiceInstall
Id="NetBirdService" Id="NetBirdService"
Name="NetBird" Name="NetBird"
DisplayName="NetBird" DisplayName="NetBird"
Description="A WireGuard-based mesh network that connects your devices into a single private network." Description="Connect your devices into a secure WireGuard-based overlay network with SSO, MFA and granular access controls."
Start="auto" Type="ownProcess" Start="auto" Type="ownProcess"
ErrorControl="normal" ErrorControl="normal"
Account="LocalSystem" Account="LocalSystem"

View File

@@ -1197,8 +1197,14 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
if dnsState.Error != nil { if dnsState.Error != nil {
err = dnsState.Error.Error() err = dnsState.Error.Error()
} }
var servers []string
for _, server := range dnsState.Servers {
servers = append(servers, server.String())
}
pbDnsState := &proto.NSGroupState{ pbDnsState := &proto.NSGroupState{
Servers: dnsState.Servers, Servers: servers,
Domains: dnsState.Domains, Domains: dnsState.Domains,
Enabled: dnsState.Enabled, Enabled: dnsState.Enabled,
Error: err, Error: err,

View File

@@ -14,6 +14,7 @@ import (
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server/groups"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -302,13 +303,14 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
t.Cleanup(ctrl.Finish) t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
permissionsManagerMock := permissions.NewMockManager(ctrl) permissionsManagerMock := permissions.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{}) mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
if err != nil { if err != nil {
return nil, "", err return nil, "", err

View File

@@ -46,7 +46,7 @@ func (s *serviceClient) showProfilesUI() {
widget.NewLabel(""), // profile name widget.NewLabel(""), // profile name
layout.NewSpacer(), layout.NewSpacer(),
widget.NewButton("Select", nil), widget.NewButton("Select", nil),
widget.NewButton("Logout", nil), widget.NewButton("Deregister", nil),
widget.NewButton("Remove", nil), widget.NewButton("Remove", nil),
) )
}, },
@@ -128,7 +128,7 @@ func (s *serviceClient) showProfilesUI() {
} }
logoutBtn.Show() logoutBtn.Show()
logoutBtn.SetText("Logout") logoutBtn.SetText("Deregister")
logoutBtn.OnTapped = func() { logoutBtn.OnTapped = func() {
s.handleProfileLogout(profile.Name, refresh) s.handleProfileLogout(profile.Name, refresh)
} }
@@ -143,7 +143,7 @@ func (s *serviceClient) showProfilesUI() {
if !confirm { if !confirm {
return return
} }
err = s.removeProfile(profile.Name) err = s.removeProfile(profile.Name)
if err != nil { if err != nil {
log.Errorf("failed to remove profile: %v", err) log.Errorf("failed to remove profile: %v", err)
@@ -334,27 +334,27 @@ func (s *serviceClient) getProfiles() ([]Profile, error) {
func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback func()) { func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback func()) {
dialog.ShowConfirm( dialog.ShowConfirm(
"Logout", "Deregister",
fmt.Sprintf("Are you sure you want to logout from '%s'?", profileName), fmt.Sprintf("Are you sure you want to deregister from '%s'?", profileName),
func(confirm bool) { func(confirm bool) {
if !confirm { if !confirm {
return return
} }
conn, err := s.getSrvClient(defaultFailTimeout) conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil { if err != nil {
log.Errorf("failed to get service client: %v", err) log.Errorf("failed to get service client: %v", err)
dialog.ShowError(fmt.Errorf("failed to connect to service"), s.wProfiles) dialog.ShowError(fmt.Errorf("failed to connect to service"), s.wProfiles)
return return
} }
currUser, err := user.Current() currUser, err := user.Current()
if err != nil { if err != nil {
log.Errorf("failed to get current user: %v", err) log.Errorf("failed to get current user: %v", err)
dialog.ShowError(fmt.Errorf("failed to get current user"), s.wProfiles) dialog.ShowError(fmt.Errorf("failed to get current user"), s.wProfiles)
return return
} }
username := currUser.Username username := currUser.Username
_, err = conn.Logout(s.ctx, &proto.LogoutRequest{ _, err = conn.Logout(s.ctx, &proto.LogoutRequest{
ProfileName: &profileName, ProfileName: &profileName,
@@ -362,16 +362,16 @@ func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback
}) })
if err != nil { if err != nil {
log.Errorf("logout failed: %v", err) log.Errorf("logout failed: %v", err)
dialog.ShowError(fmt.Errorf("logout failed"), s.wProfiles) dialog.ShowError(fmt.Errorf("deregister failed"), s.wProfiles)
return return
} }
dialog.ShowInformation( dialog.ShowInformation(
"Logged Out", "Deregistered",
fmt.Sprintf("Successfully logged out from '%s'", profileName), fmt.Sprintf("Successfully deregistered from '%s'", profileName),
s.wProfiles, s.wProfiles,
) )
refreshCallback() refreshCallback()
}, },
s.wProfiles, s.wProfiles,
@@ -602,7 +602,7 @@ func (p *profileMenu) refresh() {
// Add Logout menu item // Add Logout menu item
ctx2, cancel2 := context.WithCancel(context.Background()) ctx2, cancel2 := context.WithCancel(context.Background())
logoutItem := p.profileMenuItem.AddSubMenuItem("Logout", "") logoutItem := p.profileMenuItem.AddSubMenuItem("Deregister", "")
p.logoutSubItem = &subItem{logoutItem, ctx2, cancel2} p.logoutSubItem = &subItem{logoutItem, ctx2, cancel2}
go func() { go func() {
@@ -616,9 +616,9 @@ func (p *profileMenu) refresh() {
} }
if err := p.eventHandler.logout(p.ctx); err != nil { if err := p.eventHandler.logout(p.ctx); err != nil {
log.Errorf("logout failed: %v", err) log.Errorf("logout failed: %v", err)
p.app.SendNotification(fyne.NewNotification("Error", "Failed to logout")) p.app.SendNotification(fyne.NewNotification("Error", "Failed to deregister"))
} else { } else {
p.app.SendNotification(fyne.NewNotification("Success", "Logged out successfully")) p.app.SendNotification(fyne.NewNotification("Success", "Deregistered successfully"))
} }
} }
} }

View File

@@ -102,6 +102,11 @@ func (n *NameServer) IsEqual(other *NameServer) bool {
other.Port == n.Port other.Port == n.Port
} }
// AddrPort returns the nameserver as a netip.AddrPort
func (n *NameServer) AddrPort() netip.AddrPort {
return netip.AddrPortFrom(n.IP, uint16(n.Port))
}
// ParseNameServerURL parses a nameserver url in the format <type>://<ip>:<port>, e.g., udp://1.1.1.1:53 // ParseNameServerURL parses a nameserver url in the format <type>://<ip>:<port>, e.g., udp://1.1.1.1:53
func ParseNameServerURL(nsURL string) (NameServer, error) { func ParseNameServerURL(nsURL string) (NameServer, error) {
parsedURL, err := url.Parse(nsURL) parsedURL, err := url.Parse(nsURL)

2
go.mod
View File

@@ -63,7 +63,7 @@ require (
github.com/miekg/dns v1.1.59 github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0 github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20250805121557-5f225a973d1f github.com/netbirdio/management-integrations/integrations v0.0.0-20250812185008-dfc66fa49a2e
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0 github.com/oschwald/maxminddb-golang v1.12.0

4
go.sum
View File

@@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250805121557-5f225a973d1f h1:YmqNWdRbeVn1lSpkLzIiFHX2cndRuaVYyynx2ibrOtg= github.com/netbirdio/management-integrations/integrations v0.0.0-20250812185008-dfc66fa49a2e h1:S85laGfx1UP+nmRF9smP6/TY965kLWz41PbBK1TX8g0=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250805121557-5f225a973d1f/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q= github.com/netbirdio/management-integrations/integrations v0.0.0-20250812185008-dfc66fa49a2e/go.mod h1:Jjve0+eUjOLKL3PJtAhjfM2iJ0SxWio5elHqlV1ymP8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=

View File

@@ -34,6 +34,7 @@ import (
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
@@ -45,7 +46,6 @@ import (
"github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/auth"
nbContext "github.com/netbirdio/netbird/management/server/context" nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/groups"
nbhttp "github.com/netbirdio/netbird/management/server/http" nbhttp "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/metrics" "github.com/netbirdio/netbird/management/server/metrics"
@@ -220,7 +220,8 @@ var (
return fmt.Errorf("build default manager: %v", err) return fmt.Errorf("build default manager: %v", err)
} }
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsManager) groupsManager := groups.NewManager(store, permissionsManager, accountManager)
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsManager, groupsManager)
trustedPeers := config.ReverseProxy.TrustedPeers trustedPeers := config.ReverseProxy.TrustedPeers
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")} defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
@@ -277,7 +278,6 @@ var (
config.GetAuthAudiences(), config.GetAuthAudiences(),
config.HttpConfig.IdpSignKeyRefreshEnabled) config.HttpConfig.IdpSignKeyRefreshEnabled)
groupsManager := groups.NewManager(store, permissionsManager, accountManager)
resourcesManager := resources.NewManager(store, permissionsManager, groupsManager, accountManager) resourcesManager := resources.NewManager(store, permissionsManager, groupsManager, accountManager)
routersManager := routers.NewManager(store, permissionsManager, accountManager) routersManager := routers.NewManager(store, permissionsManager, accountManager)
networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager) networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager)

View File

@@ -40,12 +40,12 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/status"
) )
const ( const (
@@ -346,12 +346,12 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
} }
if updateAccountPeers || groupsUpdated { if updateAccountPeers || groupsUpdated {
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
} }
return transaction.SaveAccountSettings(ctx, store.LockingStrengthUpdate, accountID, newSettings) return transaction.SaveAccountSettings(ctx, accountID, newSettings)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -405,7 +405,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain) return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
} }
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "") peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
if err != nil { if err != nil {
return err return err
} }
@@ -746,7 +746,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
// AccountExists checks if an account exists. // AccountExists checks if an account exists.
func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) { func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) {
return am.Store.AccountExists(ctx, store.LockingStrengthShare, accountID) return am.Store.AccountExists(ctx, store.LockingStrengthNone, accountID)
} }
// GetAccountIDByUserID retrieves the account ID based on the userID provided. // GetAccountIDByUserID retrieves the account ID based on the userID provided.
@@ -758,7 +758,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI
return "", status.Errorf(status.NotFound, "no valid userID provided") return "", status.Errorf(status.NotFound, "no valid userID provided")
} }
accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userID) accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
@@ -813,7 +813,7 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any)
log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID) log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID)
accountIDString := fmt.Sprintf("%v", accountID) accountIDString := fmt.Sprintf("%v", accountID)
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountIDString) accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountIDString)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -867,7 +867,7 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(ctx context.Context, e
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil // lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, accountID string) (*idp.UserData, error) { func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, accountID string) (*idp.UserData, error) {
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -897,7 +897,7 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s
// add extra check on external cache manager. We may get to this point when the user is not yet findable in IDP, // add extra check on external cache manager. We may get to this point when the user is not yet findable in IDP,
// or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta // or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, accountID) log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, accountID)
return nil, err return nil, err
@@ -1048,7 +1048,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID) unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlockAccount() defer unlockAccount()
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, accountID) accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
return err return err
@@ -1058,7 +1058,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
return nil return nil
} }
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error getting user: %v", err) log.WithContext(ctx).Errorf("error getting user: %v", err)
return err return err
@@ -1145,7 +1145,7 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context,
newUser := types.NewRegularUser(userAuth.UserId) newUser := types.NewRegularUser(userAuth.UserId)
newUser.AccountID = domainAccountID newUser.AccountID = domainAccountID
err := am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser) err := am.Store.SaveUser(ctx, newUser)
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -1223,7 +1223,7 @@ func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID s
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountMeta(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountMeta(ctx, store.LockingStrengthNone, accountID)
} }
// GetAccountOnboarding retrieves the onboarding information for a specific account. // GetAccountOnboarding retrieves the onboarding information for a specific account.
@@ -1308,7 +1308,7 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
return "", "", err return "", "", err
} }
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil { if err != nil {
// this is not really possible because we got an account by user ID // this is not really possible because we got an account by user ID
return "", "", status.Errorf(status.NotFound, "user %s not found", userAuth.UserId) return "", "", status.Errorf(status.NotFound, "user %s not found", userAuth.UserId)
@@ -1340,7 +1340,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
return nil return nil
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil { if err != nil {
return err return err
} }
@@ -1366,12 +1366,12 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
var hasChanges bool var hasChanges bool
var user *types.User var user *types.User
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil { if err != nil {
return fmt.Errorf("error getting user: %w", err) return fmt.Errorf("error getting user: %w", err)
} }
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId) groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil { if err != nil {
return fmt.Errorf("error getting account groups: %w", err) return fmt.Errorf("error getting account groups: %w", err)
} }
@@ -1387,7 +1387,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
return nil return nil
} }
if err = transaction.CreateGroups(ctx, store.LockingStrengthUpdate, userAuth.AccountId, newGroupsToCreate); err != nil { if err = transaction.CreateGroups(ctx, userAuth.AccountId, newGroupsToCreate); err != nil {
return fmt.Errorf("error saving groups: %w", err) return fmt.Errorf("error saving groups: %w", err)
} }
@@ -1395,13 +1395,13 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
removeOldGroups = util.Difference(user.AutoGroups, updatedAutoGroups) removeOldGroups = util.Difference(user.AutoGroups, updatedAutoGroups)
user.AutoGroups = updatedAutoGroups user.AutoGroups = updatedAutoGroups
if err = transaction.SaveUser(ctx, store.LockingStrengthUpdate, user); err != nil { if err = transaction.SaveUser(ctx, user); err != nil {
return fmt.Errorf("error saving user: %w", err) return fmt.Errorf("error saving user: %w", err)
} }
// Propagate changes to peers if group propagation is enabled // Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled { if settings.GroupsPropagationEnabled {
peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, userAuth.AccountId, userAuth.UserId) peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, userAuth.AccountId, userAuth.UserId)
if err != nil { if err != nil {
return fmt.Errorf("error getting user peers: %w", err) return fmt.Errorf("error getting user peers: %w", err)
} }
@@ -1419,7 +1419,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
} }
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, userAuth.AccountId); err != nil { if err = transaction.IncrementNetworkSerial(ctx, userAuth.AccountId); err != nil {
return fmt.Errorf("error incrementing network serial: %w", err) return fmt.Errorf("error incrementing network serial: %w", err)
} }
} }
@@ -1437,7 +1437,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
} }
for _, g := range addNewGroups { for _, g := range addNewGroups {
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g) group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthNone, userAuth.AccountId, g)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId) log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId)
} else { } else {
@@ -1450,7 +1450,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
} }
for _, g := range removeOldGroups { for _, g := range removeOldGroups {
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g) group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthNone, userAuth.AccountId, g)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId) log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId)
} else { } else {
@@ -1511,7 +1511,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
} }
if userAuth.IsChild { if userAuth.IsChild {
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, userAuth.AccountId) exists, err := am.Store.AccountExists(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil || !exists { if err != nil || !exists {
return "", err return "", err
} }
@@ -1535,7 +1535,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
return "", err return "", err
} }
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
return "", err return "", err
@@ -1556,7 +1556,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
return am.addNewPrivateAccount(ctx, domainAccountID, userAuth) return am.addNewPrivateAccount(ctx, domainAccountID, userAuth)
} }
func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) { func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) {
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain) domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
@@ -1571,7 +1571,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
cancel := am.Store.AcquireGlobalLock(ctx) cancel := am.Store.AcquireGlobalLock(ctx)
// check again if the domain has a primary account because of simultaneous requests // check again if the domain has a primary account because of simultaneous requests
domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain) domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
cancel() cancel()
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
@@ -1582,7 +1582,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
} }
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) { func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) {
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
return "", err return "", err
@@ -1592,7 +1592,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context
return "", fmt.Errorf("user %s is not part of the account id %s", userAuth.UserId, userAuth.AccountId) return "", fmt.Errorf("user %s is not part of the account id %s", userAuth.UserId, userAuth.AccountId)
} }
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, userAuth.AccountId) accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, userAuth.AccountId)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
return "", err return "", err
@@ -1603,7 +1603,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context
} }
// We checked if the domain has a primary account already // We checked if the domain has a primary account already
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, userAuth.Domain) domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, userAuth.Domain)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
return "", err return "", err
@@ -1751,7 +1751,7 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
} }
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction store.Store, peer *nbpeer.Peer, settings *types.Settings) (bool, error) { func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction store.Store, peer *nbpeer.Peer, settings *types.Settings) (bool, error) {
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID) user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, peer.UserID)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -1780,7 +1780,7 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account
if !allowed { if !allowed {
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
} }
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id // newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
@@ -1870,7 +1870,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
cancel := am.Store.AcquireGlobalLock(ctx) cancel := am.Store.AcquireGlobalLock(ctx)
defer cancel() defer cancel()
existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain) existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
return nil, false, err return nil, false, err
} }
@@ -1890,7 +1890,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
for range 2 { for range 2 {
accountId := xid.New().String() accountId := xid.New().String()
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, accountId) exists, err := am.Store.AccountExists(ctx, store.LockingStrengthNone, accountId)
if err != nil || exists { if err != nil || exists {
continue continue
} }
@@ -1965,7 +1965,7 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
return nil return nil
} }
existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, account.Domain) existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, account.Domain)
// error is not a not found error // error is not a not found error
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
@@ -2002,17 +2002,17 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
// propagateUserGroupMemberships propagates all account users' group memberships to their peers. // propagateUserGroupMemberships propagates all account users' group memberships to their peers.
// Returns true if any groups were modified, true if those updates affect peers and an error. // Returns true if any groups were modified, true if those updates affect peers and an error.
func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) { func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) {
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return false, false, err return false, false, err
} }
accountGroupPeers, err := transaction.GetAccountGroupPeers(ctx, store.LockingStrengthShare, accountID) accountGroupPeers, err := transaction.GetAccountGroupPeers(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return false, false, fmt.Errorf("error getting account group peers: %w", err) return false, false, fmt.Errorf("error getting account group peers: %w", err)
} }
accountGroups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) accountGroups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return false, false, fmt.Errorf("error getting account groups: %w", err) return false, false, fmt.Errorf("error getting account groups: %w", err)
} }
@@ -2025,7 +2025,7 @@ func propagateUserGroupMemberships(ctx context.Context, transaction store.Store,
updatedGroups := []string{} updatedGroups := []string{}
for _, user := range users { for _, user := range users {
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, user.Id) userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, user.Id)
if err != nil { if err != nil {
return false, false, err return false, false, err
} }
@@ -2074,7 +2074,7 @@ func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, t
account.Network.Net = newIPNet account.Network.Net = newIPNet
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "") peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
if err != nil { if err != nil {
return err return err
} }
@@ -2099,7 +2099,7 @@ func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, t
} }
for _, peer := range peers { for _, peer := range peers {
if err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer); err != nil { if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
return status.Errorf(status.Internal, "save updated peer %s: %v", peer.ID, err) return status.Errorf(status.Internal, "save updated peer %s: %v", peer.ID, err)
} }
} }
@@ -2154,7 +2154,7 @@ func (am *DefaultAccountManager) updatePeerIPInTransaction(ctx context.Context,
return fmt.Errorf("get account: %w", err) return fmt.Errorf("get account: %w", err)
} }
existingPeer, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID) existingPeer, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil { if err != nil {
return fmt.Errorf("get peer: %w", err) return fmt.Errorf("get peer: %w", err)
} }
@@ -2185,7 +2185,7 @@ func (am *DefaultAccountManager) updatePeerIPInTransaction(ctx context.Context,
func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transaction store.Store, accountID, userID string, peer *nbpeer.Peer, newIP netip.Addr) error { func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transaction store.Store, accountID, userID string, peer *nbpeer.Peer, newIP netip.Addr) error {
log.WithContext(ctx).Infof("updating peer %s IP from %s to %s", peer.ID, peer.IP, newIP) log.WithContext(ctx).Infof("updating peer %s IP from %s to %s", peer.ID, peer.IP, newIP)
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return fmt.Errorf("get account settings: %w", err) return fmt.Errorf("get account settings: %w", err)
} }
@@ -2195,7 +2195,7 @@ func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transacti
oldIP := peer.IP.String() oldIP := peer.IP.String()
peer.IP = newIP.AsSlice() peer.IP = newIP.AsSlice()
err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer) err = transaction.SavePeer(ctx, accountID, peer)
if err != nil { if err != nil {
return fmt.Errorf("save peer: %w", err) return fmt.Errorf("save peer: %w", err)
} }

View File

@@ -783,7 +783,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
return return
} }
exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthShare, accountID) exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthNone, accountID)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, exists, "expected to get existing account after creation using userid") assert.True(t, exists, "expected to get existing account after creation using userid")
@@ -900,11 +900,11 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount)) t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount))
} }
pats, err := manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, "service-user-1") pats, err := manager.Store.GetUserPATs(context.Background(), store.LockingStrengthNone, "service-user-1")
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, pats, 0) assert.Len(t, pats, 0)
pats, err = manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, userId) pats, err = manager.Store.GetUserPATs(context.Background(), store.LockingStrengthNone, userId)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, pats, 0) assert.Len(t, pats, 0)
} }
@@ -1786,7 +1786,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID) settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthNone, accountID)
require.NoError(t, err, "unable to get account settings") require.NoError(t, err, "unable to get account settings")
assert.NotNil(t, settings) assert.NotNil(t, settings)
@@ -1971,7 +1971,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
assert.False(t, updatedSettings.PeerLoginExpirationEnabled) assert.False(t, updatedSettings.PeerLoginExpirationEnabled)
assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour) assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour)
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID) settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthNone, accountID)
require.NoError(t, err, "unable to get account settings") require.NoError(t, err, "unable to get account settings")
assert.False(t, settings.PeerLoginExpirationEnabled) assert.False(t, settings.PeerLoginExpirationEnabled)
@@ -2655,7 +2655,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0, "JWT groups should not be synced") assert.Len(t, user.AutoGroups, 0, "JWT groups should not be synced")
}) })
@@ -2669,7 +2669,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err := manager.SyncUserJWTGroups(context.Background(), claims) err := manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Empty(t, user.AutoGroups, "auto groups must be empty") assert.Empty(t, user.AutoGroups, "auto groups must be empty")
}) })
@@ -2683,18 +2683,18 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err := manager.SyncUserJWTGroups(context.Background(), claims) err := manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0) assert.Len(t, user.AutoGroups, 0)
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1") group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthNone, "accountID", "group1")
assert.NoError(t, err, "unable to get group") assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued") assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued")
}) })
t.Run("jwt match existing api group in user auto groups", func(t *testing.T) { t.Run("jwt match existing api group in user auto groups", func(t *testing.T) {
account.Users["user1"].AutoGroups = []string{"group1"} account.Users["user1"].AutoGroups = []string{"group1"}
assert.NoError(t, manager.Store.SaveUser(context.Background(), store.LockingStrengthUpdate, account.Users["user1"])) assert.NoError(t, manager.Store.SaveUser(context.Background(), account.Users["user1"]))
claims := nbcontext.UserAuth{ claims := nbcontext.UserAuth{
UserId: "user1", UserId: "user1",
@@ -2704,11 +2704,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1) assert.Len(t, user.AutoGroups, 1)
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1") group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthNone, "accountID", "group1")
assert.NoError(t, err, "unable to get group") assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued") assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued")
}) })
@@ -2722,7 +2722,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 2, "groups count should not be change") assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
}) })
@@ -2736,7 +2736,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 2, "groups count should not be change") assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
}) })
@@ -2750,11 +2750,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, "accountID") groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthNone, "accountID")
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, groups, 3, "new group3 should be added") assert.Len(t, groups, 3, "new group3 should be added")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "new group should be added") assert.Len(t, user.AutoGroups, 1, "new group should be added")
}) })
@@ -2768,7 +2768,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain") assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain")
assert.Contains(t, user.AutoGroups, "group1", "group1 should still be present") assert.Contains(t, user.AutoGroups, "group1", "group1 should still be present")
@@ -2783,7 +2783,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0, "all JWT groups should be removed") assert.Len(t, user.AutoGroups, 0, "all JWT groups should be removed")
}) })
@@ -3348,11 +3348,11 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"} peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"}
err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer1) err = manager.Store.AddPeerToAccount(ctx, peer1)
require.NoError(t, err) require.NoError(t, err)
peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"} peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"}
err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer2) err = manager.Store.AddPeerToAccount(ctx, peer2)
require.NoError(t, err) require.NoError(t, err)
t.Run("should skip propagation when the user has no groups", func(t *testing.T) { t.Run("should skip propagation when the user has no groups", func(t *testing.T) {
@@ -3364,20 +3364,20 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
t.Run("should update membership but no account peers update for unused groups", func(t *testing.T) { t.Run("should update membership but no account peers update for unused groups", func(t *testing.T) {
group1 := &types.Group{ID: "group1", Name: "Group 1", AccountID: account.Id} group1 := &types.Group{ID: "group1", Name: "Group 1", AccountID: account.Id}
require.NoError(t, manager.Store.CreateGroup(ctx, store.LockingStrengthUpdate, group1)) require.NoError(t, manager.Store.CreateGroup(ctx, group1))
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId) user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
require.NoError(t, err) require.NoError(t, err)
user.AutoGroups = append(user.AutoGroups, group1.ID) user.AutoGroups = append(user.AutoGroups, group1.ID)
require.NoError(t, manager.Store.SaveUser(ctx, store.LockingStrengthUpdate, user)) require.NoError(t, manager.Store.SaveUser(ctx, user))
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, groupsUpdated) assert.True(t, groupsUpdated)
assert.False(t, groupChangesAffectPeers) assert.False(t, groupChangesAffectPeers)
group, err := manager.Store.GetGroupByID(ctx, store.LockingStrengthShare, account.Id, group1.ID) group, err := manager.Store.GetGroupByID(ctx, store.LockingStrengthNone, account.Id, group1.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, group.Peers, 2) assert.Len(t, group.Peers, 2)
assert.Contains(t, group.Peers, "peer1") assert.Contains(t, group.Peers, "peer1")
@@ -3386,13 +3386,13 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
t.Run("should update membership and account peers for used groups", func(t *testing.T) { t.Run("should update membership and account peers for used groups", func(t *testing.T) {
group2 := &types.Group{ID: "group2", Name: "Group 2", AccountID: account.Id} group2 := &types.Group{ID: "group2", Name: "Group 2", AccountID: account.Id}
require.NoError(t, manager.Store.CreateGroup(ctx, store.LockingStrengthUpdate, group2)) require.NoError(t, manager.Store.CreateGroup(ctx, group2))
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId) user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
require.NoError(t, err) require.NoError(t, err)
user.AutoGroups = append(user.AutoGroups, group2.ID) user.AutoGroups = append(user.AutoGroups, group2.ID)
require.NoError(t, manager.Store.SaveUser(ctx, store.LockingStrengthUpdate, user)) require.NoError(t, manager.Store.SaveUser(ctx, user))
_, err = manager.SavePolicy(context.Background(), account.Id, initiatorId, &types.Policy{ _, err = manager.SavePolicy(context.Background(), account.Id, initiatorId, &types.Policy{
Name: "Group1 Policy", Name: "Group1 Policy",
@@ -3415,7 +3415,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
assert.True(t, groupsUpdated) assert.True(t, groupsUpdated)
assert.True(t, groupChangesAffectPeers) assert.True(t, groupChangesAffectPeers)
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthShare, account.Id, []string{"group1", "group2"}) groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthNone, account.Id, []string{"group1", "group2"})
require.NoError(t, err) require.NoError(t, err)
for _, group := range groups { for _, group := range groups {
assert.Len(t, group.Peers, 2) assert.Len(t, group.Peers, 2)
@@ -3432,18 +3432,18 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
}) })
t.Run("should not remove peers when groups are removed from user", func(t *testing.T) { t.Run("should not remove peers when groups are removed from user", func(t *testing.T) {
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId) user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
require.NoError(t, err) require.NoError(t, err)
user.AutoGroups = []string{"group1"} user.AutoGroups = []string{"group1"}
require.NoError(t, manager.Store.SaveUser(ctx, store.LockingStrengthUpdate, user)) require.NoError(t, manager.Store.SaveUser(ctx, user))
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
require.NoError(t, err) require.NoError(t, err)
assert.False(t, groupsUpdated) assert.False(t, groupsUpdated)
assert.False(t, groupChangesAffectPeers) assert.False(t, groupChangesAffectPeers)
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthShare, account.Id, []string{"group1", "group2"}) groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthNone, account.Id, []string{"group1", "group2"})
require.NoError(t, err) require.NoError(t, err)
for _, group := range groups { for _, group := range groups {
assert.Len(t, group.Peers, 2) assert.Len(t, group.Peers, 2)

View File

@@ -73,7 +73,7 @@ func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbco
return userAuth, nil return userAuth, nil
} }
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId) settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil { if err != nil {
return userAuth, err return userAuth, err
} }
@@ -94,7 +94,7 @@ func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbco
// MarkPATUsed marks a personal access token as used // MarkPATUsed marks a personal access token as used
func (am *manager) MarkPATUsed(ctx context.Context, tokenID string) error { func (am *manager) MarkPATUsed(ctx context.Context, tokenID string) error {
return am.store.MarkPATUsed(ctx, store.LockingStrengthUpdate, tokenID) return am.store.MarkPATUsed(ctx, tokenID)
} }
// GetPATInfo retrieves user, personal access token, domain, and category details from a personal access token. // GetPATInfo retrieves user, personal access token, domain, and category details from a personal access token.
@@ -104,7 +104,7 @@ func (am *manager) GetPATInfo(ctx context.Context, token string) (user *types.Us
return nil, nil, "", "", err return nil, nil, "", "", err
} }
domain, category, err = am.store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, user.AccountID) domain, category, err = am.store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, user.AccountID)
if err != nil { if err != nil {
return nil, nil, "", "", err return nil, nil, "", "", err
} }
@@ -142,12 +142,12 @@ func (am *manager) extractPATFromToken(ctx context.Context, token string) (*type
var pat *types.PersonalAccessToken var pat *types.PersonalAccessToken
err = am.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken) pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthNone, encodedHashedToken)
if err != nil { if err != nil {
return err return err
} }
user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthShare, pat.ID) user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthNone, pat.ID)
return err return err
}) })
if err != nil { if err != nil {

View File

@@ -8,14 +8,14 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status"
) )
// DNSConfigCache is a thread-safe cache for DNS configuration components // DNSConfigCache is a thread-safe cache for DNS configuration components
@@ -72,7 +72,7 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
} }
// SaveDNSSettings validates a user role and updates the account's DNS settings // SaveDNSSettings validates a user role and updates the account's DNS settings
@@ -113,11 +113,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups) events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
eventsToStore = append(eventsToStore, events...) eventsToStore = append(eventsToStore, events...)
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.SaveDNSSettings(ctx, store.LockingStrengthUpdate, accountID, dnsSettingsToSave) return transaction.SaveDNSSettings(ctx, accountID, dnsSettingsToSave)
}) })
if err != nil { if err != nil {
return err return err
@@ -139,7 +139,7 @@ func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, t
var eventsToStore []func() var eventsToStore []func()
modifiedGroups := slices.Concat(addedGroups, removedGroups) modifiedGroups := slices.Concat(addedGroups, removedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, modifiedGroups)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err) log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err)
return nil return nil
@@ -195,7 +195,7 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID
return nil return nil
} }
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, settings.DisabledManagementGroups) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, settings.DisabledManagementGroups)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -134,7 +134,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
} }
func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) { func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthShare) peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthNone)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err) log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err)
return return

View File

@@ -11,9 +11,9 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
) )
func isEnabled() bool { func isEnabled() bool {
@@ -103,7 +103,7 @@ func (am *DefaultAccountManager) fillEventsWithUserInfo(ctx context.Context, eve
} }
func (am *DefaultAccountManager) getEventsUserInfo(ctx context.Context, events []*activity.Event, accountId string, userId string) (map[string]eventUserInfo, error) { func (am *DefaultAccountManager) getEventsUserInfo(ctx context.Context, events []*activity.Event, accountId string, userId string) (map[string]eventUserInfo, error) {
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountId) accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -154,7 +154,7 @@ func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context,
continue continue
} }
externalUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id) externalUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id)
if err != nil { if err != nil {
// @todo consider logging // @todo consider logging
continue continue

View File

@@ -14,11 +14,11 @@ import (
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/status"
) )
type GroupLinkError struct { type GroupLinkError struct {
@@ -49,7 +49,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err return nil, err
} }
return am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID) return am.Store.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
} }
// GetAllGroups returns all groups in an account // GetAllGroups returns all groups in an account
@@ -57,12 +57,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err return nil, err
} }
return am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
} }
// GetGroupByName filters all groups in an account by name and returns the one with the most peers // GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) { func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
return am.Store.GetGroupByName(ctx, store.LockingStrengthShare, accountID, groupName) return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
} }
// CreateGroup object of the peers // CreateGroup object of the peers
@@ -96,11 +96,11 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
if err := transaction.CreateGroup(ctx, store.LockingStrengthUpdate, newGroup); err != nil { if err := transaction.CreateGroup(ctx, newGroup); err != nil {
return status.Errorf(status.Internal, "failed to create group: %v", err) return status.Errorf(status.Internal, "failed to create group: %v", err)
} }
@@ -147,7 +147,7 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
return err return err
} }
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, newGroup.ID) oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
if err != nil { if err != nil {
return status.Errorf(status.NotFound, "group with ID %s not found", newGroup.ID) return status.Errorf(status.NotFound, "group with ID %s not found", newGroup.ID)
} }
@@ -176,11 +176,11 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.UpdateGroup(ctx, store.LockingStrengthUpdate, newGroup) return transaction.UpdateGroup(ctx, newGroup)
}) })
if err != nil { if err != nil {
return err return err
@@ -234,11 +234,11 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.CreateGroups(ctx, store.LockingStrengthUpdate, accountID, groupsToSave) return transaction.CreateGroups(ctx, accountID, groupsToSave)
}) })
if err != nil { if err != nil {
return err return err
@@ -292,11 +292,11 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.UpdateGroups(ctx, store.LockingStrengthUpdate, accountID, groupsToSave) return transaction.UpdateGroups(ctx, accountID, groupsToSave)
}) })
if err != nil { if err != nil {
return err return err
@@ -320,7 +320,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
addedPeers := make([]string, 0) addedPeers := make([]string, 0)
removedPeers := make([]string, 0) removedPeers := make([]string, 0)
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, newGroup.ID) oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
if err == nil && oldGroup != nil { if err == nil && oldGroup != nil {
addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers) addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers)
removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers) removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers)
@@ -332,13 +332,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
} }
modifiedPeers := slices.Concat(addedPeers, removedPeers) modifiedPeers := slices.Concat(addedPeers, removedPeers)
peers, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, modifiedPeers) peers, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, modifiedPeers)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err) log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err)
return nil return nil
} }
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err) log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err)
return nil return nil
@@ -423,11 +423,11 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
deletedGroups = append(deletedGroups, group) deletedGroups = append(deletedGroups, group)
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.DeleteGroups(ctx, store.LockingStrengthUpdate, accountID, groupIDsToDelete) return transaction.DeleteGroups(ctx, accountID, groupIDsToDelete)
}) })
if err != nil { if err != nil {
return err return err
@@ -454,7 +454,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
@@ -495,11 +495,11 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.UpdateGroup(ctx, store.LockingStrengthUpdate, group) return transaction.UpdateGroup(ctx, group)
}) })
if err != nil { if err != nil {
return err return err
@@ -526,7 +526,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
@@ -567,11 +567,11 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.UpdateGroup(ctx, store.LockingStrengthUpdate, group) return transaction.UpdateGroup(ctx, group)
}) })
if err != nil { if err != nil {
return err return err
@@ -591,7 +591,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
} }
if newGroup.ID == "" && newGroup.Issued == types.GroupIssuedAPI { if newGroup.ID == "" && newGroup.Issued == types.GroupIssuedAPI {
existingGroup, err := transaction.GetGroupByName(ctx, store.LockingStrengthShare, accountID, newGroup.Name) existingGroup, err := transaction.GetGroupByName(ctx, store.LockingStrengthNone, accountID, newGroup.Name)
if err != nil { if err != nil {
if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound { if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound {
return err return err
@@ -608,7 +608,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
} }
for _, peerID := range newGroup.Peers { for _, peerID := range newGroup.Peers {
_, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID) _, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil { if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
} }
@@ -620,7 +620,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error { func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error {
// disable a deleting integration group if the initiator is not an admin service user // disable a deleting integration group if the initiator is not an admin service user
if group.Issued == types.GroupIssuedIntegration { if group.Issued == types.GroupIssuedIntegration {
executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userID) executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
return status.Errorf(status.Internal, "failed to get user") return status.Errorf(status.Internal, "failed to get user")
} }
@@ -666,7 +666,7 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty
// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account. // checkGroupLinkedToSettings verifies if a group is linked to any settings in the account.
func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, group *types.Group) error { func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, group *types.Group) error {
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, group.AccountID) dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, group.AccountID)
if err != nil { if err != nil {
return status.Errorf(status.Internal, "failed to get DNS settings") return status.Errorf(status.Internal, "failed to get DNS settings")
} }
@@ -675,7 +675,7 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, gr
return &GroupLinkError{"disabled DNS management groups", group.Name} return &GroupLinkError{"disabled DNS management groups", group.Name}
} }
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, group.AccountID) settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, group.AccountID)
if err != nil { if err != nil {
return status.Errorf(status.Internal, "failed to get account settings") return status.Errorf(status.Internal, "failed to get account settings")
} }
@@ -689,7 +689,7 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, gr
// isGroupLinkedToRoute checks if a group is linked to any route in the account. // isGroupLinkedToRoute checks if a group is linked to any route in the account.
func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *route.Route) { func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *route.Route) {
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
return false, nil return false, nil
@@ -709,7 +709,7 @@ func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountI
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account. // isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.Policy) { func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.Policy) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
return false, nil return false, nil
@@ -727,7 +727,7 @@ func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, account
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. // isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) { func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID) nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
return false, nil return false, nil
@@ -746,7 +746,7 @@ func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. // isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.SetupKey) { func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.SetupKey) {
setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID) setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
return false, nil return false, nil
@@ -762,7 +762,7 @@ func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accou
// isGroupLinkedToUser checks if a group is linked to any user in the account. // isGroupLinkedToUser checks if a group is linked to any user in the account.
func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.User) { func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.User) {
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
return false, nil return false, nil
@@ -778,7 +778,7 @@ func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID
// isGroupLinkedToNetworkRouter checks if a group is linked to any network router in the account. // isGroupLinkedToNetworkRouter checks if a group is linked to any network router in the account.
func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *routerTypes.NetworkRouter) { func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *routerTypes.NetworkRouter) {
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID) routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving network routers while checking group linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving network routers while checking group linkage: %v", err)
return false, nil return false, nil
@@ -798,7 +798,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
return false, nil return false, nil
} }
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID) dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -826,7 +826,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources. // anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) { func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupIDs)
if err != nil { if err != nil {
return false, err return false, err
} }

View File

@@ -26,10 +26,10 @@ import (
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
peer2 "github.com/netbirdio/netbird/management/server/peer" peer2 "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/status"
) )
const ( const (
@@ -898,7 +898,7 @@ func Test_AddPeerAndAddToAll(t *testing.T) {
} }
err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error { err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error {
err = transaction.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer) err = transaction.AddPeerToAccount(context.Background(), peer)
if err != nil { if err != nil {
return fmt.Errorf("AddPeer failed for peer %d: %w", i, err) return fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
} }
@@ -971,7 +971,7 @@ func Test_IncrementNetworkSerial(t *testing.T) {
<-start <-start
err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error { err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error {
err = transaction.IncrementNetworkSerial(context.Background(), store.LockingStrengthNone, accountID) err = transaction.IncrementNetworkSerial(context.Background(), accountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get account %s: %v", accountID, err) return fmt.Errorf("failed to get account %s: %v", accountID, err)
} }

View File

@@ -6,12 +6,12 @@ import (
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
) )
type Manager interface { type Manager interface {
@@ -21,6 +21,7 @@ type Manager interface {
AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resourceID *types.Resource) error AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resourceID *types.Resource) error
AddResourceToGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID string, resourceID *types.Resource) (func(), error) AddResourceToGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID string, resourceID *types.Resource) (func(), error)
RemoveResourceFromGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID, resourceID string) (func(), error) RemoveResourceFromGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID, resourceID string) (func(), error)
GetPeerGroupIDs(ctx context.Context, accountID, peerID string) ([]string, error)
} }
type managerImpl struct { type managerImpl struct {
@@ -49,7 +50,7 @@ func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string
return nil, err return nil, err
} }
groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting account groups: %w", err) return nil, fmt.Errorf("error getting account groups: %w", err)
} }
@@ -96,13 +97,13 @@ func (m *managerImpl) AddResourceToGroupInTransaction(ctx context.Context, trans
return nil, fmt.Errorf("error adding resource to group: %w", err) return nil, fmt.Errorf("error adding resource to group: %w", err)
} }
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID) group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting group: %w", err) return nil, fmt.Errorf("error getting group: %w", err)
} }
// TODO: at some point, this will need to become a switch statement // TODO: at some point, this will need to become a switch statement
networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resource.ID) networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resource.ID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting network resource: %w", err) return nil, fmt.Errorf("error getting network resource: %w", err)
} }
@@ -120,13 +121,13 @@ func (m *managerImpl) RemoveResourceFromGroupInTransaction(ctx context.Context,
return nil, fmt.Errorf("error removing resource from group: %w", err) return nil, fmt.Errorf("error removing resource from group: %w", err)
} }
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID) group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting group: %w", err) return nil, fmt.Errorf("error getting group: %w", err)
} }
// TODO: at some point, this will need to become a switch statement // TODO: at some point, this will need to become a switch statement
networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resourceID) networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting network resource: %w", err) return nil, fmt.Errorf("error getting network resource: %w", err)
} }
@@ -142,6 +143,10 @@ func (m *managerImpl) GetResourceGroupsInTransaction(ctx context.Context, transa
return transaction.GetResourceGroups(ctx, lockingStrength, accountID, resourceID) return transaction.GetResourceGroups(ctx, lockingStrength, accountID, resourceID)
} }
func (m *managerImpl) GetPeerGroupIDs(ctx context.Context, accountID, peerID string) ([]string, error) {
return m.store.GetPeerGroupIDs(ctx, store.LockingStrengthShare, accountID, peerID)
}
func ToGroupsInfoMap(groups []*types.Group, idCount int) map[string][]api.GroupMinimum { func ToGroupsInfoMap(groups []*types.Group, idCount int) map[string][]api.GroupMinimum {
groupsInfoMap := make(map[string][]api.GroupMinimum, idCount) groupsInfoMap := make(map[string][]api.GroupMinimum, idCount)
groupsChecked := make(map[string]struct{}, len(groups)) // not sure why this is needed (left over from old implementation) groupsChecked := make(map[string]struct{}, len(groups)) // not sure why this is needed (left over from old implementation)
@@ -202,6 +207,10 @@ func (m *mockManager) RemoveResourceFromGroupInTransaction(ctx context.Context,
}, nil }, nil
} }
func (m *mockManager) GetPeerGroupIDs(ctx context.Context, accountID, peerID string) ([]string, error) {
return []string{}, nil
}
func NewManagerMock() Manager { func NewManagerMock() Manager {
return &mockManager{} return &mockManager{}
} }

View File

@@ -24,7 +24,6 @@ import (
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/auth"
@@ -32,9 +31,10 @@ import (
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
internalStatus "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
internalStatus "github.com/netbirdio/netbird/shared/management/status"
) )
// GRPCServer an instance of a Management gRPC API server // GRPCServer an instance of a Management gRPC API server
@@ -662,7 +662,7 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
} }
} }
func toSyncResponse(ctx context.Context, config *types.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings) *proto.SyncResponse { func toSyncResponse(ctx context.Context, config *types.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string) *proto.SyncResponse {
response := &proto.SyncResponse{ response := &proto.SyncResponse{
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings), PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings),
NetworkMap: &proto.NetworkMap{ NetworkMap: &proto.NetworkMap{
@@ -674,7 +674,7 @@ func toSyncResponse(ctx context.Context, config *types.Config, peer *nbpeer.Peer
} }
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings) nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, nbConfig, extraSettings) extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
response.NetbirdConfig = extendedConfig response.NetbirdConfig = extendedConfig
response.NetworkMap.PeerConfig = response.PeerConfig response.NetworkMap.PeerConfig = response.PeerConfig
@@ -750,7 +750,12 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
return status.Errorf(codes.Internal, "error handling request") return status.Errorf(codes.Internal, "error handling request")
} }
plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra) peerGroups, err := getPeerGroupIDs(ctx, s.accountManager.GetStore(), peer.AccountID, peer.ID)
if err != nil {
return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
}
plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups)
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
if err != nil { if err != nil {
@@ -913,6 +918,7 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage)
func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) { func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
log.WithContext(ctx).Debugf("Logout request from peer [%s]", req.WgPubKey) log.WithContext(ctx).Debugf("Logout request from peer [%s]", req.WgPubKey)
start := time.Now()
empty := &proto.Empty{} empty := &proto.Empty{}
peerKey, err := s.parseRequest(ctx, req, empty) peerKey, err := s.parseRequest(ctx, req, empty)
@@ -920,7 +926,7 @@ func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*
return nil, err return nil, err
} }
peer, err := s.accountManager.GetStore().GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, peerKey.String()) peer, err := s.accountManager.GetStore().GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerKey.String())
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("peer %s is not registered for logout", peerKey.String()) log.WithContext(ctx).Debugf("peer %s is not registered for logout", peerKey.String())
// TODO: consider idempotency // TODO: consider idempotency
@@ -944,7 +950,7 @@ func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*
s.accountManager.BufferUpdateAccountPeers(ctx, peer.AccountID) s.accountManager.BufferUpdateAccountPeers(ctx, peer.AccountID)
log.WithContext(ctx).Infof("peer %s logged out successfully", peerKey.String()) log.WithContext(ctx).Debugf("peer %s logged out successfully after %s", peerKey.String(), time.Since(start))
return &proto.Empty{}, nil return &proto.Empty{}, nil
} }

View File

@@ -199,6 +199,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
settings.Extra = &types.ExtraSettings{ settings.Extra = &types.ExtraSettings{
PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled, PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled,
FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled, FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled,
FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups,
FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled, FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled,
} }
} }
@@ -327,6 +328,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
apiSettings.Extra = &api.AccountExtraSettings{ apiSettings.Extra = &api.AccountExtraSettings{
PeerApprovalEnabled: settings.Extra.PeerApprovalEnabled, PeerApprovalEnabled: settings.Extra.PeerApprovalEnabled,
NetworkTrafficLogsEnabled: settings.Extra.FlowEnabled, NetworkTrafficLogsEnabled: settings.Extra.FlowEnabled,
NetworkTrafficLogsGroups: settings.Extra.FlowGroups,
NetworkTrafficPacketCounterEnabled: settings.Extra.FlowPacketCounterEnabled, NetworkTrafficPacketCounterEnabled: settings.Extra.FlowPacketCounterEnabled,
} }
} }

View File

@@ -77,7 +77,7 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
for _, groupID := range groupIDs { for _, groupID := range groupIDs {
_, err := transaction.GetGroupByID(context.Background(), store.LockingStrengthShare, accountID, groupID) _, err := transaction.GetGroupByID(context.Background(), store.LockingStrengthNone, accountID, groupID)
if err != nil { if err != nil {
return err return err
} }
@@ -97,17 +97,17 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI
var peers []*nbpeer.Peer var peers []*nbpeer.Peer
var settings *types.Settings var settings *types.Settings
groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "") peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -22,14 +22,15 @@ import (
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/formatter/hook"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@@ -446,6 +447,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
Return(&types.ExtraSettings{}, nil). Return(&types.ExtraSettings{}, nil).
AnyTimes() AnyTimes()
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
groupsManager := groups.NewManagerMock()
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted", accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
@@ -455,7 +457,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
return nil, nil, "", cleanup, err return nil, nil, "", cleanup, err
} }
secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
ephemeralMgr := NewEphemeralManager(store, accountManager) ephemeralMgr := NewEphemeralManager(store, accountManager)
mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{}) mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{})
@@ -645,7 +647,7 @@ func testSyncStatusRace(t *testing.T) {
} }
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, peerWithInvalidStatus.PublicKey().String()) peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerWithInvalidStatus.PublicKey().String())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return

View File

@@ -23,6 +23,7 @@ import (
mgmtProto "github.com/netbirdio/netbird/shared/management/proto" mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
@@ -216,7 +217,8 @@ func startServer(
t.Fatalf("failed creating an account manager: %v", err) t.Fatalf("failed creating an account manager: %v", err)
} }
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) groupsManager := groups.NewManager(str, permissionsManager, accountManager)
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := server.NewServer( mgmtServer, err := server.NewServer(
context.Background(), context.Background(),
config, config,

View File

@@ -15,6 +15,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause"
) )
func GetColumnName(db *gorm.DB, column string) string { func GetColumnName(db *gorm.DB, column string) string {
@@ -466,7 +467,7 @@ func MigrateJsonToTable[T any](ctx context.Context, db *gorm.DB, columnName stri
} }
for _, value := range data { for _, value := range data {
if err := tx.Create( if err := tx.Clauses(clause.OnConflict{DoNothing: true}).Create(
mapperFunc(row["account_id"].(string), row["id"].(string), value), mapperFunc(row["account_id"].(string), row["id"].(string), value),
).Error; err != nil { ).Error; err != nil {
return fmt.Errorf("failed to insert id %v: %w", row["id"], err) return fmt.Errorf("failed to insert id %v: %w", row["id"], err)

View File

@@ -13,9 +13,9 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
) )
const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*[*.a-z]{1,}$` const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*[*.a-z]{1,}$`
@@ -32,7 +32,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupID) return am.Store.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupID)
} }
// CreateNameServerGroup creates and saves a new nameserver group // CreateNameServerGroup creates and saves a new nameserver group
@@ -73,11 +73,11 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.SaveNameServerGroup(ctx, store.LockingStrengthUpdate, newNSGroup) return transaction.SaveNameServerGroup(ctx, newNSGroup)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -112,7 +112,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
var updateAccountPeers bool var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupToSave.ID) oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupToSave.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -127,11 +127,11 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.SaveNameServerGroup(ctx, store.LockingStrengthUpdate, nsGroupToSave) return transaction.SaveNameServerGroup(ctx, nsGroupToSave)
}) })
if err != nil { if err != nil {
return err return err
@@ -173,11 +173,11 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.DeleteNameServerGroup(ctx, store.LockingStrengthUpdate, accountID, nsGroupID) return transaction.DeleteNameServerGroup(ctx, accountID, nsGroupID)
}) })
if err != nil { if err != nil {
return err return err
@@ -202,7 +202,7 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
} }
func validateNameServerGroup(ctx context.Context, transaction store.Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error { func validateNameServerGroup(ctx context.Context, transaction store.Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error {
@@ -216,7 +216,7 @@ func validateNameServerGroup(ctx context.Context, transaction store.Store, accou
return err return err
} }
nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID) nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -226,7 +226,7 @@ func validateNameServerGroup(ctx context.Context, transaction store.Store, accou
return err return err
} }
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, nameserverGroup.Groups) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, nameserverGroup.Groups)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -14,8 +14,8 @@ import (
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
) )
type Manager interface { type Manager interface {
@@ -56,7 +56,7 @@ func (m *managerImpl) GetAllNetworks(ctx context.Context, accountID, userID stri
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return m.store.GetAccountNetworks(ctx, store.LockingStrengthShare, accountID) return m.store.GetAccountNetworks(ctx, store.LockingStrengthNone, accountID)
} }
func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) { func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {
@@ -73,7 +73,7 @@ func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network
unlock := m.store.AcquireWriteLockByUID(ctx, network.AccountID) unlock := m.store.AcquireWriteLockByUID(ctx, network.AccountID)
defer unlock() defer unlock()
err = m.store.SaveNetwork(ctx, store.LockingStrengthUpdate, network) err = m.store.SaveNetwork(ctx, network)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to save network: %w", err) return nil, fmt.Errorf("failed to save network: %w", err)
} }
@@ -92,7 +92,7 @@ func (m *managerImpl) GetNetwork(ctx context.Context, accountID, userID, network
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return m.store.GetNetworkByID(ctx, store.LockingStrengthShare, accountID, networkID) return m.store.GetNetworkByID(ctx, store.LockingStrengthNone, accountID, networkID)
} }
func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) { func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {
@@ -114,7 +114,7 @@ func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network
m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkUpdated, network.EventMeta()) m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkUpdated, network.EventMeta())
return network, m.store.SaveNetwork(ctx, store.LockingStrengthUpdate, network) return network, m.store.SaveNetwork(ctx, network)
} }
func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error { func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error {
@@ -162,12 +162,12 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
eventsToStore = append(eventsToStore, event) eventsToStore = append(eventsToStore, event)
} }
err = transaction.DeleteNetwork(ctx, store.LockingStrengthUpdate, accountID, networkID) err = transaction.DeleteNetwork(ctx, accountID, networkID)
if err != nil { if err != nil {
return fmt.Errorf("failed to delete network: %w", err) return fmt.Errorf("failed to delete network: %w", err)
} }
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err) return fmt.Errorf("failed to increment network serial: %w", err)
} }

View File

@@ -12,10 +12,10 @@ import (
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
nbtypes "github.com/netbirdio/netbird/management/server/types" nbtypes "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/status"
) )
type Manager interface { type Manager interface {
@@ -57,7 +57,7 @@ func (m *managerImpl) GetAllResourcesInNetwork(ctx context.Context, accountID, u
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return m.store.GetNetworkResourcesByNetID(ctx, store.LockingStrengthShare, accountID, networkID) return m.store.GetNetworkResourcesByNetID(ctx, store.LockingStrengthNone, accountID, networkID)
} }
func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) { func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) {
@@ -69,7 +69,7 @@ func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, u
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthShare, accountID) return m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthNone, accountID)
} }
func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) { func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) {
@@ -81,7 +81,7 @@ func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID,
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
resources, err := m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthShare, accountID) resources, err := m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get network resources: %w", err) return nil, fmt.Errorf("failed to get network resources: %w", err)
} }
@@ -113,7 +113,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
var eventsToStore []func() var eventsToStore []func()
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
_, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name) _, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
if err == nil { if err == nil {
return status.Errorf(status.InvalidArgument, "resource with name %s already exists", resource.Name) return status.Errorf(status.InvalidArgument, "resource with name %s already exists", resource.Name)
} }
@@ -123,7 +123,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
return fmt.Errorf("failed to get network: %w", err) return fmt.Errorf("failed to get network: %w", err)
} }
err = transaction.SaveNetworkResource(ctx, store.LockingStrengthUpdate, resource) err = transaction.SaveNetworkResource(ctx, resource)
if err != nil { if err != nil {
return fmt.Errorf("failed to save network resource: %w", err) return fmt.Errorf("failed to save network resource: %w", err)
} }
@@ -145,7 +145,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
eventsToStore = append(eventsToStore, event) eventsToStore = append(eventsToStore, event)
} }
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, resource.AccountID) err = transaction.IncrementNetworkSerial(ctx, resource.AccountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err) return fmt.Errorf("failed to increment network serial: %w", err)
} }
@@ -174,7 +174,7 @@ func (m *managerImpl) GetResource(ctx context.Context, accountID, userID, networ
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resourceID) resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get network resource: %w", err) return nil, fmt.Errorf("failed to get network resource: %w", err)
} }
@@ -218,22 +218,22 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
return status.NewResourceNotPartOfNetworkError(resource.ID, resource.NetworkID) return status.NewResourceNotPartOfNetworkError(resource.ID, resource.NetworkID)
} }
_, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID) _, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, resource.AccountID, resource.ID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get network resource: %w", err) return fmt.Errorf("failed to get network resource: %w", err)
} }
oldResource, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name) oldResource, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
if err == nil && oldResource.ID != resource.ID { if err == nil && oldResource.ID != resource.ID {
return status.Errorf(status.InvalidArgument, "new resource name already exists") return status.Errorf(status.InvalidArgument, "new resource name already exists")
} }
oldResource, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID) oldResource, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, resource.AccountID, resource.ID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get network resource: %w", err) return fmt.Errorf("failed to get network resource: %w", err)
} }
err = transaction.SaveNetworkResource(ctx, store.LockingStrengthUpdate, resource) err = transaction.SaveNetworkResource(ctx, resource)
if err != nil { if err != nil {
return fmt.Errorf("failed to save network resource: %w", err) return fmt.Errorf("failed to save network resource: %w", err)
} }
@@ -248,7 +248,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceUpdated, resource.EventMeta(network)) m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceUpdated, resource.EventMeta(network))
}) })
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, resource.AccountID) err = transaction.IncrementNetworkSerial(ctx, resource.AccountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err) return fmt.Errorf("failed to increment network serial: %w", err)
} }
@@ -325,7 +325,7 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
return fmt.Errorf("failed to delete resource: %w", err) return fmt.Errorf("failed to delete resource: %w", err)
} }
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err) return fmt.Errorf("failed to increment network serial: %w", err)
} }
@@ -375,7 +375,7 @@ func (m *managerImpl) DeleteResourceInTransaction(ctx context.Context, transacti
eventsToStore = append(eventsToStore, event) eventsToStore = append(eventsToStore, event)
} }
err = transaction.DeleteNetworkResource(ctx, store.LockingStrengthUpdate, accountID, resourceID) err = transaction.DeleteNetworkResource(ctx, accountID, resourceID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to delete network resource: %w", err) return nil, fmt.Errorf("failed to delete network resource: %w", err)
} }

View File

@@ -14,8 +14,8 @@ import (
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
) )
type Manager interface { type Manager interface {
@@ -54,7 +54,7 @@ func (m *managerImpl) GetAllRoutersInNetwork(ctx context.Context, accountID, use
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return m.store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthShare, accountID, networkID) return m.store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, accountID, networkID)
} }
func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, userID string) (map[string][]*types.NetworkRouter, error) { func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, userID string) (map[string][]*types.NetworkRouter, error) {
@@ -66,7 +66,7 @@ func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, use
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
routers, err := m.store.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID) routers, err := m.store.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get network routers: %w", err) return nil, fmt.Errorf("failed to get network routers: %w", err)
} }
@@ -93,7 +93,7 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
var network *networkTypes.Network var network *networkTypes.Network
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthShare, router.AccountID, router.NetworkID) network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get network: %w", err) return fmt.Errorf("failed to get network: %w", err)
} }
@@ -104,12 +104,12 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
router.ID = xid.New().String() router.ID = xid.New().String()
err = transaction.SaveNetworkRouter(ctx, store.LockingStrengthUpdate, router) err = transaction.SaveNetworkRouter(ctx, router)
if err != nil { if err != nil {
return fmt.Errorf("failed to create network router: %w", err) return fmt.Errorf("failed to create network router: %w", err)
} }
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, router.AccountID) err = transaction.IncrementNetworkSerial(ctx, router.AccountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err) return fmt.Errorf("failed to increment network serial: %w", err)
} }
@@ -136,7 +136,7 @@ func (m *managerImpl) GetRouter(ctx context.Context, accountID, userID, networkI
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
router, err := m.store.GetNetworkRouterByID(ctx, store.LockingStrengthShare, accountID, routerID) router, err := m.store.GetNetworkRouterByID(ctx, store.LockingStrengthNone, accountID, routerID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get network router: %w", err) return nil, fmt.Errorf("failed to get network router: %w", err)
} }
@@ -162,7 +162,7 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
var network *networkTypes.Network var network *networkTypes.Network
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthShare, router.AccountID, router.NetworkID) network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get network: %w", err) return fmt.Errorf("failed to get network: %w", err)
} }
@@ -171,12 +171,12 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID) return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID)
} }
err = transaction.SaveNetworkRouter(ctx, store.LockingStrengthUpdate, router) err = transaction.SaveNetworkRouter(ctx, router)
if err != nil { if err != nil {
return fmt.Errorf("failed to update network router: %w", err) return fmt.Errorf("failed to update network router: %w", err)
} }
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, router.AccountID) err = transaction.IncrementNetworkSerial(ctx, router.AccountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err) return fmt.Errorf("failed to increment network serial: %w", err)
} }
@@ -213,7 +213,7 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
return fmt.Errorf("failed to delete network router: %w", err) return fmt.Errorf("failed to delete network router: %w", err)
} }
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err) return fmt.Errorf("failed to increment network serial: %w", err)
} }
@@ -232,7 +232,7 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
} }
func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (func(), error) { func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (func(), error) {
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthShare, accountID, networkID) network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthNone, accountID, networkID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get network: %w", err) return nil, fmt.Errorf("failed to get network: %w", err)
} }
@@ -246,7 +246,7 @@ func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction
return nil, status.NewRouterNotPartOfNetworkError(routerID, networkID) return nil, status.NewRouterNotPartOfNetworkError(routerID, networkID)
} }
err = transaction.DeleteNetworkRouter(ctx, store.LockingStrengthUpdate, accountID, routerID) err = transaction.DeleteNetworkRouter(ctx, accountID, routerID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to delete network router: %w", err) return nil, fmt.Errorf("failed to delete network router: %w", err)
} }

View File

@@ -17,28 +17,28 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
) )
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if // GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
// the current user is not an admin. // the current user is not an admin.
func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) { func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -48,7 +48,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
return nil, status.NewPermissionValidationError(err) return nil, status.NewPermissionValidationError(err)
} }
accountPeers, err := am.Store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, nameFilter, ipFilter) accountPeers, err := am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, nameFilter, ipFilter)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -58,7 +58,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
return accountPeers, nil return accountPeers, nil
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get account settings: %w", err) return nil, fmt.Errorf("failed to get account settings: %w", err)
} }
@@ -130,7 +130,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
} }
if peer.AddedWithSSOLogin() { if peer.AddedWithSSOLogin() {
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -173,7 +173,7 @@ func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocatio
peer.Location.CountryCode = location.Country.ISOCode peer.Location.CountryCode = location.Country.ISOCode
peer.Location.CityName = location.City.Names.En peer.Location.CityName = location.City.Names.En
peer.Location.GeoNameID = location.City.GeonameID peer.Location.GeoNameID = location.City.GeonameID
err = transaction.SavePeerLocation(ctx, store.LockingStrengthUpdate, accountID, peer) err = transaction.SavePeerLocation(ctx, accountID, peer)
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err) log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
} }
@@ -182,7 +182,7 @@ func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocatio
log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected) log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected)
err := transaction.SavePeerStatus(ctx, store.LockingStrengthUpdate, accountID, peer.ID, *newStatus) err := transaction.SavePeerStatus(ctx, accountID, peer.ID, *newStatus)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -219,7 +219,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
return err return err
} }
settings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -281,7 +281,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
inactivityExpirationChanged = true inactivityExpirationChanged = true
} }
return transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer) return transaction.SavePeer(ctx, accountID, peer)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -346,7 +346,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return status.NewPermissionDeniedError() return status.NewPermissionDeniedError()
} }
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthShare, peerID) peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
if err != nil { if err != nil {
return err return err
} }
@@ -383,7 +383,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return fmt.Errorf("failed to delete peer: %w", err) return fmt.Errorf("failed to delete peer: %w", err)
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err) return fmt.Errorf("failed to increment network serial: %w", err)
} }
@@ -609,7 +609,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
newPeer.DNSLabel = freeLabel newPeer.DNSLabel = freeLabel
newPeer.IP = freeIP newPeer.IP = freeIP
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
defer func() { defer func() {
if unlock != nil { if unlock != nil {
unlock() unlock()
@@ -617,7 +617,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
}() }()
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
err = transaction.AddPeerToAccount(ctx, store.LockingStrengthUpdate, newPeer) err = transaction.AddPeerToAccount(ctx, newPeer)
if err != nil { if err != nil {
return err return err
} }
@@ -658,7 +658,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
} }
} }
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err) return fmt.Errorf("failed to increment network serial: %w", err)
} }
@@ -734,7 +734,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
var err error var err error
var postureChecks []*posture.Checks var postureChecks []*posture.Checks
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@@ -746,7 +746,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
} }
if peer.UserID != "" { if peer.UserID != "" {
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID) user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, peer.UserID)
if err != nil { if err != nil {
return err return err
} }
@@ -774,7 +774,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
if updated { if updated {
am.metrics.AccountManagerMetrics().CountPeerMetUpdate() am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID) log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID)
if err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer); err != nil { if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
return err return err
} }
@@ -849,7 +849,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
var isPeerUpdated bool var isPeerUpdated bool
var postureChecks []*posture.Checks var postureChecks []*posture.Checks
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@@ -911,7 +911,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
} }
if shouldStorePeer { if shouldStorePeer {
if err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer); err != nil { if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
return err return err
} }
} }
@@ -934,7 +934,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
// getPeerPostureChecks returns the posture checks for the peer. // getPeerPostureChecks returns the posture checks for the peer.
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) { func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -958,7 +958,7 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI
peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...) peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...)
} }
peerPostureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthShare, accountID, peerPostureChecksIDs) peerPostureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, peerPostureChecksIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -973,7 +973,7 @@ func processPeerPostureChecks(ctx context.Context, transaction store.Store, poli
continue continue
} }
sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, rule.Sources) sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, rule.Sources)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -998,7 +998,7 @@ func processPeerPostureChecks(ctx context.Context, transaction store.Store, poli
// with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired // with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired
// and before starting the engine, we do the checks without an account lock to avoid piling up requests. // and before starting the engine, we do the checks without an account lock to avoid piling up requests.
func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login types.PeerLogin) error { func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login types.PeerLogin) error {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, login.WireGuardPubKey) peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, login.WireGuardPubKey)
if err != nil { if err != nil {
return err return err
} }
@@ -1009,7 +1009,7 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
return nil return nil
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -1080,7 +1080,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact
// If peer was expired before and if it reached this point, it is re-authenticated. // If peer was expired before and if it reached this point, it is re-authenticated.
// UserID is present, meaning that JWT validation passed successfully in the API layer. // UserID is present, meaning that JWT validation passed successfully in the API layer.
peer = peer.UpdateLastLogin() peer = peer.UpdateLastLogin()
err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, peer.AccountID, peer) err = transaction.SavePeer(ctx, peer.AccountID, peer)
if err != nil { if err != nil {
return err return err
} }
@@ -1090,7 +1090,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact
log.WithContext(ctx).Debugf("failed to update user last login: %v", err) log.WithContext(ctx).Debugf("failed to update user last login: %v", err)
} }
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, peer.AccountID) settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, peer.AccountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get account settings: %w", err) return fmt.Errorf("failed to get account settings: %w", err)
} }
@@ -1132,7 +1132,7 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *types.Se
// GetPeer for a given accountID, peerID and userID error if not found. // GetPeer for a given accountID, peerID and userID error if not found.
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) { func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID) peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1145,7 +1145,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
return peer, nil return peer, nil
} }
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1171,7 +1171,7 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
// it is also possible that user doesn't own the peer but some of his peers have access to it, // it is also possible that user doesn't own the peer but some of his peers have access to it,
// this is a valid case, show the peer as well. // this is a valid case, show the peer as well.
userPeers, err := am.Store.GetUserPeers(ctx, store.LockingStrengthShare, accountID, userID) userPeers, err := am.Store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1275,8 +1275,9 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
} }
am.metrics.UpdateChannelMetrics().CountMergeNetworkMapDuration(time.Since(start)) am.metrics.UpdateChannelMetrics().CountMergeNetworkMapDuration(time.Since(start))
peerGroups := account.GetPeerGroups(p.ID)
start = time.Now() start = time.Now()
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting) update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups))
am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start)) am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start))
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
@@ -1386,7 +1387,8 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
return return
} }
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings) peerGroups := account.GetPeerGroups(peerId)
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups))
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
} }
@@ -1394,7 +1396,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
// If there is no peer that expires this function returns false and a duration of 0. // If there is no peer that expires this function returns false and a duration of 0.
// This function only considers peers that haven't been expired yet and that are connected. // This function only considers peers that haven't been expired yet and that are connected.
func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) { func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthShare, accountID) peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err) log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err)
return peerSchedulerRetryInterval, true return peerSchedulerRetryInterval, true
@@ -1404,7 +1406,7 @@ func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, acco
return 0, false return 0, false
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err) log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
return peerSchedulerRetryInterval, true return peerSchedulerRetryInterval, true
@@ -1438,7 +1440,7 @@ func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, acco
// If there is no peer that expires this function returns false and a duration of 0. // If there is no peer that expires this function returns false and a duration of 0.
// This function only considers peers that haven't been expired yet and that are not connected. // This function only considers peers that haven't been expired yet and that are not connected.
func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) { func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthShare, accountID) peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err) log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err)
return peerSchedulerRetryInterval, true return peerSchedulerRetryInterval, true
@@ -1448,7 +1450,7 @@ func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Conte
return 0, false return 0, false
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err) log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
return peerSchedulerRetryInterval, true return peerSchedulerRetryInterval, true
@@ -1479,12 +1481,12 @@ func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Conte
// getExpiredPeers returns peers that have been expired. // getExpiredPeers returns peers that have been expired.
func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) { func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthShare, accountID) peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1502,12 +1504,12 @@ func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID
// getInactivePeers returns peers that have been expired by inactivity // getInactivePeers returns peers that have been expired by inactivity
func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) { func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthShare, accountID) peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1530,7 +1532,7 @@ func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, p
// getPeerGroupIDs returns the IDs of the groups that the peer is part of. // getPeerGroupIDs returns the IDs of the groups that the peer is part of.
func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID string, peerID string) ([]string, error) { func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID string, peerID string) ([]string, error) {
return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthShare, accountID, peerID) return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID)
} }
// IsPeerInActiveGroup checks if the given peer is part of a group that is used // IsPeerInActiveGroup checks if the given peer is part of a group that is used
@@ -1548,7 +1550,7 @@ func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID
func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) { func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) {
var peerDeletedEvents []func() var peerDeletedEvents []func()
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1568,7 +1570,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
return nil, err return nil, err
} }
if err = transaction.DeletePeer(ctx, store.LockingStrengthUpdate, accountID, peer.ID); err != nil { if err = transaction.DeletePeer(ctx, accountID, peer.ID); err != nil {
return nil, err return nil, err
} }
@@ -1624,7 +1626,7 @@ func (am *DefaultAccountManager) validatePeerDelete(ctx context.Context, transac
// isPeerLinkedToNetworkRouter checks if a peer is linked to any network router in the account. // isPeerLinkedToNetworkRouter checks if a peer is linked to any network router in the account.
func isPeerLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, peerID string) (bool, *routerTypes.NetworkRouter) { func isPeerLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, peerID string) (bool, *routerTypes.NetworkRouter) {
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID) routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving network routers while checking peer linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving network routers while checking peer linkage: %v", err)
return false, nil return false, nil

View File

@@ -38,8 +38,6 @@ import (
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
@@ -47,6 +45,8 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
nbroute "github.com/netbirdio/netbird/route" nbroute "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
) )
func TestPeer_LoginExpired(t *testing.T) { func TestPeer_LoginExpired(t *testing.T) {
@@ -1164,7 +1164,7 @@ func TestToSyncResponse(t *testing.T) {
} }
dnsCache := &DNSConfigCache{} dnsCache := &DNSConfigCache{}
accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true} accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true}
response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil) response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{})
assert.NotNil(t, response) assert.NotNil(t, response)
// assert peer config // assert peer config
@@ -1307,7 +1307,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, newPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels) assert.Equal(t, newPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels)
peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, addedPeer.Key) peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, addedPeer.Key)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, peer.AccountID, existingAccountID) assert.Equal(t, peer.AccountID, existingAccountID)
assert.Equal(t, peer.UserID, existingUserID) assert.Equal(t, peer.UserID, existingUserID)
@@ -1442,7 +1442,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
assert.NotNil(t, addedPeer, "addedPeer should not be nil on success") assert.NotNil(t, addedPeer, "addedPeer should not be nil on success")
assert.Equal(t, currentPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels, "ExtraDNSLabels mismatch") assert.Equal(t, currentPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels, "ExtraDNSLabels mismatch")
peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, currentPeer.Key) peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, currentPeer.Key)
require.NoError(t, err, "Failed to get peer by pub key: %s", currentPeer.Key) require.NoError(t, err, "Failed to get peer by pub key: %s", currentPeer.Key)
assert.Equal(t, existingAccountID, peerFromStore.AccountID, "AccountID mismatch for peer from store") assert.Equal(t, existingAccountID, peerFromStore.AccountID, "AccountID mismatch for peer from store")
assert.Equal(t, currentPeer.ExtraDNSLabels, peerFromStore.ExtraDNSLabels, "ExtraDNSLabels mismatch for peer from store") assert.Equal(t, currentPeer.ExtraDNSLabels, peerFromStore.ExtraDNSLabels, "ExtraDNSLabels mismatch for peer from store")
@@ -1476,8 +1476,9 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
func Test_RegisterPeerRollbackOnFailure(t *testing.T) { func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
engine := os.Getenv("NETBIRD_STORE_ENGINE") engine := os.Getenv("NETBIRD_STORE_ENGINE")
if engine == "sqlite" || engine == "" { if engine == "sqlite" || engine == "mysql" || engine == "" {
t.Skip("Skipping test because sqlite test store is not respecting foreign keys") // we intentionally disabled foreign keys in mysql
t.Skip("Skipping test because store is not respecting foreign keys")
} }
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet") t.Skip("The SQLite store is not properly supported by Windows yet")
@@ -1528,7 +1529,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
_, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer) _, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer)
require.Error(t, err) require.Error(t, err)
_, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, newPeer.Key) _, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, newPeer.Key)
require.Error(t, err) require.Error(t, err)
account, err := s.GetAccount(context.Background(), existingAccountID) account, err := s.GetAccount(context.Background(), existingAccountID)
@@ -1699,7 +1700,7 @@ func Test_LoginPeer(t *testing.T) {
assert.Equal(t, existingAccountID, loggedinPeer.AccountID, "AccountID mismatch for logged peer") assert.Equal(t, existingAccountID, loggedinPeer.AccountID, "AccountID mismatch for logged peer")
peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, loginInput.WireGuardPubKey) peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, loginInput.WireGuardPubKey)
require.NoError(t, err, "Failed to get peer by pub key: %s", loginInput.WireGuardPubKey) require.NoError(t, err, "Failed to get peer by pub key: %s", loginInput.WireGuardPubKey)
assert.Equal(t, existingAccountID, peerFromStore.AccountID, "AccountID mismatch for peer from store") assert.Equal(t, existingAccountID, peerFromStore.AccountID, "AccountID mismatch for peer from store")
assert.Equal(t, loggedinPeer.ID, peerFromStore.ID, "Peer ID mismatch between loggedinPeer and peerFromStore") assert.Equal(t, loggedinPeer.ID, peerFromStore.ID, "Peer ID mismatch between loggedinPeer and peerFromStore")
@@ -2160,10 +2161,10 @@ func Test_IsUniqueConstraintError(t *testing.T) {
} }
t.Cleanup(cleanup) t.Cleanup(cleanup)
err = s.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer) err = s.AddPeerToAccount(context.Background(), peer)
assert.NoError(t, err) assert.NoError(t, err)
err = s.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer) err = s.AddPeerToAccount(context.Background(), peer)
result := isUniqueConstraintError(err) result := isUniqueConstraintError(err)
assert.True(t, result) assert.True(t, result)
}) })

View File

@@ -10,8 +10,8 @@ import (
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
) )
type Manager interface { type Manager interface {
@@ -42,7 +42,7 @@ func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID str
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return m.store.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID) return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
} }
func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) { func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
@@ -52,12 +52,12 @@ func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string)
} }
if !allowed { if !allowed {
return m.store.GetUserPeers(ctx, store.LockingStrengthShare, accountID, userID) return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
} }
return m.store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "") return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
} }
func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) { func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) {
return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthShare, peerID) return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
} }

View File

@@ -11,9 +11,9 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/permissions/roles" "github.com/netbirdio/netbird/management/server/permissions/roles"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
) )
type Manager interface { type Manager interface {
@@ -45,7 +45,7 @@ func (m *managerImpl) ValidateUserPermissions(
return true, nil return true, nil
} }
user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
return false, err return false, err
} }

View File

@@ -6,11 +6,11 @@ import (
"github.com/rs/xid" "github.com/rs/xid"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
@@ -27,7 +27,7 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policyID) return am.Store.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policyID)
} }
// SavePolicy in the store // SavePolicy in the store
@@ -61,7 +61,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
@@ -71,7 +71,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
saveFunc = transaction.SavePolicy saveFunc = transaction.SavePolicy
} }
return saveFunc(ctx, store.LockingStrengthUpdate, policy) return saveFunc(ctx, policy)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -113,11 +113,11 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.DeletePolicy(ctx, store.LockingStrengthUpdate, accountID, policyID) return transaction.DeletePolicy(ctx, accountID, policyID)
}) })
if err != nil { if err != nil {
return err return err
@@ -142,13 +142,13 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
} }
// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers. // arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) { func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) {
if isUpdate { if isUpdate {
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID) existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -173,7 +173,7 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a
// validatePolicy validates the policy and its rules. // validatePolicy validates the policy and its rules.
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error { func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error {
if policy.ID != "" { if policy.ID != "" {
_, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID) _, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -182,12 +182,12 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
policy.AccountID = accountID policy.AccountID = accountID
} }
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, policy.RuleGroups()) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, policy.RuleGroups())
if err != nil { if err != nil {
return err return err
} }
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthShare, accountID, policy.SourcePostureChecks) postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, policy.SourcePostureChecks)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -13,9 +13,9 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
) )
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
@@ -27,7 +27,7 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID) return am.Store.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecksID)
} }
// SavePostureChecks saves a posture check. // SavePostureChecks saves a posture check.
@@ -62,7 +62,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
@@ -70,7 +70,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
} }
postureChecks.AccountID = accountID postureChecks.AccountID = accountID
return transaction.SavePostureChecks(ctx, store.LockingStrengthUpdate, postureChecks) return transaction.SavePostureChecks(ctx, postureChecks)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -101,7 +101,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
var postureChecks *posture.Checks var postureChecks *posture.Checks
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
postureChecks, err = transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID) postureChecks, err = transaction.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecksID)
if err != nil { if err != nil {
return err return err
} }
@@ -110,11 +110,11 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.DeletePostureChecks(ctx, store.LockingStrengthUpdate, accountID, postureChecksID) return transaction.DeletePostureChecks(ctx, accountID, postureChecksID)
}) })
if err != nil { if err != nil {
return err return err
@@ -135,7 +135,7 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID)
} }
// getPeerPostureChecks returns the posture checks applied for a given peer. // getPeerPostureChecks returns the posture checks applied for a given peer.
@@ -161,7 +161,7 @@ func (am *DefaultAccountManager) getPeerPostureChecks(account *types.Account, pe
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers. // arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) { func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -190,14 +190,14 @@ func validatePostureChecks(ctx context.Context, transaction store.Store, account
// If the posture check already has an ID, verify its existence in the store. // If the posture check already has an ID, verify its existence in the store.
if postureChecks.ID != "" { if postureChecks.ID != "" {
if _, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecks.ID); err != nil { if _, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecks.ID); err != nil {
return err return err
} }
return nil return nil
} }
// For new posture checks, ensure no duplicates by name. // For new posture checks, ensure no duplicates by name.
checks, err := transaction.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID) checks, err := transaction.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -259,7 +259,7 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy. // isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error { func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -9,15 +9,15 @@ import (
"github.com/rs/xid" "github.com/rs/xid"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status"
) )
// GetRoute gets a route object from account and route IDs // GetRoute gets a route object from account and route IDs
@@ -30,7 +30,7 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, accountID, string(routeID)) return am.Store.GetRouteByID(ctx, store.LockingStrengthNone, accountID, string(routeID))
} }
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. // checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
@@ -59,7 +59,7 @@ func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction sto
seenPeers[string(prefixRoute.ID)] = true seenPeers[string(prefixRoute.ID)] = true
} }
peerGroupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, prefixRoute.PeerGroups) peerGroupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, prefixRoute.PeerGroups)
if err != nil { if err != nil {
return err return err
} }
@@ -83,7 +83,7 @@ func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction sto
if peerID := checkRoute.Peer; peerID != "" { if peerID := checkRoute.Peer; peerID != "" {
// check that peerID exists and is not in any route as single peer or part of the group // check that peerID exists and is not in any route as single peer or part of the group
_, err = transaction.GetPeerByID(context.Background(), store.LockingStrengthShare, accountID, peerID) _, err = transaction.GetPeerByID(context.Background(), store.LockingStrengthNone, accountID, peerID)
if err != nil { if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
} }
@@ -104,7 +104,7 @@ func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction sto
} }
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix // check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
peersMap, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, group.Peers) peersMap, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, group.Peers)
if err != nil { if err != nil {
return err return err
} }
@@ -181,11 +181,11 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.SaveRoute(ctx, store.LockingStrengthUpdate, newRoute) return transaction.SaveRoute(ctx, newRoute)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -238,11 +238,11 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
} }
routeToSave.AccountID = accountID routeToSave.AccountID = accountID
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.SaveRoute(ctx, store.LockingStrengthUpdate, routeToSave) return transaction.SaveRoute(ctx, routeToSave)
}) })
if err != nil { if err != nil {
return err return err
@@ -284,11 +284,11 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.DeleteRoute(ctx, store.LockingStrengthUpdate, accountID, string(routeID)) return transaction.DeleteRoute(ctx, accountID, string(routeID))
}) })
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta()) am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
@@ -310,7 +310,7 @@ func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, user
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
} }
func validateRoute(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) error { func validateRoute(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) error {
@@ -353,7 +353,7 @@ func validateRoute(ctx context.Context, transaction store.Store, accountID strin
// validateRouteGroups validates the route groups and returns the validated groups map. // validateRouteGroups validates the route groups and returns the validated groups map.
func validateRouteGroups(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) (map[string]*types.Group, error) { func validateRouteGroups(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) (map[string]*types.Group, error) {
groupsToValidate := slices.Concat(routeToSave.Groups, routeToSave.PeerGroups, routeToSave.AccessControlGroups) groupsToValidate := slices.Concat(routeToSave.Groups, routeToSave.PeerGroups, routeToSave.AccessControlGroups)
groupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupsToValidate) groupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupsToValidate)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -494,7 +494,7 @@ func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, ro
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix // GetRoutesByPrefixOrDomains return list of routes by account and route prefix
func getRoutesByPrefixOrDomains(ctx context.Context, transaction store.Store, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) { func getRoutesByPrefixOrDomains(ctx context.Context, transaction store.Store, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) {
accountRoutes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) accountRoutes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -14,7 +14,6 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -27,6 +26,7 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
) )
const ( const (
@@ -1100,7 +1100,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route") assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route")
groups, err := am.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, account.Id) groups, err := am.Store.GetAccountGroups(context.Background(), store.LockingStrengthNone, account.Id)
require.NoError(t, err) require.NoError(t, err)
var groupHA1, groupHA2 *types.Group var groupHA1, groupHA2 *types.Group
for _, group := range groups { for _, group := range groups {

View File

@@ -11,10 +11,10 @@ import (
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/shared/management/status"
) )
type Manager interface { type Manager interface {
@@ -60,7 +60,7 @@ func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string)
return nil, fmt.Errorf("get extra settings: %w", err) return nil, fmt.Errorf("get extra settings: %w", err)
} }
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("get account settings: %w", err) return nil, fmt.Errorf("get account settings: %w", err)
} }
@@ -68,6 +68,7 @@ func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string)
// Once we migrate the peer approval to settings manager this merging is obsolete // Once we migrate the peer approval to settings manager this merging is obsolete
if settings.Extra != nil { if settings.Extra != nil {
settings.Extra.FlowEnabled = extraSettings.FlowEnabled settings.Extra.FlowEnabled = extraSettings.FlowEnabled
settings.Extra.FlowGroups = extraSettings.FlowGroups
settings.Extra.FlowPacketCounterEnabled = extraSettings.FlowPacketCounterEnabled settings.Extra.FlowPacketCounterEnabled = extraSettings.FlowPacketCounterEnabled
settings.Extra.FlowENCollectionEnabled = extraSettings.FlowENCollectionEnabled settings.Extra.FlowENCollectionEnabled = extraSettings.FlowENCollectionEnabled
settings.Extra.FlowDnsCollectionEnabled = extraSettings.FlowDnsCollectionEnabled settings.Extra.FlowDnsCollectionEnabled = extraSettings.FlowDnsCollectionEnabled
@@ -82,7 +83,7 @@ func (m *managerImpl) GetExtraSettings(ctx context.Context, accountID string) (*
return nil, fmt.Errorf("get extra settings: %w", err) return nil, fmt.Errorf("get extra settings: %w", err)
} }
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("get account settings: %w", err) return nil, fmt.Errorf("get account settings: %w", err)
} }
@@ -93,6 +94,7 @@ func (m *managerImpl) GetExtraSettings(ctx context.Context, accountID string) (*
} }
settings.Extra.FlowEnabled = extraSettings.FlowEnabled settings.Extra.FlowEnabled = extraSettings.FlowEnabled
settings.Extra.FlowGroups = extraSettings.FlowGroups
return settings.Extra, nil return settings.Extra, nil
} }

View File

@@ -10,10 +10,10 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/status"
) )
const ( const (
@@ -81,7 +81,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, autoGroups, nil, setupKey) events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, autoGroups, nil, setupKey)
eventsToStore = append(eventsToStore, events...) eventsToStore = append(eventsToStore, events...)
return transaction.SaveSetupKey(ctx, store.LockingStrengthUpdate, setupKey) return transaction.SaveSetupKey(ctx, setupKey)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -127,7 +127,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
return status.Errorf(status.InvalidArgument, "invalid auto groups: %v", err) return status.Errorf(status.InvalidArgument, "invalid auto groups: %v", err)
} }
oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyToSave.Id) oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthUpdate, accountID, keyToSave.Id)
if err != nil { if err != nil {
return err return err
} }
@@ -148,7 +148,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups, oldKey) events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups, oldKey)
eventsToStore = append(eventsToStore, events...) eventsToStore = append(eventsToStore, events...)
return transaction.SaveSetupKey(ctx, store.LockingStrengthUpdate, newKey) return transaction.SaveSetupKey(ctx, newKey)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -175,7 +175,7 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountSetupKeys(ctx, store.LockingStrengthNone, accountID)
} }
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
@@ -188,7 +188,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
setupKey, err := am.Store.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID) setupKey, err := am.Store.GetSetupKeyByID(ctx, store.LockingStrengthNone, accountID, keyID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -214,12 +214,12 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
var deletedSetupKey *types.SetupKey var deletedSetupKey *types.SetupKey
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID) deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthUpdate, accountID, keyID)
if err != nil { if err != nil {
return err return err
} }
return transaction.DeleteSetupKey(ctx, store.LockingStrengthUpdate, accountID, keyID) return transaction.DeleteSetupKey(ctx, accountID, keyID)
}) })
if err != nil { if err != nil {
return err return err
@@ -231,7 +231,7 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
} }
func validateSetupKeyAutoGroups(ctx context.Context, transaction store.Store, accountID string, autoGroupIDs []string) error { func validateSetupKeyAutoGroups(ctx context.Context, transaction store.Store, accountID string, autoGroupIDs []string) error {
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, autoGroupIDs) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, autoGroupIDs)
if err != nil { if err != nil {
return err return err
} }
@@ -255,7 +255,7 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran
var eventsToStore []func() var eventsToStore []func()
modifiedGroups := slices.Concat(addedGroups, removedGroups) modifiedGroups := slices.Concat(addedGroups, removedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, modifiedGroups)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err) log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err)
return nil return nil

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -72,8 +72,8 @@ type Store interface {
SaveAccount(ctx context.Context, account *types.Account) error SaveAccount(ctx context.Context, account *types.Account) error
DeleteAccount(ctx context.Context, account *types.Account) error DeleteAccount(ctx context.Context, account *types.Account) error
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error SaveDNSSettings(ctx context.Context, accountID string, settings *types.DNSSettings) error
SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.Settings) error SaveAccountSettings(ctx context.Context, accountID string, settings *types.Settings) error
CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error) CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error)
SaveAccountOnboarding(ctx context.Context, onboarding *types.AccountOnboarding) error SaveAccountOnboarding(ctx context.Context, onboarding *types.AccountOnboarding) error
@@ -81,10 +81,10 @@ type Store interface {
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error)
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error)
GetAccountOwner(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.User, error) GetAccountOwner(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.User, error)
SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*types.User) error SaveUsers(ctx context.Context, users []*types.User) error
SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error SaveUser(ctx context.Context, user *types.User) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
DeleteUser(ctx context.Context, lockStrength LockingStrength, accountID, userID string) error DeleteUser(ctx context.Context, accountID, userID string) error
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error DeleteTokenID2UserIDIndex(tokenID string) error
@@ -92,34 +92,34 @@ type Store interface {
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error)
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error)
GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error)
MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error MarkPATUsed(ctx context.Context, patID string) error
SavePAT(ctx context.Context, strength LockingStrength, pat *types.PersonalAccessToken) error SavePAT(ctx context.Context, pat *types.PersonalAccessToken) error
DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error DeletePAT(ctx context.Context, userID, patID string) error
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error)
GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error)
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error)
CreateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error CreateGroups(ctx context.Context, accountID string, groups []*types.Group) error
UpdateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error UpdateGroups(ctx context.Context, accountID string, groups []*types.Group) error
CreateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error CreateGroup(ctx context.Context, group *types.Group) error
UpdateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error UpdateGroup(ctx context.Context, group *types.Group) error
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error DeleteGroup(ctx context.Context, accountID, groupID string) error
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error DeleteGroups(ctx context.Context, accountID string, groupIDs []string) error
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Policy, error) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.Policy, error)
CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error CreatePolicy(ctx context.Context, policy *types.Policy) error
SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error SavePolicy(ctx context.Context, policy *types.Policy) error
DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error DeletePolicy(ctx context.Context, accountID, policyID string) error
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error)
GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error)
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error SavePostureChecks(ctx context.Context, postureCheck *posture.Checks) error
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error DeletePostureChecks(ctx context.Context, accountID, postureChecksID string) error
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string, hostname string) ([]string, error) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string, hostname string) ([]string, error)
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
@@ -130,7 +130,7 @@ type Store interface {
GetPeerGroupIDs(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]string, error) GetPeerGroupIDs(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]string, error)
AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error
RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error
AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
@@ -139,30 +139,30 @@ type Store interface {
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)
SavePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(ctx context.Context, lockStrength LockingStrength, accountID, peerID string, status nbpeer.PeerStatus) error SavePeerStatus(ctx context.Context, accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error
DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error DeletePeer(ctx context.Context, accountID string, peerID string) error
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error)
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.SetupKey, error) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.SetupKey, error)
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types.SetupKey, error) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types.SetupKey, error)
SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *types.SetupKey) error SaveSetupKey(ctx context.Context, setupKey *types.SetupKey) error
DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error DeleteSetupKey(ctx context.Context, accountID, keyID string) error
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error)
SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error SaveRoute(ctx context.Context, route *route.Route) error
DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error DeleteRoute(ctx context.Context, accountID, routeID string) error
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *dns.NameServerGroup) error SaveNameServerGroup(ctx context.Context, nameServerGroup *dns.NameServerGroup) error
DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error DeleteNameServerGroup(ctx context.Context, accountID, nameServerGroupID string) error
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error IncrementNetworkSerial(ctx context.Context, accountId string) error
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*types.Network, error) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*types.Network, error)
GetInstallationID() string GetInstallationID() string
@@ -184,21 +184,21 @@ type Store interface {
GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error)
GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error) GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error)
SaveNetwork(ctx context.Context, lockStrength LockingStrength, network *networkTypes.Network) error SaveNetwork(ctx context.Context, network *networkTypes.Network) error
DeleteNetwork(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) error DeleteNetwork(ctx context.Context, accountID, networkID string) error
GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error) GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error)
GetNetworkRoutersByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error) GetNetworkRoutersByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error)
GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error) GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error)
SaveNetworkRouter(ctx context.Context, lockStrength LockingStrength, router *routerTypes.NetworkRouter) error SaveNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error
DeleteNetworkRouter(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) error DeleteNetworkRouter(ctx context.Context, accountID, routerID string) error
GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*resourceTypes.NetworkResource, error) GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*resourceTypes.NetworkResource, error)
GetNetworkResourcesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error) GetNetworkResourcesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error)
GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*resourceTypes.NetworkResource, error) GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*resourceTypes.NetworkResource, error)
GetNetworkResourceByName(ctx context.Context, lockStrength LockingStrength, accountID, resourceName string) (*resourceTypes.NetworkResource, error) GetNetworkResourceByName(ctx context.Context, lockStrength LockingStrength, accountID, resourceName string) (*resourceTypes.NetworkResource, error)
SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *resourceTypes.NetworkResource) error SaveNetworkResource(ctx context.Context, resource *resourceTypes.NetworkResource) error
DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error DeleteNetworkResource(ctx context.Context, accountID, resourceID string) error
GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error) GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error)
GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error) GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error)
GetAccountGroupPeers(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]map[string]struct{}, error) GetAccountGroupPeers(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]map[string]struct{}, error)

View File

@@ -11,13 +11,13 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/shared/management/proto" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac" auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
authv2 "github.com/netbirdio/netbird/shared/relay/auth/hmac/v2" authv2 "github.com/netbirdio/netbird/shared/relay/auth/hmac/v2"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
) )
const defaultDuration = 12 * time.Hour const defaultDuration = 12 * time.Hour
@@ -39,13 +39,14 @@ type TimeBasedAuthSecretsManager struct {
relayHmacToken *authv2.Generator relayHmacToken *authv2.Generator
updateManager *PeersUpdateManager updateManager *PeersUpdateManager
settingsManager settings.Manager settingsManager settings.Manager
groupsManager groups.Manager
turnCancelMap map[string]chan struct{} turnCancelMap map[string]chan struct{}
relayCancelMap map[string]chan struct{} relayCancelMap map[string]chan struct{}
} }
type Token auth.Token type Token auth.Token
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *types.TURNConfig, relayCfg *types.Relay, settingsManager settings.Manager) *TimeBasedAuthSecretsManager { func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *types.TURNConfig, relayCfg *types.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager {
mgr := &TimeBasedAuthSecretsManager{ mgr := &TimeBasedAuthSecretsManager{
updateManager: updateManager, updateManager: updateManager,
turnCfg: turnCfg, turnCfg: turnCfg,
@@ -53,6 +54,7 @@ func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *
turnCancelMap: make(map[string]chan struct{}), turnCancelMap: make(map[string]chan struct{}),
relayCancelMap: make(map[string]chan struct{}), relayCancelMap: make(map[string]chan struct{}),
settingsManager: settingsManager, settingsManager: settingsManager,
groupsManager: groupsManager,
} }
if turnCfg != nil { if turnCfg != nil {
@@ -258,6 +260,11 @@ func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, p
log.WithContext(ctx).Errorf("failed to get extra settings: %v", err) log.WithContext(ctx).Errorf("failed to get extra settings: %v", err)
} }
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peerID, update.NetbirdConfig, extraSettings) peerGroups, err := m.groupsManager.GetPeerGroupIDs(ctx, accountID, peerID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get peer groups: %v", err)
}
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peerID, peerGroups, update.NetbirdConfig, extraSettings)
update.NetbirdConfig = extendedConfig update.NetbirdConfig = extendedConfig
} }

View File

@@ -13,9 +13,10 @@ import (
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@@ -40,13 +41,14 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish) t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{ tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
CredentialsTTL: ttl, CredentialsTTL: ttl,
Secret: secret, Secret: secret,
Turns: []*types.Host{TurnTestHost}, Turns: []*types.Host{TurnTestHost},
TimeBasedCredentials: true, TimeBasedCredentials: true,
}, rc, settingsMockManager) }, rc, settingsMockManager, groupsManager)
turnCredentials, err := tested.GenerateTurnToken() turnCredentials, err := tested.GenerateTurnToken()
require.NoError(t, err) require.NoError(t, err)
@@ -91,13 +93,14 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
t.Cleanup(ctrl.Finish) t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes() settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes()
groupsManager := groups.NewManagerMock()
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{ tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
CredentialsTTL: ttl, CredentialsTTL: ttl,
Secret: secret, Secret: secret,
Turns: []*types.Host{TurnTestHost}, Turns: []*types.Host{TurnTestHost},
TimeBasedCredentials: true, TimeBasedCredentials: true,
}, rc, settingsMockManager) }, rc, settingsMockManager, groupsManager)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@@ -193,13 +196,14 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish) t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{ tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
CredentialsTTL: ttl, CredentialsTTL: ttl,
Secret: secret, Secret: secret,
Turns: []*types.Host{TurnTestHost}, Turns: []*types.Host{TurnTestHost},
TimeBasedCredentials: true, TimeBasedCredentials: true,
}, rc, settingsMockManager) }, rc, settingsMockManager, groupsManager)
tested.SetupRefresh(context.Background(), "someAccountID", peer) tested.SetupRefresh(context.Background(), "someAccountID", peer)
if _, ok := tested.turnCancelMap[peer]; !ok { if _, ok := tested.turnCancelMap[peer]; !ok {

View File

@@ -2,6 +2,7 @@ package types
import ( import (
"net/netip" "net/netip"
"slices"
"time" "time"
) )
@@ -87,21 +88,21 @@ type ExtraSettings struct {
// IntegratedValidatorGroups list of group IDs to be used with integrated approval configurations // IntegratedValidatorGroups list of group IDs to be used with integrated approval configurations
IntegratedValidatorGroups []string `gorm:"serializer:json"` IntegratedValidatorGroups []string `gorm:"serializer:json"`
FlowEnabled bool `gorm:"-"` FlowEnabled bool `gorm:"-"`
FlowPacketCounterEnabled bool `gorm:"-"` FlowGroups []string `gorm:"-"`
FlowENCollectionEnabled bool `gorm:"-"` FlowPacketCounterEnabled bool `gorm:"-"`
FlowDnsCollectionEnabled bool `gorm:"-"` FlowENCollectionEnabled bool `gorm:"-"`
FlowDnsCollectionEnabled bool `gorm:"-"`
} }
// Copy copies the ExtraSettings struct // Copy copies the ExtraSettings struct
func (e *ExtraSettings) Copy() *ExtraSettings { func (e *ExtraSettings) Copy() *ExtraSettings {
var cpGroup []string
return &ExtraSettings{ return &ExtraSettings{
PeerApprovalEnabled: e.PeerApprovalEnabled, PeerApprovalEnabled: e.PeerApprovalEnabled,
IntegratedValidatorGroups: append(cpGroup, e.IntegratedValidatorGroups...), IntegratedValidatorGroups: slices.Clone(e.IntegratedValidatorGroups),
IntegratedValidator: e.IntegratedValidator, IntegratedValidator: e.IntegratedValidator,
FlowEnabled: e.FlowEnabled, FlowEnabled: e.FlowEnabled,
FlowGroups: slices.Clone(e.FlowGroups),
FlowPacketCounterEnabled: e.FlowPacketCounterEnabled, FlowPacketCounterEnabled: e.FlowPacketCounterEnabled,
FlowENCollectionEnabled: e.FlowENCollectionEnabled, FlowENCollectionEnabled: e.FlowENCollectionEnabled,
FlowDnsCollectionEnabled: e.FlowDnsCollectionEnabled, FlowDnsCollectionEnabled: e.FlowDnsCollectionEnabled,

View File

@@ -17,11 +17,11 @@ import (
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/status"
) )
// createServiceUser creates a new service user under the given account. // createServiceUser creates a new service user under the given account.
@@ -46,7 +46,7 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI
newUser.AccountID = accountID newUser.AccountID = accountID
log.WithContext(ctx).Debugf("New User: %v", newUser) log.WithContext(ctx).Debugf("New User: %v", newUser)
if err = am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser); err != nil { if err = am.Store.SaveUser(ctx, newUser); err != nil {
return nil, err return nil, err
} }
@@ -95,14 +95,14 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
inviterID := userID inviterID := userID
if initiatorUser.IsServiceUser { if initiatorUser.IsServiceUser {
createdBy, err := am.Store.GetAccountCreatedBy(ctx, store.LockingStrengthShare, accountID) createdBy, err := am.Store.GetAccountCreatedBy(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -124,7 +124,7 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
} }
if err = am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser); err != nil { if err = am.Store.SaveUser(ctx, newUser); err != nil {
return nil, err return nil, err
} }
@@ -178,13 +178,13 @@ func (am *DefaultAccountManager) createNewIdpUser(ctx context.Context, accountID
} }
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) { func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) {
return am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id) return am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id)
} }
// GetUser looks up a user by provided nbContext.UserAuths. // GetUser looks up a user by provided nbContext.UserAuths.
// Expects account to have been created already. // Expects account to have been created already.
func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbContext.UserAuth) (*types.User, error) { func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbContext.UserAuth) (*types.User, error) {
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -209,11 +209,11 @@ func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAu
// ListUsers returns lists of all users under the account. // ListUsers returns lists of all users under the account.
// It doesn't populate user information such as email or name. // It doesn't populate user information such as email or name.
func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) { func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) {
return am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
} }
func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, accountID string, initiatorUserID string, targetUser *types.User) error { func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, accountID string, initiatorUserID string, targetUser *types.User) error {
if err := am.Store.DeleteUser(ctx, store.LockingStrengthUpdate, accountID, targetUser.Id); err != nil { if err := am.Store.DeleteUser(ctx, accountID, targetUser.Id); err != nil {
return err return err
} }
meta := map[string]any{"name": targetUser.ServiceUserName, "created_at": targetUser.CreatedAt} meta := map[string]any{"name": targetUser.ServiceUserName, "created_at": targetUser.CreatedAt}
@@ -230,7 +230,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return err return err
} }
@@ -243,7 +243,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
return status.NewPermissionDeniedError() return status.NewPermissionDeniedError()
} }
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil { if err != nil {
return err return err
} }
@@ -347,12 +347,12 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -367,7 +367,7 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err) return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err)
} }
if err = am.Store.SavePAT(ctx, store.LockingStrengthUpdate, &pat.PersonalAccessToken); err != nil { if err = am.Store.SavePAT(ctx, &pat.PersonalAccessToken); err != nil {
return nil, err return nil, err
} }
@@ -390,12 +390,12 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
return status.NewPermissionDeniedError() return status.NewPermissionDeniedError()
} }
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return err return err
} }
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil { if err != nil {
return err return err
} }
@@ -404,12 +404,12 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
return status.NewAdminPermissionError() return status.NewAdminPermissionError()
} }
pat, err := am.Store.GetPATByID(ctx, store.LockingStrengthShare, targetUserID, tokenID) pat, err := am.Store.GetPATByID(ctx, store.LockingStrengthNone, targetUserID, tokenID)
if err != nil { if err != nil {
return err return err
} }
if err = am.Store.DeletePAT(ctx, store.LockingStrengthUpdate, targetUserID, tokenID); err != nil { if err = am.Store.DeletePAT(ctx, targetUserID, tokenID); err != nil {
return err return err
} }
@@ -429,12 +429,12 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -443,7 +443,7 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
return nil, status.NewAdminPermissionError() return nil, status.NewAdminPermissionError()
} }
return am.Store.GetPATByID(ctx, store.LockingStrengthShare, targetUserID, tokenID) return am.Store.GetPATByID(ctx, store.LockingStrengthNone, targetUserID, tokenID)
} }
// GetAllPATs returns all PATs for a user // GetAllPATs returns all PATs for a user
@@ -456,12 +456,12 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -470,7 +470,7 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin
return nil, status.NewAdminPermissionError() return nil, status.NewAdminPermissionError()
} }
return am.Store.GetUserPATs(ctx, store.LockingStrengthShare, targetUserID) return am.Store.GetUserPATs(ctx, store.LockingStrengthNone, targetUserID)
} }
// SaveUser saves updates to the given user. If the user doesn't exist, it will throw status.NotFound error. // SaveUser saves updates to the given user. If the user doesn't exist, it will throw status.NotFound error.
@@ -511,7 +511,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
if !allowed { if !allowed {
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -521,7 +521,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
var addUserEvents []func() var addUserEvents []func()
var usersToSave = make([]*types.User, 0, len(updates)) var usersToSave = make([]*types.User, 0, len(updates))
groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting account groups: %w", err) return nil, fmt.Errorf("error getting account groups: %w", err)
} }
@@ -533,7 +533,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
var initiatorUser *types.User var initiatorUser *types.User
if initiatorUserID != activity.SystemInitiator { if initiatorUserID != activity.SystemInitiator {
result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -560,7 +560,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
updateAccountPeers = true updateAccountPeers = true
} }
} }
return transaction.SaveUsers(ctx, store.LockingStrengthUpdate, usersToSave) return transaction.SaveUsers(ctx, usersToSave)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -593,7 +593,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
} }
if settings.GroupsPropagationEnabled && updateAccountPeers { if settings.GroupsPropagationEnabled && updateAccountPeers {
if err = am.Store.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = am.Store.IncrementNetworkSerial(ctx, accountID); err != nil {
return nil, fmt.Errorf("failed to increment network serial: %w", err) return nil, fmt.Errorf("failed to increment network serial: %w", err)
} }
am.UpdateAccountPeers(ctx, accountID) am.UpdateAccountPeers(ctx, accountID)
@@ -700,7 +700,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
// getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist. // getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist.
func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, accountID string, update *types.User, addIfNotExists bool) (*types.User, error) { func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, accountID string, update *types.User, addIfNotExists bool) (*types.User, error) {
existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, update.Id) existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, update.Id)
if err != nil { if err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
if !addIfNotExists { if !addIfNotExists {
@@ -724,7 +724,7 @@ func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initi
newInitiatorUser := initiatorUser.Copy() newInitiatorUser := initiatorUser.Copy()
newInitiatorUser.Role = types.UserRoleAdmin newInitiatorUser.Role = types.UserRoleAdmin
if err := transaction.SaveUser(ctx, store.LockingStrengthUpdate, newInitiatorUser); err != nil { if err := transaction.SaveUser(ctx, newInitiatorUser); err != nil {
return false, err return false, err
} }
return true, nil return true, nil
@@ -835,7 +835,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
var user *types.User var user *types.User
if initiatorUserID != activity.SystemInitiator { if initiatorUserID != activity.SystemInitiator {
result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err) return nil, fmt.Errorf("failed to get user: %w", err)
} }
@@ -845,7 +845,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
accountUsers := []*types.User{} accountUsers := []*types.User{}
switch { switch {
case allowed: case allowed:
accountUsers, err = am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) accountUsers, err = am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -939,7 +939,7 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
// expireAndUpdatePeers expires all peers of the given user and updates them in the account // expireAndUpdatePeers expires all peers of the given user and updates them in the account
func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accountID string, peers []*nbpeer.Peer) error { func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accountID string, peers []*nbpeer.Peer) error {
log.WithContext(ctx).Debugf("Expiring %d peers for account %s", len(peers), accountID) log.WithContext(ctx).Debugf("Expiring %d peers for account %s", len(peers), accountID)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -956,7 +956,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
peerIDs = append(peerIDs, peer.ID) peerIDs = append(peerIDs, peer.ID)
peer.MarkLoginExpired(true) peer.MarkLoginExpired(true)
if err := am.Store.SavePeerStatus(ctx, store.LockingStrengthUpdate, accountID, peer.ID, *peer.Status); err != nil { if err := am.Store.SavePeerStatus(ctx, accountID, peer.ID, *peer.Status); err != nil {
return err return err
} }
am.StoreEvent( am.StoreEvent(
@@ -1009,7 +1009,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
return status.NewPermissionDeniedError() return status.NewPermissionDeniedError()
} }
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return err return err
} }
@@ -1023,7 +1023,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
continue continue
} }
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil { if err != nil {
allErrors = errors.Join(allErrors, err) allErrors = errors.Join(allErrors, err)
continue continue
@@ -1087,12 +1087,12 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
var err error var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
targetUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserInfo.ID) targetUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthUpdate, targetUserInfo.ID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get user to delete: %w", err) return fmt.Errorf("failed to get user to delete: %w", err)
} }
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, targetUserInfo.ID) userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, targetUserInfo.ID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get user peers: %w", err) return fmt.Errorf("failed to get user peers: %w", err)
} }
@@ -1105,7 +1105,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
} }
} }
if err = transaction.DeleteUser(ctx, store.LockingStrengthUpdate, accountID, targetUserInfo.ID); err != nil { if err = transaction.DeleteUser(ctx, accountID, targetUserInfo.ID); err != nil {
return fmt.Errorf("failed to delete user: %s %w", targetUserInfo.ID, err) return fmt.Errorf("failed to delete user: %s %w", targetUserInfo.ID, err)
} }
@@ -1126,7 +1126,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
// GetOwnerInfo retrieves the owner information for a given account ID. // GetOwnerInfo retrieves the owner information for a given account ID.
func (am *DefaultAccountManager) GetOwnerInfo(ctx context.Context, accountID string) (*types.UserInfo, error) { func (am *DefaultAccountManager) GetOwnerInfo(ctx context.Context, accountID string) (*types.UserInfo, error) {
owner, err := am.Store.GetAccountOwner(ctx, store.LockingStrengthShare, accountID) owner, err := am.Store.GetAccountOwner(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1176,7 +1176,7 @@ func validateUserInvite(invite *types.UserInfo) error {
func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) { func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) {
accountID, userID := userAuth.AccountId, userAuth.UserId accountID, userID := userAuth.AccountId, userAuth.UserId
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1193,7 +1193,7 @@ func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAut
return nil, err return nil, err
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -15,9 +15,9 @@ import (
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/roles" "github.com/netbirdio/netbird/management/server/permissions/roles"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/status"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
@@ -88,7 +88,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
assert.Equal(t, pat.ID, tokenID) assert.Equal(t, pat.ID, tokenID)
user, err := am.Store.GetUserByPATID(context.Background(), store.LockingStrengthShare, tokenID) user, err := am.Store.GetUserByPATID(context.Background(), store.LockingStrengthNone, tokenID)
if err != nil { if err != nil {
t.Fatalf("Error when getting user by token ID: %s", err) t.Fatalf("Error when getting user by token ID: %s", err)
} }
@@ -1521,7 +1521,7 @@ func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) {
_, err = am.SaveOrAddUser(context.Background(), "account2", "ownerAccount2", account1.Users[targetId], true) _, err = am.SaveOrAddUser(context.Background(), "account2", "ownerAccount2", account1.Users[targetId], true)
assert.Error(t, err, "update user to another account should fail") assert.Error(t, err, "update user to another account should fail")
user, err := s.GetUserByUserID(context.Background(), store.LockingStrengthShare, targetId) user, err := s.GetUserByUserID(context.Background(), store.LockingStrengthNone, targetId)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, account1.Users[targetId].Id, user.Id) assert.Equal(t, account1.Users[targetId].Id, user.Id)
assert.Equal(t, account1.Users[targetId].AccountID, user.AccountID) assert.Equal(t, account1.Users[targetId].AccountID, user.AccountID)

View File

@@ -26,7 +26,7 @@ func NewManager(store store.Store) Manager {
} }
func (m *managerImpl) GetUser(ctx context.Context, userID string) (*types.User, error) { func (m *managerImpl) GetUser(ctx context.Context, userID string) (*types.User, error) {
return m.store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) return m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
} }
func NewManagerMock() Manager { func NewManagerMock() Manager {

Some files were not shown because too many files have changed in this diff Show More