mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-14 20:59:54 +00:00
Merge branch 'main' into fix/ice-handshake
This commit is contained in:
12
.github/pull_request_template.md
vendored
12
.github/pull_request_template.md
vendored
@@ -12,6 +12,16 @@
|
|||||||
- [ ] Is a feature enhancement
|
- [ ] Is a feature enhancement
|
||||||
- [ ] It is a refactor
|
- [ ] It is a refactor
|
||||||
- [ ] Created tests that fail without the change (if possible)
|
- [ ] Created tests that fail without the change (if possible)
|
||||||
- [ ] Extended the README / documentation, if necessary
|
|
||||||
|
|
||||||
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).
|
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
Select exactly one:
|
||||||
|
|
||||||
|
- [ ] I added/updated documentation for this change
|
||||||
|
- [ ] Documentation is **not needed** for this change (explain why)
|
||||||
|
|
||||||
|
### Docs PR URL (required if "docs added" is checked)
|
||||||
|
Paste the PR link from https://github.com/netbirdio/docs here:
|
||||||
|
|
||||||
|
https://github.com/netbirdio/docs/pull/__
|
||||||
|
|||||||
94
.github/workflows/docs-ack.yml
vendored
Normal file
94
.github/workflows/docs-ack.yml
vendored
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
name: Docs Acknowledgement
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
types: [opened, edited, synchronize]
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
pull-requests: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
docs-ack:
|
||||||
|
name: Require docs PR URL or explicit "not needed"
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Read PR body
|
||||||
|
id: body
|
||||||
|
run: |
|
||||||
|
BODY=$(jq -r '.pull_request.body // ""' "$GITHUB_EVENT_PATH")
|
||||||
|
echo "body<<EOF" >> $GITHUB_OUTPUT
|
||||||
|
echo "$BODY" >> $GITHUB_OUTPUT
|
||||||
|
echo "EOF" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
- name: Validate checkbox selection
|
||||||
|
id: validate
|
||||||
|
run: |
|
||||||
|
body='${{ steps.body.outputs.body }}'
|
||||||
|
|
||||||
|
added_checked=$(printf "%s" "$body" | grep -E '^- \[x\] I added/updated documentation' -i | wc -l | tr -d ' ')
|
||||||
|
noneed_checked=$(printf "%s" "$body" | grep -E '^- \[x\] Documentation is \*\*not needed\*\*' -i | wc -l | tr -d ' ')
|
||||||
|
|
||||||
|
if [ "$added_checked" -eq 1 ] && [ "$noneed_checked" -eq 1 ]; then
|
||||||
|
echo "::error::Choose exactly one: either 'docs added' OR 'not needed'."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "$added_checked" -eq 0 ] && [ "$noneed_checked" -eq 0 ]; then
|
||||||
|
echo "::error::You must check exactly one docs option in the PR template."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "$added_checked" -eq 1 ]; then
|
||||||
|
echo "mode=added" >> $GITHUB_OUTPUT
|
||||||
|
else
|
||||||
|
echo "mode=noneed" >> $GITHUB_OUTPUT
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Extract docs PR URL (when 'docs added')
|
||||||
|
if: steps.validate.outputs.mode == 'added'
|
||||||
|
id: extract
|
||||||
|
run: |
|
||||||
|
body='${{ steps.body.outputs.body }}'
|
||||||
|
|
||||||
|
# Strictly require HTTPS and that it's a PR in netbirdio/docs
|
||||||
|
# Examples accepted:
|
||||||
|
# https://github.com/netbirdio/docs/pull/1234
|
||||||
|
url=$(printf "%s" "$body" | grep -Eo 'https://github\.com/netbirdio/docs/pull/[0-9]+' | head -n1 || true)
|
||||||
|
|
||||||
|
if [ -z "$url" ]; then
|
||||||
|
echo "::error::You checked 'docs added' but didn't include a valid HTTPS PR link to netbirdio/docs (e.g., https://github.com/netbirdio/docs/pull/1234)."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
pr_number=$(echo "$url" | sed -E 's#.*/pull/([0-9]+)$#\1#')
|
||||||
|
echo "url=$url" >> $GITHUB_OUTPUT
|
||||||
|
echo "pr_number=$pr_number" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
- name: Verify docs PR exists (and is open or merged)
|
||||||
|
if: steps.validate.outputs.mode == 'added'
|
||||||
|
uses: actions/github-script@v7
|
||||||
|
id: verify
|
||||||
|
with:
|
||||||
|
pr_number: ${{ steps.extract.outputs.pr_number }}
|
||||||
|
script: |
|
||||||
|
const prNumber = parseInt(core.getInput('pr_number'), 10);
|
||||||
|
const { data } = await github.rest.pulls.get({
|
||||||
|
owner: 'netbirdio',
|
||||||
|
repo: 'docs',
|
||||||
|
pull_number: prNumber
|
||||||
|
});
|
||||||
|
|
||||||
|
// Allow open or merged PRs
|
||||||
|
const ok = data.state === 'open' || data.merged === true;
|
||||||
|
core.setOutput('state', data.state);
|
||||||
|
core.setOutput('merged', String(!!data.merged));
|
||||||
|
if (!ok) {
|
||||||
|
core.setFailed(`Docs PR #${prNumber} exists but is neither open nor merged (state=${data.state}, merged=${data.merged}).`);
|
||||||
|
}
|
||||||
|
result-encoding: string
|
||||||
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
|
- name: All good
|
||||||
|
run: echo "Documentation requirement satisfied ✅"
|
||||||
18
.github/workflows/forum.yml
vendored
Normal file
18
.github/workflows/forum.yml
vendored
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
name: Post release topic on Discourse
|
||||||
|
|
||||||
|
on:
|
||||||
|
release:
|
||||||
|
types: [published]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
post:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: roots/discourse-topic-github-release-action@main
|
||||||
|
with:
|
||||||
|
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
|
||||||
|
discourse-base-url: https://forum.netbird.io
|
||||||
|
discourse-author-username: NetBird
|
||||||
|
discourse-category: 17
|
||||||
|
discourse-tags:
|
||||||
|
releases
|
||||||
28
.github/workflows/release.yml
vendored
28
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.0.21"
|
SIGN_PIPE_VER: "v0.0.22"
|
||||||
GORELEASER_VER: "v2.3.2"
|
GORELEASER_VER: "v2.3.2"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "NetBird GmbH"
|
COPYRIGHT: "NetBird GmbH"
|
||||||
@@ -79,6 +79,8 @@ jobs:
|
|||||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||||
- name: Generate windows syso amd64
|
- name: Generate windows syso amd64
|
||||||
run: goversioninfo -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
|
run: goversioninfo -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
|
||||||
|
- name: Generate windows syso arm64
|
||||||
|
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||||
- name: Run GoReleaser
|
- name: Run GoReleaser
|
||||||
uses: goreleaser/goreleaser-action@v4
|
uses: goreleaser/goreleaser-action@v4
|
||||||
with:
|
with:
|
||||||
@@ -154,10 +156,20 @@ jobs:
|
|||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
|
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
|
||||||
|
|
||||||
|
- name: Install LLVM-MinGW for ARM64 cross-compilation
|
||||||
|
run: |
|
||||||
|
cd /tmp
|
||||||
|
wget -q https://github.com/mstorsjo/llvm-mingw/releases/download/20250709/llvm-mingw-20250709-ucrt-ubuntu-22.04-x86_64.tar.xz
|
||||||
|
echo "60cafae6474c7411174cff1d4ba21a8e46cadbaeb05a1bace306add301628337 llvm-mingw-20250709-ucrt-ubuntu-22.04-x86_64.tar.xz" | sha256sum -c
|
||||||
|
tar -xf llvm-mingw-20250709-ucrt-ubuntu-22.04-x86_64.tar.xz
|
||||||
|
echo "/tmp/llvm-mingw-20250709-ucrt-ubuntu-22.04-x86_64/bin" >> $GITHUB_PATH
|
||||||
- name: Install goversioninfo
|
- name: Install goversioninfo
|
||||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||||
- name: Generate windows syso amd64
|
- name: Generate windows syso amd64
|
||||||
run: goversioninfo -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
|
run: goversioninfo -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
|
||||||
|
- name: Generate windows syso arm64
|
||||||
|
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
|
||||||
|
|
||||||
- name: Run GoReleaser
|
- name: Run GoReleaser
|
||||||
uses: goreleaser/goreleaser-action@v4
|
uses: goreleaser/goreleaser-action@v4
|
||||||
@@ -231,17 +243,3 @@ jobs:
|
|||||||
ref: ${{ env.SIGN_PIPE_VER }}
|
ref: ${{ env.SIGN_PIPE_VER }}
|
||||||
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
|
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
|
||||||
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'
|
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'
|
||||||
|
|
||||||
post_on_forum:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
continue-on-error: true
|
|
||||||
needs: [trigger_signer]
|
|
||||||
steps:
|
|
||||||
- uses: Codixer/discourse-topic-github-release-action@v2.0.1
|
|
||||||
with:
|
|
||||||
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
|
|
||||||
discourse-base-url: https://forum.netbird.io
|
|
||||||
discourse-author-username: NetBird
|
|
||||||
discourse-category: 17
|
|
||||||
discourse-tags:
|
|
||||||
releases
|
|
||||||
|
|||||||
@@ -16,8 +16,6 @@ builds:
|
|||||||
- arm64
|
- arm64
|
||||||
- 386
|
- 386
|
||||||
ignore:
|
ignore:
|
||||||
- goos: windows
|
|
||||||
goarch: arm64
|
|
||||||
- goos: windows
|
- goos: windows
|
||||||
goarch: arm
|
goarch: arm
|
||||||
- goos: windows
|
- goos: windows
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ builds:
|
|||||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
- id: netbird-ui-windows
|
- id: netbird-ui-windows-amd64
|
||||||
dir: client/ui
|
dir: client/ui
|
||||||
binary: netbird-ui
|
binary: netbird-ui
|
||||||
env:
|
env:
|
||||||
@@ -30,6 +30,22 @@ builds:
|
|||||||
- -H windowsgui
|
- -H windowsgui
|
||||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
|
- id: netbird-ui-windows-arm64
|
||||||
|
dir: client/ui
|
||||||
|
binary: netbird-ui
|
||||||
|
env:
|
||||||
|
- CGO_ENABLED=1
|
||||||
|
- CC=aarch64-w64-mingw32-clang
|
||||||
|
- CXX=aarch64-w64-mingw32-clang++
|
||||||
|
goos:
|
||||||
|
- windows
|
||||||
|
goarch:
|
||||||
|
- arm64
|
||||||
|
ldflags:
|
||||||
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
|
- -H windowsgui
|
||||||
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
archives:
|
archives:
|
||||||
- id: linux-arch
|
- id: linux-arch
|
||||||
name_template: "{{ .ProjectName }}-linux_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
|
name_template: "{{ .ProjectName }}-linux_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
|
||||||
@@ -38,7 +54,8 @@ archives:
|
|||||||
- id: windows-arch
|
- id: windows-arch
|
||||||
name_template: "{{ .ProjectName }}-windows_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
|
name_template: "{{ .ProjectName }}-windows_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
|
||||||
builds:
|
builds:
|
||||||
- netbird-ui-windows
|
- netbird-ui-windows-amd64
|
||||||
|
- netbird-ui-windows-arm64
|
||||||
|
|
||||||
nfpms:
|
nfpms:
|
||||||
- maintainer: Netbird <dev@netbird.io>
|
- maintainer: Netbird <dev@netbird.io>
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ package android
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -112,7 +113,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
|||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||||
@@ -138,7 +139,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
|
|||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the internal client and free the resources
|
// Stop the internal client and free the resources
|
||||||
@@ -235,7 +236,7 @@ func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer.OnUpdatedHostDNSServer(list.items)
|
dnsServer.OnUpdatedHostDNSServer(slices.Clone(list.items))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,23 +1,34 @@
|
|||||||
package android
|
package android
|
||||||
|
|
||||||
import "fmt"
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
// DNSList is a wrapper of []string
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DNSList is a wrapper of []netip.AddrPort with default DNS port
|
||||||
type DNSList struct {
|
type DNSList struct {
|
||||||
items []string
|
items []netip.AddrPort
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add new DNS address to the collection
|
// Add new DNS address to the collection, returns error if invalid
|
||||||
func (array *DNSList) Add(s string) {
|
func (array *DNSList) Add(s string) error {
|
||||||
array.items = append(array.items, s)
|
addr, err := netip.ParseAddr(s)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid DNS address: %s", s)
|
||||||
|
}
|
||||||
|
addrPort := netip.AddrPortFrom(addr.Unmap(), dns.DefaultPort)
|
||||||
|
array.items = append(array.items, addrPort)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get return an element of the collection
|
// Get return an element of the collection as string
|
||||||
func (array *DNSList) Get(i int) (string, error) {
|
func (array *DNSList) Get(i int) (string, error) {
|
||||||
if i >= len(array.items) || i < 0 {
|
if i >= len(array.items) || i < 0 {
|
||||||
return "", fmt.Errorf("out of range")
|
return "", fmt.Errorf("out of range")
|
||||||
}
|
}
|
||||||
return array.items[i], nil
|
return array.items[i].Addr().String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Size return with the size of the collection
|
// Size return with the size of the collection
|
||||||
|
|||||||
@@ -3,20 +3,30 @@ package android
|
|||||||
import "testing"
|
import "testing"
|
||||||
|
|
||||||
func TestDNSList_Get(t *testing.T) {
|
func TestDNSList_Get(t *testing.T) {
|
||||||
l := DNSList{
|
l := DNSList{}
|
||||||
items: make([]string, 1),
|
|
||||||
|
// Add a valid DNS address
|
||||||
|
err := l.Add("8.8.8.8")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unexpected error: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := l.Get(0)
|
// Test getting valid index
|
||||||
|
addr, err := l.Get(0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("invalid error: %s", err)
|
t.Errorf("invalid error: %s", err)
|
||||||
}
|
}
|
||||||
|
if addr != "8.8.8.8" {
|
||||||
|
t.Errorf("expected 8.8.8.8, got %s", addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test negative index
|
||||||
_, err = l.Get(-1)
|
_, err = l.Get(-1)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("expected error but got nil")
|
t.Errorf("expected error but got nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test out of bounds index
|
||||||
_, err = l.Get(1)
|
_, err = l.Get(1)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("expected error but got nil")
|
t.Errorf("expected error but got nil")
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ var (
|
|||||||
var debugCmd = &cobra.Command{
|
var debugCmd = &cobra.Command{
|
||||||
Use: "debug",
|
Use: "debug",
|
||||||
Short: "Debugging commands",
|
Short: "Debugging commands",
|
||||||
Long: "Provides commands for debugging and logging control within the Netbird daemon.",
|
Long: "Provides commands for debugging and logging control within the NetBird daemon.",
|
||||||
}
|
}
|
||||||
|
|
||||||
var debugBundleCmd = &cobra.Command{
|
var debugBundleCmd = &cobra.Command{
|
||||||
@@ -46,8 +46,8 @@ var debugBundleCmd = &cobra.Command{
|
|||||||
|
|
||||||
var logCmd = &cobra.Command{
|
var logCmd = &cobra.Command{
|
||||||
Use: "log",
|
Use: "log",
|
||||||
Short: "Manage logging for the Netbird daemon",
|
Short: "Manage logging for the NetBird daemon",
|
||||||
Long: `Commands to manage logging settings for the Netbird daemon, including ICE, gRPC, and general log levels.`,
|
Long: `Commands to manage logging settings for the NetBird daemon, including ICE, gRPC, and general log levels.`,
|
||||||
}
|
}
|
||||||
|
|
||||||
var logLevelCmd = &cobra.Command{
|
var logLevelCmd = &cobra.Command{
|
||||||
@@ -184,7 +184,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||||
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
||||||
}
|
}
|
||||||
cmd.Println("Netbird up")
|
cmd.Println("netbird up")
|
||||||
time.Sleep(time.Second * 10)
|
time.Sleep(time.Second * 10)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -202,7 +202,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||||
}
|
}
|
||||||
cmd.Println("Netbird down")
|
cmd.Println("netbird down")
|
||||||
|
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second)
|
||||||
|
|
||||||
@@ -216,11 +216,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||||
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
||||||
}
|
}
|
||||||
cmd.Println("Netbird up")
|
cmd.Println("netbird up")
|
||||||
|
|
||||||
time.Sleep(3 * time.Second)
|
time.Sleep(3 * time.Second)
|
||||||
|
|
||||||
headerPostUp := fmt.Sprintf("----- Netbird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
||||||
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag))
|
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag))
|
||||||
|
|
||||||
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
|
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
|
||||||
@@ -230,7 +230,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
cmd.Println("Creating debug bundle...")
|
cmd.Println("Creating debug bundle...")
|
||||||
|
|
||||||
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
||||||
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
|
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
|
||||||
request := &proto.DebugBundleRequest{
|
request := &proto.DebugBundleRequest{
|
||||||
Anonymize: anonymizeFlag,
|
Anonymize: anonymizeFlag,
|
||||||
@@ -250,7 +250,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||||
}
|
}
|
||||||
cmd.Println("Netbird down")
|
cmd.Println("netbird down")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !initialLevelTrace {
|
if !initialLevelTrace {
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ func init() {
|
|||||||
|
|
||||||
var loginCmd = &cobra.Command{
|
var loginCmd = &cobra.Command{
|
||||||
Use: "login",
|
Use: "login",
|
||||||
Short: "login to the Netbird Management Service (first run)",
|
Short: "login to the NetBird Management Service (first run)",
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
if err := setEnvAndFlags(cmd); err != nil {
|
if err := setEnvAndFlags(cmd); err != nil {
|
||||||
return fmt.Errorf("set env and flags: %v", err)
|
return fmt.Errorf("set env and flags: %v", err)
|
||||||
|
|||||||
@@ -12,14 +12,15 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var logoutCmd = &cobra.Command{
|
var logoutCmd = &cobra.Command{
|
||||||
Use: "logout",
|
Use: "deregister",
|
||||||
Short: "logout from the Netbird Management Service and delete peer",
|
Aliases: []string{"logout"},
|
||||||
|
Short: "deregister from the NetBird Management Service and delete peer",
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
SetFlagsFromEnvVars(rootCmd)
|
SetFlagsFromEnvVars(rootCmd)
|
||||||
|
|
||||||
cmd.SetOut(cmd.OutOrStdout())
|
cmd.SetOut(cmd.OutOrStdout())
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
|
ctx, cancel := context.WithTimeout(cmd.Context(), time.Second*15)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||||
@@ -44,10 +45,10 @@ var logoutCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
|
|
||||||
if _, err := daemonClient.Logout(ctx, req); err != nil {
|
if _, err := daemonClient.Logout(ctx, req); err != nil {
|
||||||
return fmt.Errorf("logout: %v", err)
|
return fmt.Errorf("deregister: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.Println("Logged out successfully")
|
cmd.Println("Deregistered successfully")
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,14 +16,14 @@ import (
|
|||||||
|
|
||||||
var profileCmd = &cobra.Command{
|
var profileCmd = &cobra.Command{
|
||||||
Use: "profile",
|
Use: "profile",
|
||||||
Short: "manage Netbird profiles",
|
Short: "manage NetBird profiles",
|
||||||
Long: `Manage Netbird profiles, allowing you to list, switch, and remove profiles.`,
|
Long: `Manage NetBird profiles, allowing you to list, switch, and remove profiles.`,
|
||||||
}
|
}
|
||||||
|
|
||||||
var profileListCmd = &cobra.Command{
|
var profileListCmd = &cobra.Command{
|
||||||
Use: "list",
|
Use: "list",
|
||||||
Short: "list all profiles",
|
Short: "list all profiles",
|
||||||
Long: `List all available profiles in the Netbird client.`,
|
Long: `List all available profiles in the NetBird client.`,
|
||||||
Aliases: []string{"ls"},
|
Aliases: []string{"ls"},
|
||||||
RunE: listProfilesFunc,
|
RunE: listProfilesFunc,
|
||||||
}
|
}
|
||||||
@@ -31,7 +31,7 @@ var profileListCmd = &cobra.Command{
|
|||||||
var profileAddCmd = &cobra.Command{
|
var profileAddCmd = &cobra.Command{
|
||||||
Use: "add <profile_name>",
|
Use: "add <profile_name>",
|
||||||
Short: "add a new profile",
|
Short: "add a new profile",
|
||||||
Long: `Add a new profile to the Netbird client. The profile name must be unique.`,
|
Long: `Add a new profile to the NetBird client. The profile name must be unique.`,
|
||||||
Args: cobra.ExactArgs(1),
|
Args: cobra.ExactArgs(1),
|
||||||
RunE: addProfileFunc,
|
RunE: addProfileFunc,
|
||||||
}
|
}
|
||||||
@@ -39,7 +39,7 @@ var profileAddCmd = &cobra.Command{
|
|||||||
var profileRemoveCmd = &cobra.Command{
|
var profileRemoveCmd = &cobra.Command{
|
||||||
Use: "remove <profile_name>",
|
Use: "remove <profile_name>",
|
||||||
Short: "remove a profile",
|
Short: "remove a profile",
|
||||||
Long: `Remove a profile from the Netbird client. The profile must not be active.`,
|
Long: `Remove a profile from the NetBird client. The profile must not be active.`,
|
||||||
Args: cobra.ExactArgs(1),
|
Args: cobra.ExactArgs(1),
|
||||||
RunE: removeProfileFunc,
|
RunE: removeProfileFunc,
|
||||||
}
|
}
|
||||||
@@ -47,7 +47,7 @@ var profileRemoveCmd = &cobra.Command{
|
|||||||
var profileSelectCmd = &cobra.Command{
|
var profileSelectCmd = &cobra.Command{
|
||||||
Use: "select <profile_name>",
|
Use: "select <profile_name>",
|
||||||
Short: "select a profile",
|
Short: "select a profile",
|
||||||
Long: `Select a profile to be the active profile in the Netbird client. The profile must exist.`,
|
Long: `Select a profile to be the active profile in the NetBird client. The profile must exist.`,
|
||||||
Args: cobra.ExactArgs(1),
|
Args: cobra.ExactArgs(1),
|
||||||
RunE: selectProfileFunc,
|
RunE: selectProfileFunc,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -119,12 +119,12 @@ func init() {
|
|||||||
rootCmd.PersistentFlags().StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]")
|
rootCmd.PersistentFlags().StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]")
|
||||||
rootCmd.PersistentFlags().StringVarP(&managementURL, "management-url", "m", "", fmt.Sprintf("Management Service URL [http|https]://[host]:[port] (default \"%s\")", profilemanager.DefaultManagementURL))
|
rootCmd.PersistentFlags().StringVarP(&managementURL, "management-url", "m", "", fmt.Sprintf("Management Service URL [http|https]://[host]:[port] (default \"%s\")", profilemanager.DefaultManagementURL))
|
||||||
rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", profilemanager.DefaultAdminURL))
|
rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", profilemanager.DefaultAdminURL))
|
||||||
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
|
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets NetBird log level")
|
||||||
rootCmd.PersistentFlags().StringSliceVar(&logFiles, "log-file", []string{defaultLogFile}, "sets Netbird log paths written to simultaneously. If `console` is specified the log will be output to stdout. If `syslog` is specified the log will be sent to syslog daemon. You can pass the flag multiple times or separate entries by `,` character")
|
rootCmd.PersistentFlags().StringSliceVar(&logFiles, "log-file", []string{defaultLogFile}, "sets NetBird log paths written to simultaneously. If `console` is specified the log will be output to stdout. If `syslog` is specified the log will be sent to syslog daemon. You can pass the flag multiple times or separate entries by `,` character")
|
||||||
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
|
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
|
||||||
rootCmd.PersistentFlags().StringVar(&setupKeyPath, "setup-key-file", "", "The path to a setup key obtained from the Management Service Dashboard (used to register peer) This is ignored if the setup-key flag is provided.")
|
rootCmd.PersistentFlags().StringVar(&setupKeyPath, "setup-key-file", "", "The path to a setup key obtained from the Management Service Dashboard (used to register peer) This is ignored if the setup-key flag is provided.")
|
||||||
rootCmd.MarkFlagsMutuallyExclusive("setup-key", "setup-key-file")
|
rootCmd.MarkFlagsMutuallyExclusive("setup-key", "setup-key-file")
|
||||||
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
|
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets WireGuard PreSharedKey property. If set, then only peers that have the same key can communicate.")
|
||||||
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
|
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
|
||||||
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
|
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
|
||||||
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Overrides the default profile file location")
|
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Overrides the default profile file location")
|
||||||
|
|||||||
@@ -50,10 +50,10 @@ func TestSetFlagsFromEnvVars(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
cmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
|
cmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
|
||||||
`comma separated list of external IPs to map to the Wireguard interface`)
|
`comma separated list of external IPs to map to the WireGuard interface`)
|
||||||
cmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
|
cmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "WireGuard interface name")
|
||||||
cmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "Enable Rosenpass feature Rosenpass.")
|
cmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "Enable Rosenpass feature Rosenpass.")
|
||||||
cmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
|
cmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "WireGuard interface listening port")
|
||||||
|
|
||||||
t.Setenv("NB_EXTERNAL_IP_MAP", "abc,dec")
|
t.Setenv("NB_EXTERNAL_IP_MAP", "abc,dec")
|
||||||
t.Setenv("NB_INTERFACE_NAME", "test-name")
|
t.Setenv("NB_INTERFACE_NAME", "test-name")
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import (
|
|||||||
|
|
||||||
var serviceCmd = &cobra.Command{
|
var serviceCmd = &cobra.Command{
|
||||||
Use: "service",
|
Use: "service",
|
||||||
Short: "manages Netbird service",
|
Short: "manages NetBird service",
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -47,7 +47,7 @@ func init() {
|
|||||||
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
|
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
|
||||||
serviceEnvDesc := `Sets extra environment variables for the service. ` +
|
serviceEnvDesc := `Sets extra environment variables for the service. ` +
|
||||||
`You can specify a comma-separated list of KEY=VALUE pairs. ` +
|
`You can specify a comma-separated list of KEY=VALUE pairs. ` +
|
||||||
`E.g. --service-env LOG_LEVEL=debug,CUSTOM_VAR=value`
|
`E.g. --service-env NB_LOG_LEVEL=debug,CUSTOM_VAR=value`
|
||||||
|
|
||||||
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
|
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
|
||||||
reconfigureCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
|
reconfigureCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
|
||||||
@@ -64,7 +64,7 @@ func newSVCConfig() (*service.Config, error) {
|
|||||||
config := &service.Config{
|
config := &service.Config{
|
||||||
Name: serviceName,
|
Name: serviceName,
|
||||||
DisplayName: "Netbird",
|
DisplayName: "Netbird",
|
||||||
Description: "Netbird mesh network client",
|
Description: "NetBird mesh network client",
|
||||||
Option: make(service.KeyValue),
|
Option: make(service.KeyValue),
|
||||||
EnvVars: make(map[string]string),
|
EnvVars: make(map[string]string),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ import (
|
|||||||
|
|
||||||
func (p *program) Start(svc service.Service) error {
|
func (p *program) Start(svc service.Service) error {
|
||||||
// Start should not block. Do the actual work async.
|
// Start should not block. Do the actual work async.
|
||||||
log.Info("starting Netbird service") //nolint
|
log.Info("starting NetBird service") //nolint
|
||||||
|
|
||||||
// Collect static system and platform information
|
// Collect static system and platform information
|
||||||
system.UpdateStaticInfo()
|
system.UpdateStaticInfo()
|
||||||
@@ -97,7 +97,7 @@ func (p *program) Stop(srv service.Service) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
time.Sleep(time.Second * 2)
|
time.Sleep(time.Second * 2)
|
||||||
log.Info("stopped Netbird service") //nolint
|
log.Info("stopped NetBird service") //nolint
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -131,7 +131,7 @@ func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel
|
|||||||
|
|
||||||
var runCmd = &cobra.Command{
|
var runCmd = &cobra.Command{
|
||||||
Use: "run",
|
Use: "run",
|
||||||
Short: "runs Netbird as service",
|
Short: "runs NetBird as service",
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
ctx, cancel := context.WithCancel(cmd.Context())
|
ctx, cancel := context.WithCancel(cmd.Context())
|
||||||
|
|
||||||
@@ -149,7 +149,7 @@ var runCmd = &cobra.Command{
|
|||||||
|
|
||||||
var startCmd = &cobra.Command{
|
var startCmd = &cobra.Command{
|
||||||
Use: "start",
|
Use: "start",
|
||||||
Short: "starts Netbird service",
|
Short: "starts NetBird service",
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
ctx, cancel := context.WithCancel(cmd.Context())
|
ctx, cancel := context.WithCancel(cmd.Context())
|
||||||
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
||||||
@@ -160,14 +160,14 @@ var startCmd = &cobra.Command{
|
|||||||
if err := s.Start(); err != nil {
|
if err := s.Start(); err != nil {
|
||||||
return fmt.Errorf("start service: %w", err)
|
return fmt.Errorf("start service: %w", err)
|
||||||
}
|
}
|
||||||
cmd.Println("Netbird service has been started")
|
cmd.Println("NetBird service has been started")
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
var stopCmd = &cobra.Command{
|
var stopCmd = &cobra.Command{
|
||||||
Use: "stop",
|
Use: "stop",
|
||||||
Short: "stops Netbird service",
|
Short: "stops NetBird service",
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
ctx, cancel := context.WithCancel(cmd.Context())
|
ctx, cancel := context.WithCancel(cmd.Context())
|
||||||
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
||||||
@@ -178,14 +178,14 @@ var stopCmd = &cobra.Command{
|
|||||||
if err := s.Stop(); err != nil {
|
if err := s.Stop(); err != nil {
|
||||||
return fmt.Errorf("stop service: %w", err)
|
return fmt.Errorf("stop service: %w", err)
|
||||||
}
|
}
|
||||||
cmd.Println("Netbird service has been stopped")
|
cmd.Println("NetBird service has been stopped")
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
var restartCmd = &cobra.Command{
|
var restartCmd = &cobra.Command{
|
||||||
Use: "restart",
|
Use: "restart",
|
||||||
Short: "restarts Netbird service",
|
Short: "restarts NetBird service",
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
ctx, cancel := context.WithCancel(cmd.Context())
|
ctx, cancel := context.WithCancel(cmd.Context())
|
||||||
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
||||||
@@ -196,14 +196,14 @@ var restartCmd = &cobra.Command{
|
|||||||
if err := s.Restart(); err != nil {
|
if err := s.Restart(); err != nil {
|
||||||
return fmt.Errorf("restart service: %w", err)
|
return fmt.Errorf("restart service: %w", err)
|
||||||
}
|
}
|
||||||
cmd.Println("Netbird service has been restarted")
|
cmd.Println("NetBird service has been restarted")
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
var svcStatusCmd = &cobra.Command{
|
var svcStatusCmd = &cobra.Command{
|
||||||
Use: "status",
|
Use: "status",
|
||||||
Short: "shows Netbird service status",
|
Short: "shows NetBird service status",
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
ctx, cancel := context.WithCancel(cmd.Context())
|
ctx, cancel := context.WithCancel(cmd.Context())
|
||||||
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
||||||
@@ -228,7 +228,7 @@ var svcStatusCmd = &cobra.Command{
|
|||||||
statusText = fmt.Sprintf("Unknown (%d)", status)
|
statusText = fmt.Sprintf("Unknown (%d)", status)
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.Printf("Netbird service status: %s\n", statusText)
|
cmd.Printf("NetBird service status: %s\n", statusText)
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ func createServiceConfigForInstall() (*service.Config, error) {
|
|||||||
|
|
||||||
var installCmd = &cobra.Command{
|
var installCmd = &cobra.Command{
|
||||||
Use: "install",
|
Use: "install",
|
||||||
Short: "installs Netbird service",
|
Short: "installs NetBird service",
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
if err := setupServiceCommand(cmd); err != nil {
|
if err := setupServiceCommand(cmd); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -122,14 +122,14 @@ var installCmd = &cobra.Command{
|
|||||||
return fmt.Errorf("install service: %w", err)
|
return fmt.Errorf("install service: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.Println("Netbird service has been installed")
|
cmd.Println("NetBird service has been installed")
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
var uninstallCmd = &cobra.Command{
|
var uninstallCmd = &cobra.Command{
|
||||||
Use: "uninstall",
|
Use: "uninstall",
|
||||||
Short: "uninstalls Netbird service from system",
|
Short: "uninstalls NetBird service from system",
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
if err := setupServiceCommand(cmd); err != nil {
|
if err := setupServiceCommand(cmd); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -152,15 +152,15 @@ var uninstallCmd = &cobra.Command{
|
|||||||
return fmt.Errorf("uninstall service: %w", err)
|
return fmt.Errorf("uninstall service: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.Println("Netbird service has been uninstalled")
|
cmd.Println("NetBird service has been uninstalled")
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
var reconfigureCmd = &cobra.Command{
|
var reconfigureCmd = &cobra.Command{
|
||||||
Use: "reconfigure",
|
Use: "reconfigure",
|
||||||
Short: "reconfigures Netbird service with new settings",
|
Short: "reconfigures NetBird service with new settings",
|
||||||
Long: `Reconfigures the Netbird service with new settings without manual uninstall/install.
|
Long: `Reconfigures the NetBird service with new settings without manual uninstall/install.
|
||||||
This command will temporarily stop the service, update its configuration, and restart it if it was running.`,
|
This command will temporarily stop the service, update its configuration, and restart it if it was running.`,
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
if err := setupServiceCommand(cmd); err != nil {
|
if err := setupServiceCommand(cmd); err != nil {
|
||||||
@@ -186,7 +186,7 @@ This command will temporarily stop the service, update its configuration, and re
|
|||||||
}
|
}
|
||||||
|
|
||||||
if wasRunning {
|
if wasRunning {
|
||||||
cmd.Println("Stopping Netbird service...")
|
cmd.Println("Stopping NetBird service...")
|
||||||
if err := s.Stop(); err != nil {
|
if err := s.Stop(); err != nil {
|
||||||
cmd.Printf("Warning: failed to stop service: %v\n", err)
|
cmd.Printf("Warning: failed to stop service: %v\n", err)
|
||||||
}
|
}
|
||||||
@@ -203,13 +203,13 @@ This command will temporarily stop the service, update its configuration, and re
|
|||||||
}
|
}
|
||||||
|
|
||||||
if wasRunning {
|
if wasRunning {
|
||||||
cmd.Println("Starting Netbird service...")
|
cmd.Println("Starting NetBird service...")
|
||||||
if err := s.Start(); err != nil {
|
if err := s.Start(); err != nil {
|
||||||
return fmt.Errorf("start service after reconfigure: %w", err)
|
return fmt.Errorf("start service after reconfigure: %w", err)
|
||||||
}
|
}
|
||||||
cmd.Println("Netbird service has been reconfigured and started")
|
cmd.Println("NetBird service has been reconfigured and started")
|
||||||
} else {
|
} else {
|
||||||
cmd.Println("Netbird service has been reconfigured")
|
cmd.Println("NetBird service has been reconfigured")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -59,8 +59,8 @@ var sshCmd = &cobra.Command{
|
|||||||
|
|
||||||
ctx := internal.CtxInitState(cmd.Context())
|
ctx := internal.CtxInitState(cmd.Context())
|
||||||
|
|
||||||
pm := profilemanager.NewProfileManager()
|
sm := profilemanager.NewServiceManager(configPath)
|
||||||
activeProf, err := pm.GetActiveProfile()
|
activeProf, err := sm.GetActiveProfileState()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get active profile: %v", err)
|
return fmt.Errorf("get active profile: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ var (
|
|||||||
var stateCmd = &cobra.Command{
|
var stateCmd = &cobra.Command{
|
||||||
Use: "state",
|
Use: "state",
|
||||||
Short: "Manage daemon state",
|
Short: "Manage daemon state",
|
||||||
Long: "Provides commands for managing and inspecting the Netbird daemon state.",
|
Long: "Provides commands for managing and inspecting the NetBird daemon state.",
|
||||||
}
|
}
|
||||||
|
|
||||||
var stateListCmd = &cobra.Command{
|
var stateListCmd = &cobra.Command{
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
@@ -97,6 +98,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
|
|||||||
|
|
||||||
settingsMockManager := settings.NewMockManager(ctrl)
|
settingsMockManager := settings.NewMockManager(ctrl)
|
||||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||||
|
groupsManager := groups.NewManagerMock()
|
||||||
|
|
||||||
settingsMockManager.EXPECT().
|
settingsMockManager.EXPECT().
|
||||||
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||||
@@ -108,7 +110,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{})
|
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
|||||||
@@ -53,15 +53,15 @@ var (
|
|||||||
|
|
||||||
upCmd = &cobra.Command{
|
upCmd = &cobra.Command{
|
||||||
Use: "up",
|
Use: "up",
|
||||||
Short: "install, login and start Netbird client",
|
Short: "install, login and start NetBird client",
|
||||||
RunE: upFunc,
|
RunE: upFunc,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
|
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
|
||||||
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
|
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "WireGuard interface name")
|
||||||
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
|
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "WireGuard interface listening port")
|
||||||
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
|
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
|
||||||
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+
|
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+
|
||||||
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
|
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
|
||||||
@@ -79,7 +79,7 @@ func init() {
|
|||||||
|
|
||||||
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||||
upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
|
upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
|
||||||
upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location. ")
|
upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) NetBird config file location. ")
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
var (
|
var (
|
||||||
versionCmd = &cobra.Command{
|
versionCmd = &cobra.Command{
|
||||||
Use: "version",
|
Use: "version",
|
||||||
Short: "prints Netbird version",
|
Short: "prints NetBird version",
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
cmd.SetOut(cmd.OutOrStdout())
|
cmd.SetOut(cmd.OutOrStdout())
|
||||||
cmd.Println(version.NetbirdVersion())
|
cmd.Println(version.NetbirdVersion())
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
!define WEB_SITE "Netbird.io"
|
!define WEB_SITE "Netbird.io"
|
||||||
!define VERSION $%APPVER%
|
!define VERSION $%APPVER%
|
||||||
!define COPYRIGHT "Netbird Authors, 2022"
|
!define COPYRIGHT "Netbird Authors, 2022"
|
||||||
!define DESCRIPTION "A WireGuard®-based mesh network that connects your devices into a single private network"
|
!define DESCRIPTION "Connect your devices into a secure WireGuard-based overlay network with SSO, MFA, and granular access controls."
|
||||||
!define INSTALLER_NAME "netbird-installer.exe"
|
!define INSTALLER_NAME "netbird-installer.exe"
|
||||||
!define MAIN_APP_EXE "Netbird"
|
!define MAIN_APP_EXE "Netbird"
|
||||||
!define ICON "ui\\assets\\netbird.ico"
|
!define ICON "ui\\assets\\netbird.ico"
|
||||||
@@ -59,9 +59,15 @@ ShowInstDetails Show
|
|||||||
!define MUI_UNICON "${ICON}"
|
!define MUI_UNICON "${ICON}"
|
||||||
!define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}"
|
!define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}"
|
||||||
!define MUI_UNWELCOMEFINISHPAGE_BITMAP "${BANNER}"
|
!define MUI_UNWELCOMEFINISHPAGE_BITMAP "${BANNER}"
|
||||||
!define MUI_FINISHPAGE_RUN
|
!ifndef ARCH
|
||||||
!define MUI_FINISHPAGE_RUN_TEXT "Start ${UI_APP_NAME}"
|
!define ARCH "amd64"
|
||||||
!define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink"
|
!endif
|
||||||
|
|
||||||
|
!if ${ARCH} == "amd64"
|
||||||
|
!define MUI_FINISHPAGE_RUN
|
||||||
|
!define MUI_FINISHPAGE_RUN_TEXT "Start ${UI_APP_NAME}"
|
||||||
|
!define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink"
|
||||||
|
!endif
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
||||||
!define MUI_ABORTWARNING
|
!define MUI_ABORTWARNING
|
||||||
@@ -213,7 +219,15 @@ Section -MainProgram
|
|||||||
${INSTALL_TYPE}
|
${INSTALL_TYPE}
|
||||||
# SetOverwrite ifnewer
|
# SetOverwrite ifnewer
|
||||||
SetOutPath "$INSTDIR"
|
SetOutPath "$INSTDIR"
|
||||||
File /r "..\\dist\\netbird_windows_amd64\\"
|
!ifndef ARCH
|
||||||
|
!define ARCH "amd64"
|
||||||
|
!endif
|
||||||
|
|
||||||
|
!if ${ARCH} == "arm64"
|
||||||
|
File /r "..\\dist\\netbird_windows_arm64\\"
|
||||||
|
!else
|
||||||
|
File /r "..\\dist\\netbird_windows_amd64\\"
|
||||||
|
!endif
|
||||||
SectionEnd
|
SectionEnd
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
||||||
@@ -292,7 +306,9 @@ DetailPrint "Deleting application files..."
|
|||||||
Delete "$INSTDIR\${UI_APP_EXE}"
|
Delete "$INSTDIR\${UI_APP_EXE}"
|
||||||
Delete "$INSTDIR\${MAIN_APP_EXE}"
|
Delete "$INSTDIR\${MAIN_APP_EXE}"
|
||||||
Delete "$INSTDIR\wintun.dll"
|
Delete "$INSTDIR\wintun.dll"
|
||||||
|
!if ${ARCH} == "amd64"
|
||||||
Delete "$INSTDIR\opengl32.dll"
|
Delete "$INSTDIR\opengl32.dll"
|
||||||
|
!endif
|
||||||
DetailPrint "Removing application directory..."
|
DetailPrint "Removing application directory..."
|
||||||
RmDir /r "$INSTDIR"
|
RmDir /r "$INSTDIR"
|
||||||
|
|
||||||
@@ -314,8 +330,10 @@ DetailPrint "Uninstallation finished."
|
|||||||
SectionEnd
|
SectionEnd
|
||||||
|
|
||||||
|
|
||||||
|
!if ${ARCH} == "amd64"
|
||||||
Function LaunchLink
|
Function LaunchLink
|
||||||
SetShellVarContext all
|
SetShellVarContext all
|
||||||
SetOutPath $INSTDIR
|
SetOutPath $INSTDIR
|
||||||
ShellExecAsUser::ShellExecAsUser "" "$DESKTOP\${APP_NAME}.lnk"
|
ShellExecAsUser::ShellExecAsUser "" "$DESKTOP\${APP_NAME}.lnk"
|
||||||
FunctionEnd
|
FunctionEnd
|
||||||
|
!endif
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -70,7 +71,7 @@ func (c *ConnectClient) RunOnAndroid(
|
|||||||
tunAdapter device.TunAdapter,
|
tunAdapter device.TunAdapter,
|
||||||
iFaceDiscover stdnet.ExternalIFaceDiscover,
|
iFaceDiscover stdnet.ExternalIFaceDiscover,
|
||||||
networkChangeListener listener.NetworkChangeListener,
|
networkChangeListener listener.NetworkChangeListener,
|
||||||
dnsAddresses []string,
|
dnsAddresses []netip.AddrPort,
|
||||||
dnsReadyListener dns.ReadyListener,
|
dnsReadyListener dns.ReadyListener,
|
||||||
) error {
|
) error {
|
||||||
// in case of non Android os these variables will be nil
|
// in case of non Android os these variables will be nil
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type resolvConf struct {
|
type resolvConf struct {
|
||||||
nameServers []string
|
nameServers []netip.Addr
|
||||||
searchDomains []string
|
searchDomains []string
|
||||||
others []string
|
others []string
|
||||||
}
|
}
|
||||||
@@ -36,7 +36,7 @@ func parseBackupResolvConf() (*resolvConf, error) {
|
|||||||
func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
|
func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
|
||||||
rconf := &resolvConf{
|
rconf := &resolvConf{
|
||||||
searchDomains: make([]string, 0),
|
searchDomains: make([]string, 0),
|
||||||
nameServers: make([]string, 0),
|
nameServers: make([]netip.Addr, 0),
|
||||||
others: make([]string, 0),
|
others: make([]string, 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -94,7 +94,11 @@ func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
|
|||||||
if len(splitLines) != 2 {
|
if len(splitLines) != 2 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
rconf.nameServers = append(rconf.nameServers, splitLines[1])
|
if addr, err := netip.ParseAddr(splitLines[1]); err == nil {
|
||||||
|
rconf.nameServers = append(rconf.nameServers, addr.Unmap())
|
||||||
|
} else {
|
||||||
|
log.Warnf("invalid nameserver address in resolv.conf: %s, skipping", splitLines[1])
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -104,31 +108,3 @@ func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
|
|||||||
}
|
}
|
||||||
return rconf, nil
|
return rconf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// removeFirstNbNameserver removes the given nameserver from the given file if it is in the first position
|
|
||||||
// and writes the file back to the original location
|
|
||||||
func removeFirstNbNameserver(filename string, nameserverIP netip.Addr) error {
|
|
||||||
resolvConf, err := parseResolvConfFile(filename)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("parse backup resolv.conf: %w", err)
|
|
||||||
}
|
|
||||||
content, err := os.ReadFile(filename)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("read %s: %w", filename, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(resolvConf.nameServers) > 1 && resolvConf.nameServers[0] == nameserverIP.String() {
|
|
||||||
newContent := strings.Replace(string(content), fmt.Sprintf("nameserver %s\n", nameserverIP), "", 1)
|
|
||||||
|
|
||||||
stat, err := os.Stat(filename)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("stat %s: %w", filename, err)
|
|
||||||
}
|
|
||||||
if err := os.WriteFile(filename, []byte(newContent), stat.Mode()); err != nil {
|
|
||||||
return fmt.Errorf("write %s: %w", filename, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -3,13 +3,9 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_parseResolvConf(t *testing.T) {
|
func Test_parseResolvConf(t *testing.T) {
|
||||||
@@ -99,9 +95,13 @@ options debug
|
|||||||
t.Errorf("invalid parse result for search domains, expected: %v, got: %v", testCase.expectedSearch, cfg.searchDomains)
|
t.Errorf("invalid parse result for search domains, expected: %v, got: %v", testCase.expectedSearch, cfg.searchDomains)
|
||||||
}
|
}
|
||||||
|
|
||||||
ok = compareLists(cfg.nameServers, testCase.expectedNS)
|
nsStrings := make([]string, len(cfg.nameServers))
|
||||||
|
for i, ns := range cfg.nameServers {
|
||||||
|
nsStrings[i] = ns.String()
|
||||||
|
}
|
||||||
|
ok = compareLists(nsStrings, testCase.expectedNS)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Errorf("invalid parse result for ns domains, expected: %v, got: %v", testCase.expectedNS, cfg.nameServers)
|
t.Errorf("invalid parse result for ns domains, expected: %v, got: %v", testCase.expectedNS, nsStrings)
|
||||||
}
|
}
|
||||||
|
|
||||||
ok = compareLists(cfg.others, testCase.expectedOther)
|
ok = compareLists(cfg.others, testCase.expectedOther)
|
||||||
@@ -176,87 +176,3 @@ nameserver 192.168.0.1
|
|||||||
t.Errorf("unexpected resolv.conf content: %v", cfg)
|
t.Errorf("unexpected resolv.conf content: %v", cfg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRemoveFirstNbNameserver(t *testing.T) {
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
content string
|
|
||||||
ipToRemove string
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Unrelated nameservers with comments and options",
|
|
||||||
content: `# This is a comment
|
|
||||||
options rotate
|
|
||||||
nameserver 1.1.1.1
|
|
||||||
# Another comment
|
|
||||||
nameserver 8.8.4.4
|
|
||||||
search example.com`,
|
|
||||||
ipToRemove: "9.9.9.9",
|
|
||||||
expected: `# This is a comment
|
|
||||||
options rotate
|
|
||||||
nameserver 1.1.1.1
|
|
||||||
# Another comment
|
|
||||||
nameserver 8.8.4.4
|
|
||||||
search example.com`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "First nameserver matches",
|
|
||||||
content: `search example.com
|
|
||||||
nameserver 9.9.9.9
|
|
||||||
# oof, a comment
|
|
||||||
nameserver 8.8.4.4
|
|
||||||
options attempts:5`,
|
|
||||||
ipToRemove: "9.9.9.9",
|
|
||||||
expected: `search example.com
|
|
||||||
# oof, a comment
|
|
||||||
nameserver 8.8.4.4
|
|
||||||
options attempts:5`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Target IP not the first nameserver",
|
|
||||||
// nolint:dupword
|
|
||||||
content: `# Comment about the first nameserver
|
|
||||||
nameserver 8.8.4.4
|
|
||||||
# Comment before our target
|
|
||||||
nameserver 9.9.9.9
|
|
||||||
options timeout:2`,
|
|
||||||
ipToRemove: "9.9.9.9",
|
|
||||||
// nolint:dupword
|
|
||||||
expected: `# Comment about the first nameserver
|
|
||||||
nameserver 8.8.4.4
|
|
||||||
# Comment before our target
|
|
||||||
nameserver 9.9.9.9
|
|
||||||
options timeout:2`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Only nameserver matches",
|
|
||||||
content: `options debug
|
|
||||||
nameserver 9.9.9.9
|
|
||||||
search localdomain`,
|
|
||||||
ipToRemove: "9.9.9.9",
|
|
||||||
expected: `options debug
|
|
||||||
nameserver 9.9.9.9
|
|
||||||
search localdomain`,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
tempDir := t.TempDir()
|
|
||||||
tempFile := filepath.Join(tempDir, "resolv.conf")
|
|
||||||
err := os.WriteFile(tempFile, []byte(tc.content), 0644)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
ip, err := netip.ParseAddr(tc.ipToRemove)
|
|
||||||
require.NoError(t, err, "Failed to parse IP address")
|
|
||||||
err = removeFirstNbNameserver(tempFile, ip)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
content, err := os.ReadFile(tempFile)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, tc.expected, string(content), "The resulting content should match the expected output.")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -146,7 +146,7 @@ func isNbParamsMissing(nbSearchDomains []string, nbNameserverIP netip.Addr, rCon
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if rConf.nameServers[0] != nbNameserverIP.String() {
|
if rConf.nameServers[0] != nbNameserverIP {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ type fileConfigurator struct {
|
|||||||
repair *repair
|
repair *repair
|
||||||
originalPerms os.FileMode
|
originalPerms os.FileMode
|
||||||
nbNameserverIP netip.Addr
|
nbNameserverIP netip.Addr
|
||||||
originalNameservers []string
|
originalNameservers []netip.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
func newFileConfigurator() (*fileConfigurator, error) {
|
func newFileConfigurator() (*fileConfigurator, error) {
|
||||||
@@ -70,7 +70,7 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *st
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getOriginalNameservers returns the nameservers that were found in the original resolv.conf
|
// getOriginalNameservers returns the nameservers that were found in the original resolv.conf
|
||||||
func (f *fileConfigurator) getOriginalNameservers() []string {
|
func (f *fileConfigurator) getOriginalNameservers() []netip.Addr {
|
||||||
return f.originalNameservers
|
return f.originalNameservers
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -128,20 +128,14 @@ func (f *fileConfigurator) backup() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *fileConfigurator) restore() error {
|
func (f *fileConfigurator) restore() error {
|
||||||
err := removeFirstNbNameserver(fileDefaultResolvConfBackupLocation, f.nbNameserverIP)
|
if err := copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Failed to remove netbird nameserver from %s on backup restore: %s", fileDefaultResolvConfBackupLocation, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
|
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return os.RemoveAll(fileDefaultResolvConfBackupLocation)
|
return os.RemoveAll(fileDefaultResolvConfBackupLocation)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error {
|
func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress netip.Addr) error {
|
||||||
resolvConf, err := parseDefaultResolvConf()
|
resolvConf, err := parseDefaultResolvConf()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parse current resolv.conf: %w", err)
|
return fmt.Errorf("parse current resolv.conf: %w", err)
|
||||||
@@ -152,16 +146,9 @@ func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Add
|
|||||||
return restoreResolvConfFile()
|
return restoreResolvConfFile()
|
||||||
}
|
}
|
||||||
|
|
||||||
currentDNSAddress, err := netip.ParseAddr(resolvConf.nameServers[0])
|
|
||||||
// not a valid first nameserver -> restore
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("restoring unclean shutdown: parse dns address %s failed: %s", resolvConf.nameServers[0], err)
|
|
||||||
return restoreResolvConfFile()
|
|
||||||
}
|
|
||||||
|
|
||||||
// current address is still netbird's non-available dns address -> restore
|
// current address is still netbird's non-available dns address -> restore
|
||||||
// comparing parsed addresses only, to remove ambiguity
|
currentDNSAddress := resolvConf.nameServers[0]
|
||||||
if currentDNSAddress.String() == storedDNSAddress.String() {
|
if currentDNSAddress == storedDNSAddress {
|
||||||
return restoreResolvConfFile()
|
return restoreResolvConfFile()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -239,7 +239,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
|||||||
} else if inServerAddressesArray {
|
} else if inServerAddressesArray {
|
||||||
address := strings.Split(line, " : ")[1]
|
address := strings.Split(line, " : ")[1]
|
||||||
if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() {
|
if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() {
|
||||||
dnsSettings.ServerIP = ip
|
dnsSettings.ServerIP = ip.Unmap()
|
||||||
inServerAddressesArray = false // Stop reading after finding the first IPv4 address
|
inServerAddressesArray = false // Stop reading after finding the first IPv4 address
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -250,7 +250,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// default to 53 port
|
// default to 53 port
|
||||||
dnsSettings.ServerPort = defaultPort
|
dnsSettings.ServerPort = DefaultPort
|
||||||
|
|
||||||
return dnsSettings, nil
|
return dnsSettings, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ func (t osManagerType) String() string {
|
|||||||
|
|
||||||
type restoreHostManager interface {
|
type restoreHostManager interface {
|
||||||
hostManager
|
hostManager
|
||||||
restoreUncleanShutdownDNS(*netip.Addr) error
|
restoreUncleanShutdownDNS(netip.Addr) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(wgInterface string) (hostManager, error) {
|
func newHostManager(wgInterface string) (hostManager, error) {
|
||||||
@@ -130,8 +130,9 @@ func checkStub() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
systemdResolvedAddr := netip.AddrFrom4([4]byte{127, 0, 0, 53}) // 127.0.0.53
|
||||||
for _, ns := range rConf.nameServers {
|
for _, ns := range rConf.nameServers {
|
||||||
if ns == "127.0.0.53" {
|
if ns == systemdResolvedAddr {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -64,9 +64,10 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type registryConfigurator struct {
|
type registryConfigurator struct {
|
||||||
guid string
|
guid string
|
||||||
routingAll bool
|
routingAll bool
|
||||||
gpo bool
|
gpo bool
|
||||||
|
nrptEntryCount int
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
||||||
@@ -177,7 +178,11 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
|||||||
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := stateManager.UpdateState(&ShutdownState{Guid: r.guid, GPO: r.gpo}); err != nil {
|
if err := stateManager.UpdateState(&ShutdownState{
|
||||||
|
Guid: r.guid,
|
||||||
|
GPO: r.gpo,
|
||||||
|
NRPTEntryCount: r.nrptEntryCount,
|
||||||
|
}); err != nil {
|
||||||
log.Errorf("failed to update shutdown state: %s", err)
|
log.Errorf("failed to update shutdown state: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -193,13 +198,24 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(matchDomains) != 0 {
|
if len(matchDomains) != 0 {
|
||||||
if err := r.addDNSMatchPolicy(matchDomains, config.ServerIP); err != nil {
|
count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
|
||||||
|
if err != nil {
|
||||||
return fmt.Errorf("add dns match policy: %w", err)
|
return fmt.Errorf("add dns match policy: %w", err)
|
||||||
}
|
}
|
||||||
|
r.nrptEntryCount = count
|
||||||
} else {
|
} else {
|
||||||
if err := r.removeDNSMatchPolicies(); err != nil {
|
if err := r.removeDNSMatchPolicies(); err != nil {
|
||||||
return fmt.Errorf("remove dns match policies: %w", err)
|
return fmt.Errorf("remove dns match policies: %w", err)
|
||||||
}
|
}
|
||||||
|
r.nrptEntryCount = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := stateManager.UpdateState(&ShutdownState{
|
||||||
|
Guid: r.guid,
|
||||||
|
GPO: r.gpo,
|
||||||
|
NRPTEntryCount: r.nrptEntryCount,
|
||||||
|
}); err != nil {
|
||||||
|
log.Errorf("failed to update shutdown state: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.updateSearchDomains(searchDomains); err != nil {
|
if err := r.updateSearchDomains(searchDomains); err != nil {
|
||||||
@@ -216,32 +232,38 @@ func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
|
|||||||
return fmt.Errorf("adding dns setup for all failed: %w", err)
|
return fmt.Errorf("adding dns setup for all failed: %w", err)
|
||||||
}
|
}
|
||||||
r.routingAll = true
|
r.routingAll = true
|
||||||
log.Infof("configured %s:53 as main DNS forwarder for this peer", ip)
|
log.Infof("configured %s:%d as main DNS forwarder for this peer", ip, DefaultPort)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) error {
|
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) (int, error) {
|
||||||
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
|
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
|
||||||
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
|
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
|
||||||
if r.gpo {
|
for i, domain := range domains {
|
||||||
if err := r.configureDNSPolicy(gpoDnsPolicyConfigMatchPath, domains, ip); err != nil {
|
policyPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
|
||||||
return fmt.Errorf("configure GPO DNS policy: %w", err)
|
if r.gpo {
|
||||||
|
policyPath = fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
singleDomain := []string{domain}
|
||||||
|
|
||||||
|
if err := r.configureDNSPolicy(policyPath, singleDomain, ip); err != nil {
|
||||||
|
return i, fmt.Errorf("configure DNS policy for domain %s: %w", domain, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("added NRPT entry for domain: %s", domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.gpo {
|
||||||
if err := refreshGroupPolicy(); err != nil {
|
if err := refreshGroupPolicy(); err != nil {
|
||||||
log.Warnf("failed to refresh group policy: %v", err)
|
log.Warnf("failed to refresh group policy: %v", err)
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
if err := r.configureDNSPolicy(dnsPolicyConfigMatchPath, domains, ip); err != nil {
|
|
||||||
return fmt.Errorf("configure local DNS policy: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("added %d match domains. Domain list: %s", len(domains), domains)
|
log.Infof("added %d separate NRPT entries. Domain list: %s", len(domains), domains)
|
||||||
return nil
|
return len(domains), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// configureDNSPolicy handles the actual configuration of a DNS policy at the specified path
|
|
||||||
func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error {
|
func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error {
|
||||||
if err := removeRegistryKeyFromDNSPolicyConfig(policyPath); err != nil {
|
if err := removeRegistryKeyFromDNSPolicyConfig(policyPath); err != nil {
|
||||||
return fmt.Errorf("remove existing dns policy: %w", err)
|
return fmt.Errorf("remove existing dns policy: %w", err)
|
||||||
@@ -374,12 +396,25 @@ func (r *registryConfigurator) restoreHostDNS() error {
|
|||||||
|
|
||||||
func (r *registryConfigurator) removeDNSMatchPolicies() error {
|
func (r *registryConfigurator) removeDNSMatchPolicies() error {
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
// Try to remove the base entries (for backward compatibility)
|
||||||
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
|
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove local registry key: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove local base entry: %w", err))
|
||||||
|
}
|
||||||
|
if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove GPO base entry: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil {
|
for i := 0; i < r.nrptEntryCount; i++ {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove GPO registry key: %w", err))
|
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
|
||||||
|
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
|
||||||
|
|
||||||
|
if err := removeRegistryKeyFromDNSPolicyConfig(localPath); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove local entry %d: %w", i, err))
|
||||||
|
}
|
||||||
|
if err := removeRegistryKeyFromDNSPolicyConfig(gpoPath); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove GPO entry %d: %w", i, err))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := refreshGroupPolicy(); err != nil {
|
if err := refreshGroupPolicy(); err != nil {
|
||||||
|
|||||||
@@ -1,38 +1,31 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type hostsDNSHolder struct {
|
type hostsDNSHolder struct {
|
||||||
unprotectedDNSList map[string]struct{}
|
unprotectedDNSList map[netip.AddrPort]struct{}
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostsDNSHolder() *hostsDNSHolder {
|
func newHostsDNSHolder() *hostsDNSHolder {
|
||||||
return &hostsDNSHolder{
|
return &hostsDNSHolder{
|
||||||
unprotectedDNSList: make(map[string]struct{}),
|
unprotectedDNSList: make(map[netip.AddrPort]struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *hostsDNSHolder) set(list []string) {
|
func (h *hostsDNSHolder) set(list []netip.AddrPort) {
|
||||||
h.mutex.Lock()
|
h.mutex.Lock()
|
||||||
h.unprotectedDNSList = make(map[string]struct{})
|
h.unprotectedDNSList = make(map[netip.AddrPort]struct{})
|
||||||
for _, dns := range list {
|
for _, addrPort := range list {
|
||||||
dnsAddr, err := h.normalizeAddress(dns)
|
h.unprotectedDNSList[addrPort] = struct{}{}
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
h.unprotectedDNSList[dnsAddr] = struct{}{}
|
|
||||||
}
|
}
|
||||||
h.mutex.Unlock()
|
h.mutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *hostsDNSHolder) get() map[string]struct{} {
|
func (h *hostsDNSHolder) get() map[netip.AddrPort]struct{} {
|
||||||
h.mutex.RLock()
|
h.mutex.RLock()
|
||||||
l := h.unprotectedDNSList
|
l := h.unprotectedDNSList
|
||||||
h.mutex.RUnlock()
|
h.mutex.RUnlock()
|
||||||
@@ -40,24 +33,10 @@ func (h *hostsDNSHolder) get() map[string]struct{} {
|
|||||||
}
|
}
|
||||||
|
|
||||||
//nolint:unused
|
//nolint:unused
|
||||||
func (h *hostsDNSHolder) isContain(upstream string) bool {
|
func (h *hostsDNSHolder) contains(upstream netip.AddrPort) bool {
|
||||||
h.mutex.RLock()
|
h.mutex.RLock()
|
||||||
defer h.mutex.RUnlock()
|
defer h.mutex.RUnlock()
|
||||||
|
|
||||||
_, ok := h.unprotectedDNSList[upstream]
|
_, ok := h.unprotectedDNSList[upstream]
|
||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *hostsDNSHolder) normalizeAddress(addr string) (string, error) {
|
|
||||||
a, err := netip.ParseAddr(addr)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("invalid upstream IP address: %s, error: %s", addr, err)
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
if a.Is4() {
|
|
||||||
return fmt.Sprintf("%s:53", addr), nil
|
|
||||||
} else {
|
|
||||||
return fmt.Sprintf("[%s]:53", addr), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ func (m *MockServer) DnsIP() netip.Addr {
|
|||||||
return netip.MustParseAddr("100.10.254.255")
|
return netip.MustParseAddr("100.10.254.255")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockServer) OnUpdatedHostDNSServer(strings []string) {
|
func (m *MockServer) OnUpdatedHostDNSServer(addrs []netip.AddrPort) {
|
||||||
// TODO implement me
|
// TODO implement me
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -245,7 +245,7 @@ func (n *networkManagerDbusConfigurator) deleteConnectionSettings() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *networkManagerDbusConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
|
func (n *networkManagerDbusConfigurator) restoreUncleanShutdownDNS(netip.Addr) error {
|
||||||
if err := n.restoreHostDNS(); err != nil {
|
if err := n.restoreHostDNS(); err != nil {
|
||||||
return fmt.Errorf("restoring dns via network-manager: %w", err)
|
return fmt.Errorf("restoring dns via network-manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ type resolvconf struct {
|
|||||||
implType resolvconfType
|
implType resolvconfType
|
||||||
|
|
||||||
originalSearchDomains []string
|
originalSearchDomains []string
|
||||||
originalNameServers []string
|
originalNameServers []netip.Addr
|
||||||
othersConfigs []string
|
othersConfigs []string
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -110,7 +110,7 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *stateman
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *resolvconf) getOriginalNameservers() []string {
|
func (r *resolvconf) getOriginalNameservers() []netip.Addr {
|
||||||
return r.originalNameServers
|
return r.originalNameServers
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -158,7 +158,7 @@ func (r *resolvconf) applyConfig(content bytes.Buffer) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *resolvconf) restoreUncleanShutdownDNS(*netip.Addr) error {
|
func (r *resolvconf) restoreUncleanShutdownDNS(netip.Addr) error {
|
||||||
if err := r.restoreHostDNS(); err != nil {
|
if err := r.restoreHostDNS(); err != nil {
|
||||||
return fmt.Errorf("restoring dns for interface %s: %w", r.ifaceName, err)
|
return fmt.Errorf("restoring dns for interface %s: %w", r.ifaceName, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ type Server interface {
|
|||||||
Stop()
|
Stop()
|
||||||
DnsIP() netip.Addr
|
DnsIP() netip.Addr
|
||||||
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
||||||
OnUpdatedHostDNSServer(strings []string)
|
OnUpdatedHostDNSServer(addrs []netip.AddrPort)
|
||||||
SearchDomains() []string
|
SearchDomains() []string
|
||||||
ProbeAvailability()
|
ProbeAvailability()
|
||||||
}
|
}
|
||||||
@@ -55,7 +55,7 @@ type nsGroupsByDomain struct {
|
|||||||
// hostManagerWithOriginalNS extends the basic hostManager interface
|
// hostManagerWithOriginalNS extends the basic hostManager interface
|
||||||
type hostManagerWithOriginalNS interface {
|
type hostManagerWithOriginalNS interface {
|
||||||
hostManager
|
hostManager
|
||||||
getOriginalNameservers() []string
|
getOriginalNameservers() []netip.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultServer dns server object
|
// DefaultServer dns server object
|
||||||
@@ -136,7 +136,7 @@ func NewDefaultServer(
|
|||||||
func NewDefaultServerPermanentUpstream(
|
func NewDefaultServerPermanentUpstream(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
wgInterface WGIface,
|
wgInterface WGIface,
|
||||||
hostsDnsList []string,
|
hostsDnsList []netip.AddrPort,
|
||||||
config nbdns.Config,
|
config nbdns.Config,
|
||||||
listener listener.NetworkChangeListener,
|
listener listener.NetworkChangeListener,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
@@ -144,6 +144,7 @@ func NewDefaultServerPermanentUpstream(
|
|||||||
) *DefaultServer {
|
) *DefaultServer {
|
||||||
log.Debugf("host dns address list is: %v", hostsDnsList)
|
log.Debugf("host dns address list is: %v", hostsDnsList)
|
||||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
|
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
|
||||||
|
|
||||||
ds.hostsDNSHolder.set(hostsDnsList)
|
ds.hostsDNSHolder.set(hostsDnsList)
|
||||||
ds.permanent = true
|
ds.permanent = true
|
||||||
ds.addHostRootZone()
|
ds.addHostRootZone()
|
||||||
@@ -340,7 +341,7 @@ func (s *DefaultServer) disableDNS() error {
|
|||||||
|
|
||||||
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
||||||
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
||||||
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
|
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []netip.AddrPort) {
|
||||||
s.hostsDNSHolder.set(hostsDnsList)
|
s.hostsDNSHolder.set(hostsDnsList)
|
||||||
|
|
||||||
// Check if there's any root handler
|
// Check if there's any root handler
|
||||||
@@ -461,7 +462,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
|
|
||||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
||||||
|
|
||||||
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
|
if s.service.RuntimePort() != DefaultPort && !s.hostManager.supportCustomPort() {
|
||||||
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
|
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
|
||||||
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
|
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
|
||||||
s.currentConfig.RouteAll = false
|
s.currentConfig.RouteAll = false
|
||||||
@@ -581,14 +582,13 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, ns := range originalNameservers {
|
for _, ns := range originalNameservers {
|
||||||
if ns == config.ServerIP.String() {
|
if ns == config.ServerIP {
|
||||||
log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, config.ServerIP)
|
log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, config.ServerIP)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
ns = formatAddr(ns, defaultPort)
|
addrPort := netip.AddrPortFrom(ns, DefaultPort)
|
||||||
|
handler.upstreamServers = append(handler.upstreamServers, addrPort)
|
||||||
handler.upstreamServers = append(handler.upstreamServers, ns)
|
|
||||||
}
|
}
|
||||||
handler.deactivate = func(error) { /* always active */ }
|
handler.deactivate = func(error) { /* always active */ }
|
||||||
handler.reactivate = func() { /* always active */ }
|
handler.reactivate = func() { /* always active */ }
|
||||||
@@ -695,7 +695,13 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
|
|||||||
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
|
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns))
|
|
||||||
|
if ns.IP == s.service.RuntimeIP() {
|
||||||
|
log.Warnf("skipping nameserver %s as it matches our DNS server IP, preventing potential loop", ns.IP)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
handler.upstreamServers = append(handler.upstreamServers, ns.AddrPort())
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(handler.upstreamServers) == 0 {
|
if len(handler.upstreamServers) == 0 {
|
||||||
@@ -770,18 +776,6 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
|||||||
s.dnsMuxMap = muxUpdateMap
|
s.dnsMuxMap = muxUpdateMap
|
||||||
}
|
}
|
||||||
|
|
||||||
func getNSHostPort(ns nbdns.NameServer) string {
|
|
||||||
return formatAddr(ns.IP.String(), ns.Port)
|
|
||||||
}
|
|
||||||
|
|
||||||
// formatAddr formats a nameserver address with port, handling IPv6 addresses properly
|
|
||||||
func formatAddr(address string, port int) string {
|
|
||||||
if ip, err := netip.ParseAddr(address); err == nil && ip.Is6() {
|
|
||||||
return fmt.Sprintf("[%s]:%d", address, port)
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%s:%d", address, port)
|
|
||||||
}
|
|
||||||
|
|
||||||
// upstreamCallbacks returns two functions, the first one is used to deactivate
|
// upstreamCallbacks returns two functions, the first one is used to deactivate
|
||||||
// the upstream resolver from the configuration, the second one is used to
|
// the upstream resolver from the configuration, the second one is used to
|
||||||
// reactivate it. Not allowed to call reactivate before deactivate.
|
// reactivate it. Not allowed to call reactivate before deactivate.
|
||||||
@@ -879,10 +873,7 @@ func (s *DefaultServer) addHostRootZone() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
handler.upstreamServers = make([]string, 0)
|
handler.upstreamServers = maps.Keys(hostDNSServers)
|
||||||
for k := range hostDNSServers {
|
|
||||||
handler.upstreamServers = append(handler.upstreamServers, k)
|
|
||||||
}
|
|
||||||
handler.deactivate = func(error) {}
|
handler.deactivate = func(error) {}
|
||||||
handler.reactivate = func() {}
|
handler.reactivate = func() {}
|
||||||
|
|
||||||
@@ -893,9 +884,9 @@ func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
|
|||||||
var states []peer.NSGroupState
|
var states []peer.NSGroupState
|
||||||
|
|
||||||
for _, group := range groups {
|
for _, group := range groups {
|
||||||
var servers []string
|
var servers []netip.AddrPort
|
||||||
for _, ns := range group.NameServers {
|
for _, ns := range group.NameServers {
|
||||||
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
|
servers = append(servers, ns.AddrPort())
|
||||||
}
|
}
|
||||||
|
|
||||||
state := peer.NSGroupState{
|
state := peer.NSGroupState{
|
||||||
@@ -927,7 +918,7 @@ func (s *DefaultServer) updateNSState(nsGroup *nbdns.NameServerGroup, err error,
|
|||||||
func generateGroupKey(nsGroup *nbdns.NameServerGroup) string {
|
func generateGroupKey(nsGroup *nbdns.NameServerGroup) string {
|
||||||
var servers []string
|
var servers []string
|
||||||
for _, ns := range nsGroup.NameServers {
|
for _, ns := range nsGroup.NameServers {
|
||||||
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
|
servers = append(servers, ns.AddrPort().String())
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("%v_%v", servers, nsGroup.Domains)
|
return fmt.Sprintf("%v_%v", servers, nsGroup.Domains)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -97,9 +97,9 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
|
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
|
||||||
var srvs []string
|
var srvs []netip.AddrPort
|
||||||
for _, srv := range servers {
|
for _, srv := range servers {
|
||||||
srvs = append(srvs, getNSHostPort(srv))
|
srvs = append(srvs, srv.AddrPort())
|
||||||
}
|
}
|
||||||
return &upstreamResolverBase{
|
return &upstreamResolverBase{
|
||||||
domain: domain,
|
domain: domain,
|
||||||
@@ -705,7 +705,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer wgIFace.Close()
|
defer wgIFace.Close()
|
||||||
|
|
||||||
var dnsList []string
|
var dnsList []netip.AddrPort
|
||||||
dnsConfig := nbdns.Config{}
|
dnsConfig := nbdns.Config{}
|
||||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, peer.NewRecorder("mgm"), false)
|
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, peer.NewRecorder("mgm"), false)
|
||||||
err = dnsServer.Initialize()
|
err = dnsServer.Initialize()
|
||||||
@@ -715,7 +715,8 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer dnsServer.Stop()
|
defer dnsServer.Stop()
|
||||||
|
|
||||||
dnsServer.OnUpdatedHostDNSServer([]string{"8.8.8.8"})
|
addrPort := netip.MustParseAddrPort("8.8.8.8:53")
|
||||||
|
dnsServer.OnUpdatedHostDNSServer([]netip.AddrPort{addrPort})
|
||||||
|
|
||||||
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
||||||
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||||
@@ -731,7 +732,8 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer wgIFace.Close()
|
defer wgIFace.Close()
|
||||||
dnsConfig := nbdns.Config{}
|
dnsConfig := nbdns.Config{}
|
||||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
|
addrPort := netip.MustParseAddrPort("8.8.8.8:53")
|
||||||
|
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []netip.AddrPort{addrPort}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
|
||||||
err = dnsServer.Initialize()
|
err = dnsServer.Initialize()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to initialize DNS server: %v", err)
|
t.Errorf("failed to initialize DNS server: %v", err)
|
||||||
@@ -823,7 +825,8 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer wgIFace.Close()
|
defer wgIFace.Close()
|
||||||
dnsConfig := nbdns.Config{}
|
dnsConfig := nbdns.Config{}
|
||||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
|
addrPort := netip.MustParseAddrPort("8.8.8.8:53")
|
||||||
|
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []netip.AddrPort{addrPort}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
|
||||||
err = dnsServer.Initialize()
|
err = dnsServer.Initialize()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to initialize DNS server: %v", err)
|
t.Errorf("failed to initialize DNS server: %v", err)
|
||||||
@@ -2054,55 +2057,123 @@ func TestLocalResolverPriorityConstants(t *testing.T) {
|
|||||||
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
|
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFormatAddr(t *testing.T) {
|
func TestDNSLoopPrevention(t *testing.T) {
|
||||||
|
wgInterface := &mocWGIface{}
|
||||||
|
service := NewServiceViaMemory(wgInterface)
|
||||||
|
dnsServerIP := service.RuntimeIP()
|
||||||
|
|
||||||
|
server := &DefaultServer{
|
||||||
|
ctx: context.Background(),
|
||||||
|
wgInterface: wgInterface,
|
||||||
|
service: service,
|
||||||
|
localResolver: local.NewResolver(),
|
||||||
|
handlerChain: NewHandlerChain(),
|
||||||
|
hostManager: &noopHostConfigurator{},
|
||||||
|
dnsMuxMap: make(registeredHandlerMap),
|
||||||
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
address string
|
nsGroups []*nbdns.NameServerGroup
|
||||||
port int
|
expectedHandlers int
|
||||||
expected string
|
expectedServers []netip.Addr
|
||||||
|
shouldFilterOwnIP bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "IPv4 address",
|
name: "FilterOwnDNSServerIP",
|
||||||
address: "8.8.8.8",
|
nsGroups: []*nbdns.NameServerGroup{
|
||||||
port: 53,
|
{
|
||||||
expected: "8.8.8.8:53",
|
Primary: true,
|
||||||
|
NameServers: []nbdns.NameServer{
|
||||||
|
{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||||
|
{IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53},
|
||||||
|
{IP: netip.MustParseAddr("1.1.1.1"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||||
|
},
|
||||||
|
Domains: []string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: 1,
|
||||||
|
expectedServers: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("1.1.1.1")},
|
||||||
|
shouldFilterOwnIP: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "IPv4 address with custom port",
|
name: "AllServersFiltered",
|
||||||
address: "1.1.1.1",
|
nsGroups: []*nbdns.NameServerGroup{
|
||||||
port: 5353,
|
{
|
||||||
expected: "1.1.1.1:5353",
|
Primary: false,
|
||||||
|
NameServers: []nbdns.NameServer{
|
||||||
|
{IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53},
|
||||||
|
},
|
||||||
|
Domains: []string{"example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: 0,
|
||||||
|
expectedServers: []netip.Addr{},
|
||||||
|
shouldFilterOwnIP: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "IPv6 address",
|
name: "MixedServersWithOwnIP",
|
||||||
address: "fd78:94bf:7df8::1",
|
nsGroups: []*nbdns.NameServerGroup{
|
||||||
port: 53,
|
{
|
||||||
expected: "[fd78:94bf:7df8::1]:53",
|
Primary: false,
|
||||||
|
NameServers: []nbdns.NameServer{
|
||||||
|
{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||||
|
{IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53},
|
||||||
|
{IP: netip.MustParseAddr("1.1.1.1"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||||
|
{IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53}, // duplicate
|
||||||
|
},
|
||||||
|
Domains: []string{"test.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: 1,
|
||||||
|
expectedServers: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("1.1.1.1")},
|
||||||
|
shouldFilterOwnIP: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "IPv6 address with custom port",
|
name: "NoOwnIPInList",
|
||||||
address: "2001:db8::1",
|
nsGroups: []*nbdns.NameServerGroup{
|
||||||
port: 5353,
|
{
|
||||||
expected: "[2001:db8::1]:5353",
|
Primary: true,
|
||||||
},
|
NameServers: []nbdns.NameServer{
|
||||||
{
|
{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||||
name: "IPv6 localhost",
|
{IP: netip.MustParseAddr("1.1.1.1"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||||
address: "::1",
|
},
|
||||||
port: 53,
|
Domains: []string{},
|
||||||
expected: "[::1]:53",
|
},
|
||||||
},
|
},
|
||||||
{
|
expectedHandlers: 1,
|
||||||
name: "Invalid address treated as hostname",
|
expectedServers: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("1.1.1.1")},
|
||||||
address: "dns.example.com",
|
shouldFilterOwnIP: false,
|
||||||
port: 53,
|
|
||||||
expected: "dns.example.com:53",
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result := formatAddr(tt.address, tt.port)
|
muxUpdates, err := server.buildUpstreamHandlerUpdate(tt.nsGroups)
|
||||||
assert.Equal(t, tt.expected, result)
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, muxUpdates, tt.expectedHandlers)
|
||||||
|
|
||||||
|
if tt.expectedHandlers > 0 {
|
||||||
|
handler := muxUpdates[0].handler.(*upstreamResolver)
|
||||||
|
assert.Len(t, handler.upstreamServers, len(tt.expectedServers))
|
||||||
|
|
||||||
|
if tt.shouldFilterOwnIP {
|
||||||
|
for _, upstream := range handler.upstreamServers {
|
||||||
|
assert.NotEqual(t, dnsServerIP, upstream.Addr())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expected := range tt.expectedServers {
|
||||||
|
found := false
|
||||||
|
for _, upstream := range handler.upstreamServers {
|
||||||
|
if upstream.Addr() == expected {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, found, "Expected server %s not found", expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
defaultPort = 53
|
DefaultPort = 53
|
||||||
)
|
)
|
||||||
|
|
||||||
type service interface {
|
type service interface {
|
||||||
|
|||||||
@@ -122,7 +122,7 @@ func (s *serviceViaListener) RuntimePort() int {
|
|||||||
defer s.listenerFlagLock.Unlock()
|
defer s.listenerFlagLock.Unlock()
|
||||||
|
|
||||||
if s.ebpfService != nil {
|
if s.ebpfService != nil {
|
||||||
return defaultPort
|
return DefaultPort
|
||||||
} else {
|
} else {
|
||||||
return int(s.listenPort)
|
return int(s.listenPort)
|
||||||
}
|
}
|
||||||
@@ -148,9 +148,9 @@ func (s *serviceViaListener) evalListenAddress() (netip.Addr, uint16, error) {
|
|||||||
return s.customAddr.Addr(), s.customAddr.Port(), nil
|
return s.customAddr.Addr(), s.customAddr.Port(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ip, ok := s.testFreePort(defaultPort)
|
ip, ok := s.testFreePort(DefaultPort)
|
||||||
if ok {
|
if ok {
|
||||||
return ip, defaultPort, nil
|
return ip, DefaultPort, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ebpfSrv, port, ok := s.tryToUseeBPF()
|
ebpfSrv, port, ok := s.tryToUseeBPF()
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
|
|||||||
dnsMux: dns.NewServeMux(),
|
dnsMux: dns.NewServeMux(),
|
||||||
|
|
||||||
runtimeIP: lastIP,
|
runtimeIP: lastIP,
|
||||||
runtimePort: defaultPort,
|
runtimePort: DefaultPort,
|
||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -235,7 +235,7 @@ func (s *systemdDbusConfigurator) callLinkMethod(method string, value any) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemdDbusConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
|
func (s *systemdDbusConfigurator) restoreUncleanShutdownDNS(netip.Addr) error {
|
||||||
if err := s.restoreHostDNS(); err != nil {
|
if err := s.restoreHostDNS(); err != nil {
|
||||||
return fmt.Errorf("restoring dns via systemd: %w", err)
|
return fmt.Errorf("restoring dns via systemd: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ func (s *ShutdownState) Cleanup() error {
|
|||||||
return fmt.Errorf("create previous host manager: %w", err)
|
return fmt.Errorf("create previous host manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := manager.restoreUncleanShutdownDNS(&s.DNSAddress); err != nil {
|
if err := manager.restoreUncleanShutdownDNS(s.DNSAddress); err != nil {
|
||||||
return fmt.Errorf("restore unclean shutdown dns: %w", err)
|
return fmt.Errorf("restore unclean shutdown dns: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,8 +5,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type ShutdownState struct {
|
type ShutdownState struct {
|
||||||
Guid string
|
Guid string
|
||||||
GPO bool
|
GPO bool
|
||||||
|
NRPTEntryCount int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ShutdownState) Name() string {
|
func (s *ShutdownState) Name() string {
|
||||||
@@ -15,8 +16,9 @@ func (s *ShutdownState) Name() string {
|
|||||||
|
|
||||||
func (s *ShutdownState) Cleanup() error {
|
func (s *ShutdownState) Cleanup() error {
|
||||||
manager := ®istryConfigurator{
|
manager := ®istryConfigurator{
|
||||||
guid: s.Guid,
|
guid: s.Guid,
|
||||||
gpo: s.GPO,
|
gpo: s.GPO,
|
||||||
|
nrptEntryCount: s.NRPTEntryCount,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := manager.restoreUncleanShutdownDNS(); err != nil {
|
if err := manager.restoreUncleanShutdownDNS(); err != nil {
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -48,7 +49,7 @@ type upstreamResolverBase struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
upstreamClient upstreamClient
|
upstreamClient upstreamClient
|
||||||
upstreamServers []string
|
upstreamServers []netip.AddrPort
|
||||||
domain string
|
domain string
|
||||||
disabled bool
|
disabled bool
|
||||||
failsCount atomic.Int32
|
failsCount atomic.Int32
|
||||||
@@ -79,17 +80,20 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d
|
|||||||
|
|
||||||
// String returns a string representation of the upstream resolver
|
// String returns a string representation of the upstream resolver
|
||||||
func (u *upstreamResolverBase) String() string {
|
func (u *upstreamResolverBase) String() string {
|
||||||
return fmt.Sprintf("upstream %v", u.upstreamServers)
|
return fmt.Sprintf("upstream %s", u.upstreamServers)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ID returns the unique handler ID
|
// ID returns the unique handler ID
|
||||||
func (u *upstreamResolverBase) ID() types.HandlerID {
|
func (u *upstreamResolverBase) ID() types.HandlerID {
|
||||||
servers := slices.Clone(u.upstreamServers)
|
servers := slices.Clone(u.upstreamServers)
|
||||||
slices.Sort(servers)
|
slices.SortFunc(servers, func(a, b netip.AddrPort) int { return a.Compare(b) })
|
||||||
|
|
||||||
hash := sha256.New()
|
hash := sha256.New()
|
||||||
hash.Write([]byte(u.domain + ":"))
|
hash.Write([]byte(u.domain + ":"))
|
||||||
hash.Write([]byte(strings.Join(servers, ",")))
|
for _, s := range servers {
|
||||||
|
hash.Write([]byte(s.String()))
|
||||||
|
hash.Write([]byte("|"))
|
||||||
|
}
|
||||||
return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
|
return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -130,7 +134,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
func() {
|
func() {
|
||||||
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
|
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
rm, t, err = u.upstreamClient.exchange(ctx, upstream, r)
|
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -197,7 +201,7 @@ func (u *upstreamResolverBase) checkUpstreamFails(err error) {
|
|||||||
proto.SystemEvent_DNS,
|
proto.SystemEvent_DNS,
|
||||||
"All upstream servers failed (fail count exceeded)",
|
"All upstream servers failed (fail count exceeded)",
|
||||||
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
|
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
|
||||||
map[string]string{"upstreams": strings.Join(u.upstreamServers, ", ")},
|
map[string]string{"upstreams": u.upstreamServersString()},
|
||||||
// TODO add domain meta
|
// TODO add domain meta
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -258,7 +262,7 @@ func (u *upstreamResolverBase) ProbeAvailability() {
|
|||||||
proto.SystemEvent_DNS,
|
proto.SystemEvent_DNS,
|
||||||
"All upstream servers failed (probe failed)",
|
"All upstream servers failed (probe failed)",
|
||||||
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
|
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
|
||||||
map[string]string{"upstreams": strings.Join(u.upstreamServers, ", ")},
|
map[string]string{"upstreams": u.upstreamServersString()},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -278,7 +282,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
|
|||||||
operation := func() error {
|
operation := func() error {
|
||||||
select {
|
select {
|
||||||
case <-u.ctx.Done():
|
case <-u.ctx.Done():
|
||||||
return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServers))
|
return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServersString()))
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -291,7 +295,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServers, exponentialBackOff.NextBackOff())
|
log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServersString(), exponentialBackOff.NextBackOff())
|
||||||
return fmt.Errorf("upstream check call error")
|
return fmt.Errorf("upstream check call error")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -301,7 +305,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServers)
|
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
|
||||||
u.failsCount.Store(0)
|
u.failsCount.Store(0)
|
||||||
u.successCount.Add(1)
|
u.successCount.Add(1)
|
||||||
u.reactivate()
|
u.reactivate()
|
||||||
@@ -331,13 +335,21 @@ func (u *upstreamResolverBase) disable(err error) {
|
|||||||
go u.waitUntilResponse()
|
go u.waitUntilResponse()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) testNameserver(server string, timeout time.Duration) error {
|
func (u *upstreamResolverBase) upstreamServersString() string {
|
||||||
|
var servers []string
|
||||||
|
for _, server := range u.upstreamServers {
|
||||||
|
servers = append(servers, server.String())
|
||||||
|
}
|
||||||
|
return strings.Join(servers, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *upstreamResolverBase) testNameserver(server netip.AddrPort, timeout time.Duration) error {
|
||||||
ctx, cancel := context.WithTimeout(u.ctx, timeout)
|
ctx, cancel := context.WithTimeout(u.ctx, timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
|
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
|
||||||
|
|
||||||
_, _, err := u.upstreamClient.exchange(ctx, server, r)
|
_, _, err := u.upstreamClient.exchange(ctx, server.String(), r)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -79,8 +79,8 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolver) isLocalResolver(upstream string) bool {
|
func (u *upstreamResolver) isLocalResolver(upstream string) bool {
|
||||||
if u.hostsDNSHolder.isContain(upstream) {
|
if addrPort, err := netip.ParseAddrPort(upstream); err == nil {
|
||||||
return true
|
return u.hostsDNSHolder.contains(addrPort)
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -62,6 +62,8 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
|||||||
upstreamIP, err := netip.ParseAddr(upstreamHost)
|
upstreamIP, err := netip.ParseAddr(upstreamHost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to parse upstream host %s: %s", upstreamHost, err)
|
log.Warnf("failed to parse upstream host %s: %s", upstreamHost, err)
|
||||||
|
} else {
|
||||||
|
upstreamIP = upstreamIP.Unmap()
|
||||||
}
|
}
|
||||||
if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() {
|
if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() {
|
||||||
log.Debugf("using private client to query upstream: %s", upstream)
|
log.Debugf("using private client to query upstream: %s", upstream)
|
||||||
|
|||||||
@@ -59,7 +59,14 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.TODO())
|
ctx, cancel := context.WithCancel(context.TODO())
|
||||||
resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".")
|
resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".")
|
||||||
resolver.upstreamServers = testCase.InputServers
|
// Convert test servers to netip.AddrPort
|
||||||
|
var servers []netip.AddrPort
|
||||||
|
for _, server := range testCase.InputServers {
|
||||||
|
if addrPort, err := netip.ParseAddrPort(server); err == nil {
|
||||||
|
servers = append(servers, netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resolver.upstreamServers = servers
|
||||||
resolver.upstreamTimeout = testCase.timeout
|
resolver.upstreamTimeout = testCase.timeout
|
||||||
if testCase.cancelCTX {
|
if testCase.cancelCTX {
|
||||||
cancel()
|
cancel()
|
||||||
@@ -128,7 +135,8 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
|||||||
reactivatePeriod: reactivatePeriod,
|
reactivatePeriod: reactivatePeriod,
|
||||||
failsTillDeact: failsTillDeact,
|
failsTillDeact: failsTillDeact,
|
||||||
}
|
}
|
||||||
resolver.upstreamServers = []string{"0.0.0.0:-1"}
|
addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection
|
||||||
|
resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())}
|
||||||
resolver.failsTillDeact = 0
|
resolver.failsTillDeact = 0
|
||||||
resolver.reactivatePeriod = time.Microsecond * 100
|
resolver.reactivatePeriod = time.Microsecond * 100
|
||||||
|
|
||||||
|
|||||||
@@ -165,7 +165,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
ips, err := f.resolver.LookupNetIP(ctx, network, domain)
|
ips, err := f.resolver.LookupNetIP(ctx, network, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.handleDNSError(w, query, resp, domain, err)
|
f.handleDNSError(ctx, w, question, resp, domain, err)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -244,20 +244,57 @@ func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixe
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setResponseCodeForNotFound determines and sets the appropriate response code when IsNotFound is true
|
||||||
|
// It distinguishes between NXDOMAIN (domain doesn't exist) and NODATA (domain exists but no records of requested type)
|
||||||
|
//
|
||||||
|
// LIMITATION: This function only checks A and AAAA record types to determine domain existence.
|
||||||
|
// If a domain has only other record types (MX, TXT, CNAME, etc.) but no A/AAAA records,
|
||||||
|
// it may incorrectly return NXDOMAIN instead of NODATA. This is acceptable since the forwarder
|
||||||
|
// only handles A/AAAA queries and returns NOTIMP for other types.
|
||||||
|
func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns.Msg, domain string, originalQtype uint16) {
|
||||||
|
// Try querying for a different record type to see if the domain exists
|
||||||
|
// If the original query was for AAAA, try A. If it was for A, try AAAA.
|
||||||
|
// This helps distinguish between NXDOMAIN and NODATA.
|
||||||
|
var alternativeNetwork string
|
||||||
|
switch originalQtype {
|
||||||
|
case dns.TypeAAAA:
|
||||||
|
alternativeNetwork = "ip4"
|
||||||
|
case dns.TypeA:
|
||||||
|
alternativeNetwork = "ip6"
|
||||||
|
default:
|
||||||
|
resp.Rcode = dns.RcodeNameError
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := f.resolver.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
|
||||||
|
var dnsErr *net.DNSError
|
||||||
|
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
|
||||||
|
// Alternative query also returned not found - domain truly doesn't exist
|
||||||
|
resp.Rcode = dns.RcodeNameError
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Some other error (timeout, server failure, etc.) - can't determine, assume domain exists
|
||||||
|
resp.Rcode = dns.RcodeSuccess
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Alternative query succeeded - domain exists but has no records of this type
|
||||||
|
resp.Rcode = dns.RcodeSuccess
|
||||||
|
}
|
||||||
|
|
||||||
// handleDNSError processes DNS lookup errors and sends an appropriate error response
|
// handleDNSError processes DNS lookup errors and sends an appropriate error response
|
||||||
func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, query, resp *dns.Msg, domain string, err error) {
|
func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter, question dns.Question, resp *dns.Msg, domain string, err error) {
|
||||||
var dnsErr *net.DNSError
|
var dnsErr *net.DNSError
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case errors.As(err, &dnsErr):
|
case errors.As(err, &dnsErr):
|
||||||
resp.Rcode = dns.RcodeServerFailure
|
resp.Rcode = dns.RcodeServerFailure
|
||||||
if dnsErr.IsNotFound {
|
if dnsErr.IsNotFound {
|
||||||
// Pass through NXDOMAIN
|
f.setResponseCodeForNotFound(ctx, resp, domain, question.Qtype)
|
||||||
resp.Rcode = dns.RcodeNameError
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if dnsErr.Server != "" {
|
if dnsErr.Server != "" {
|
||||||
log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[query.Question[0].Qtype], domain, dnsErr.Server, err)
|
log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[question.Qtype], domain, dnsErr.Server, err)
|
||||||
} else {
|
} else {
|
||||||
log.Warnf(errResolveFailed, domain, err)
|
log.Warnf(errResolveFailed, domain, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package dnsfwd
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -16,8 +17,8 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_getMatchingEntries(t *testing.T) {
|
func Test_getMatchingEntries(t *testing.T) {
|
||||||
@@ -708,6 +709,131 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
|||||||
assert.Len(t, matches, 3, "Should match 3 patterns")
|
assert.Len(t, matches, 3, "Should match 3 patterns")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestDNSForwarder_NodataVsNxdomain tests that the forwarder correctly distinguishes
|
||||||
|
// between NXDOMAIN (domain doesn't exist) and NODATA (domain exists but no records of that type)
|
||||||
|
func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
|
||||||
|
mockFirewall := &MockFirewall{}
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
|
||||||
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||||
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
|
d, err := domain.FromString("example.com")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
set := firewall.NewDomainSet([]domain.Domain{d})
|
||||||
|
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res", Set: set}}
|
||||||
|
forwarder.UpdateDomains(entries)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
queryType uint16
|
||||||
|
setupMocks func()
|
||||||
|
expectedCode int
|
||||||
|
expectNoAnswer bool // true if we expect NOERROR with empty answer (NODATA case)
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "domain exists but no AAAA records (NODATA)",
|
||||||
|
queryType: dns.TypeAAAA,
|
||||||
|
setupMocks: func() {
|
||||||
|
// First query for AAAA returns not found
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip6", "example.com.").
|
||||||
|
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
|
||||||
|
// Check query for A records succeeds (domain exists)
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
|
||||||
|
Return([]netip.Addr{netip.MustParseAddr("1.2.3.4")}, nil).Once()
|
||||||
|
},
|
||||||
|
expectedCode: dns.RcodeSuccess,
|
||||||
|
expectNoAnswer: true,
|
||||||
|
description: "Should return NOERROR when domain exists but has no records of requested type",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "domain exists but no A records (NODATA)",
|
||||||
|
queryType: dns.TypeA,
|
||||||
|
setupMocks: func() {
|
||||||
|
// First query for A returns not found
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
|
||||||
|
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
|
||||||
|
// Check query for AAAA records succeeds (domain exists)
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip6", "example.com.").
|
||||||
|
Return([]netip.Addr{netip.MustParseAddr("2001:db8::1")}, nil).Once()
|
||||||
|
},
|
||||||
|
expectedCode: dns.RcodeSuccess,
|
||||||
|
expectNoAnswer: true,
|
||||||
|
description: "Should return NOERROR when domain exists but has no A records",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "domain doesn't exist (NXDOMAIN)",
|
||||||
|
queryType: dns.TypeA,
|
||||||
|
setupMocks: func() {
|
||||||
|
// First query for A returns not found
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
|
||||||
|
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
|
||||||
|
// Check query for AAAA also returns not found (domain doesn't exist)
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip6", "example.com.").
|
||||||
|
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
|
||||||
|
},
|
||||||
|
expectedCode: dns.RcodeNameError,
|
||||||
|
expectNoAnswer: true,
|
||||||
|
description: "Should return NXDOMAIN when domain doesn't exist at all",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "domain exists with records (normal success)",
|
||||||
|
queryType: dns.TypeA,
|
||||||
|
setupMocks: func() {
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
|
||||||
|
Return([]netip.Addr{netip.MustParseAddr("1.2.3.4")}, nil).Once()
|
||||||
|
// Expect firewall update for successful resolution
|
||||||
|
expectedPrefix := netip.PrefixFrom(netip.MustParseAddr("1.2.3.4"), 32)
|
||||||
|
mockFirewall.On("UpdateSet", set, []netip.Prefix{expectedPrefix}).Return(nil).Once()
|
||||||
|
},
|
||||||
|
expectedCode: dns.RcodeSuccess,
|
||||||
|
expectNoAnswer: false,
|
||||||
|
description: "Should return NOERROR with answer when records exist",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Reset mock expectations
|
||||||
|
mockResolver.ExpectedCalls = nil
|
||||||
|
mockResolver.Calls = nil
|
||||||
|
mockFirewall.ExpectedCalls = nil
|
||||||
|
mockFirewall.Calls = nil
|
||||||
|
|
||||||
|
tt.setupMocks()
|
||||||
|
|
||||||
|
query := &dns.Msg{}
|
||||||
|
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
|
||||||
|
|
||||||
|
var writtenResp *dns.Msg
|
||||||
|
mockWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
writtenResp = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := forwarder.handleDNSQuery(mockWriter, query)
|
||||||
|
|
||||||
|
// If a response was returned, it means it should be written (happens in wrapper functions)
|
||||||
|
if resp != nil && writtenResp == nil {
|
||||||
|
writtenResp = resp
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NotNil(t, writtenResp, "Expected response to be written")
|
||||||
|
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
|
||||||
|
|
||||||
|
if tt.expectNoAnswer {
|
||||||
|
assert.Empty(t, writtenResp.Answer, "Response should have no answer records")
|
||||||
|
}
|
||||||
|
|
||||||
|
mockResolver.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDNSForwarder_EmptyQuery(t *testing.T) {
|
func TestDNSForwarder_EmptyQuery(t *testing.T) {
|
||||||
// Test handling of malformed query with no questions
|
// Test handling of malformed query with no questions
|
||||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
@@ -1564,13 +1565,14 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
AnyTimes()
|
AnyTimes()
|
||||||
|
|
||||||
permissionsManager := permissions.NewManager(store)
|
permissionsManager := permissions.NewManager(store)
|
||||||
|
groupsManager := groups.NewManagerMock()
|
||||||
|
|
||||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
|
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package internal
|
package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
@@ -13,7 +15,7 @@ type MobileDependency struct {
|
|||||||
TunAdapter device.TunAdapter
|
TunAdapter device.TunAdapter
|
||||||
IFaceDiscover stdnet.ExternalIFaceDiscover
|
IFaceDiscover stdnet.ExternalIFaceDiscover
|
||||||
NetworkChangeListener listener.NetworkChangeListener
|
NetworkChangeListener listener.NetworkChangeListener
|
||||||
HostDNSAddresses []string
|
HostDNSAddresses []netip.AddrPort
|
||||||
DnsReadyListener dns.ReadyListener
|
DnsReadyListener dns.ReadyListener
|
||||||
|
|
||||||
// iOS only
|
// iOS only
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ type RosenpassState struct {
|
|||||||
// whether it's enabled, and the last error message encountered during probing.
|
// whether it's enabled, and the last error message encountered during probing.
|
||||||
type NSGroupState struct {
|
type NSGroupState struct {
|
||||||
ID string
|
ID string
|
||||||
Servers []string
|
Servers []netip.AddrPort
|
||||||
Domains []string
|
Domains []string
|
||||||
Enabled bool
|
Enabled bool
|
||||||
Error error
|
Error error
|
||||||
|
|||||||
@@ -593,17 +593,9 @@ func update(input ConfigInput) (*Config, error) {
|
|||||||
return config, nil
|
return config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetConfig read config file and return with Config. Errors out if it does not exist
|
||||||
func GetConfig(configPath string) (*Config, error) {
|
func GetConfig(configPath string) (*Config, error) {
|
||||||
if !fileExists(configPath) {
|
return readConfig(configPath, false)
|
||||||
return nil, fmt.Errorf("config file %s does not exist", configPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
config := &Config{}
|
|
||||||
if _, err := util.ReadJson(configPath, config); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to read config file %s: %w", configPath, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return config, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain.
|
// UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain.
|
||||||
@@ -695,6 +687,11 @@ func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
|
|||||||
|
|
||||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||||
func ReadConfig(configPath string) (*Config, error) {
|
func ReadConfig(configPath string) (*Config, error) {
|
||||||
|
return readConfig(configPath, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||||
|
func readConfig(configPath string, createIfMissing bool) (*Config, error) {
|
||||||
if fileExists(configPath) {
|
if fileExists(configPath) {
|
||||||
err := util.EnforcePermission(configPath)
|
err := util.EnforcePermission(configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -715,6 +712,8 @@ func ReadConfig(configPath string) (*Config, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return config, nil
|
return config, nil
|
||||||
|
} else if !createIfMissing {
|
||||||
|
return nil, fmt.Errorf("config file %s does not exist", configPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg, err := createNewConfig(ConfigInput{ConfigPath: configPath})
|
cfg, err := createNewConfig(ConfigInput{ConfigPath: configPath})
|
||||||
|
|||||||
@@ -16,19 +16,21 @@
|
|||||||
<StandardDirectory Id="ProgramFiles64Folder">
|
<StandardDirectory Id="ProgramFiles64Folder">
|
||||||
<Directory Id="NetbirdInstallDir" Name="Netbird">
|
<Directory Id="NetbirdInstallDir" Name="Netbird">
|
||||||
<Component Id="NetbirdFiles" Guid="db3165de-cc6e-4922-8396-9d892950e23e" Bitness="always64">
|
<Component Id="NetbirdFiles" Guid="db3165de-cc6e-4922-8396-9d892950e23e" Bitness="always64">
|
||||||
<File ProcessorArchitecture="x64" Source=".\dist\netbird_windows_amd64\netbird.exe" KeyPath="yes" />
|
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\netbird.exe" KeyPath="yes" />
|
||||||
<File ProcessorArchitecture="x64" Source=".\dist\netbird_windows_amd64\netbird-ui.exe">
|
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\netbird-ui.exe">
|
||||||
<Shortcut Id="NetbirdDesktopShortcut" Directory="DesktopFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" />
|
<Shortcut Id="NetbirdDesktopShortcut" Directory="DesktopFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" />
|
||||||
<Shortcut Id="NetbirdStartMenuShortcut" Directory="StartMenuFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" />
|
<Shortcut Id="NetbirdStartMenuShortcut" Directory="StartMenuFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" />
|
||||||
</File>
|
</File>
|
||||||
<File ProcessorArchitecture="x64" Source=".\dist\netbird_windows_amd64\wintun.dll" />
|
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\wintun.dll" />
|
||||||
<File ProcessorArchitecture="x64" Source=".\dist\netbird_windows_amd64\opengl32.dll" />
|
<?if $(var.ArchSuffix) = "amd64" ?>
|
||||||
|
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\opengl32.dll" />
|
||||||
|
<?endif ?>
|
||||||
|
|
||||||
<ServiceInstall
|
<ServiceInstall
|
||||||
Id="NetBirdService"
|
Id="NetBirdService"
|
||||||
Name="NetBird"
|
Name="NetBird"
|
||||||
DisplayName="NetBird"
|
DisplayName="NetBird"
|
||||||
Description="A WireGuard-based mesh network that connects your devices into a single private network."
|
Description="Connect your devices into a secure WireGuard-based overlay network with SSO, MFA and granular access controls."
|
||||||
Start="auto" Type="ownProcess"
|
Start="auto" Type="ownProcess"
|
||||||
ErrorControl="normal"
|
ErrorControl="normal"
|
||||||
Account="LocalSystem"
|
Account="LocalSystem"
|
||||||
|
|||||||
@@ -1197,8 +1197,14 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
|
|||||||
if dnsState.Error != nil {
|
if dnsState.Error != nil {
|
||||||
err = dnsState.Error.Error()
|
err = dnsState.Error.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var servers []string
|
||||||
|
for _, server := range dnsState.Servers {
|
||||||
|
servers = append(servers, server.String())
|
||||||
|
}
|
||||||
|
|
||||||
pbDnsState := &proto.NSGroupState{
|
pbDnsState := &proto.NSGroupState{
|
||||||
Servers: dnsState.Servers,
|
Servers: servers,
|
||||||
Domains: dnsState.Domains,
|
Domains: dnsState.Domains,
|
||||||
Enabled: dnsState.Enabled,
|
Enabled: dnsState.Enabled,
|
||||||
Error: err,
|
Error: err,
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -302,13 +303,14 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
|||||||
t.Cleanup(ctrl.Finish)
|
t.Cleanup(ctrl.Finish)
|
||||||
settingsMockManager := settings.NewMockManager(ctrl)
|
settingsMockManager := settings.NewMockManager(ctrl)
|
||||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||||
|
groupsManager := groups.NewManagerMock()
|
||||||
|
|
||||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
|
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ func (s *serviceClient) showProfilesUI() {
|
|||||||
widget.NewLabel(""), // profile name
|
widget.NewLabel(""), // profile name
|
||||||
layout.NewSpacer(),
|
layout.NewSpacer(),
|
||||||
widget.NewButton("Select", nil),
|
widget.NewButton("Select", nil),
|
||||||
widget.NewButton("Logout", nil),
|
widget.NewButton("Deregister", nil),
|
||||||
widget.NewButton("Remove", nil),
|
widget.NewButton("Remove", nil),
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
@@ -128,7 +128,7 @@ func (s *serviceClient) showProfilesUI() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
logoutBtn.Show()
|
logoutBtn.Show()
|
||||||
logoutBtn.SetText("Logout")
|
logoutBtn.SetText("Deregister")
|
||||||
logoutBtn.OnTapped = func() {
|
logoutBtn.OnTapped = func() {
|
||||||
s.handleProfileLogout(profile.Name, refresh)
|
s.handleProfileLogout(profile.Name, refresh)
|
||||||
}
|
}
|
||||||
@@ -143,7 +143,7 @@ func (s *serviceClient) showProfilesUI() {
|
|||||||
if !confirm {
|
if !confirm {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = s.removeProfile(profile.Name)
|
err = s.removeProfile(profile.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to remove profile: %v", err)
|
log.Errorf("failed to remove profile: %v", err)
|
||||||
@@ -334,27 +334,27 @@ func (s *serviceClient) getProfiles() ([]Profile, error) {
|
|||||||
|
|
||||||
func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback func()) {
|
func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback func()) {
|
||||||
dialog.ShowConfirm(
|
dialog.ShowConfirm(
|
||||||
"Logout",
|
"Deregister",
|
||||||
fmt.Sprintf("Are you sure you want to logout from '%s'?", profileName),
|
fmt.Sprintf("Are you sure you want to deregister from '%s'?", profileName),
|
||||||
func(confirm bool) {
|
func(confirm bool) {
|
||||||
if !confirm {
|
if !confirm {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to get service client: %v", err)
|
log.Errorf("failed to get service client: %v", err)
|
||||||
dialog.ShowError(fmt.Errorf("failed to connect to service"), s.wProfiles)
|
dialog.ShowError(fmt.Errorf("failed to connect to service"), s.wProfiles)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
currUser, err := user.Current()
|
currUser, err := user.Current()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to get current user: %v", err)
|
log.Errorf("failed to get current user: %v", err)
|
||||||
dialog.ShowError(fmt.Errorf("failed to get current user"), s.wProfiles)
|
dialog.ShowError(fmt.Errorf("failed to get current user"), s.wProfiles)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
username := currUser.Username
|
username := currUser.Username
|
||||||
_, err = conn.Logout(s.ctx, &proto.LogoutRequest{
|
_, err = conn.Logout(s.ctx, &proto.LogoutRequest{
|
||||||
ProfileName: &profileName,
|
ProfileName: &profileName,
|
||||||
@@ -362,16 +362,16 @@ func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback
|
|||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("logout failed: %v", err)
|
log.Errorf("logout failed: %v", err)
|
||||||
dialog.ShowError(fmt.Errorf("logout failed"), s.wProfiles)
|
dialog.ShowError(fmt.Errorf("deregister failed"), s.wProfiles)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dialog.ShowInformation(
|
dialog.ShowInformation(
|
||||||
"Logged Out",
|
"Deregistered",
|
||||||
fmt.Sprintf("Successfully logged out from '%s'", profileName),
|
fmt.Sprintf("Successfully deregistered from '%s'", profileName),
|
||||||
s.wProfiles,
|
s.wProfiles,
|
||||||
)
|
)
|
||||||
|
|
||||||
refreshCallback()
|
refreshCallback()
|
||||||
},
|
},
|
||||||
s.wProfiles,
|
s.wProfiles,
|
||||||
@@ -602,7 +602,7 @@ func (p *profileMenu) refresh() {
|
|||||||
|
|
||||||
// Add Logout menu item
|
// Add Logout menu item
|
||||||
ctx2, cancel2 := context.WithCancel(context.Background())
|
ctx2, cancel2 := context.WithCancel(context.Background())
|
||||||
logoutItem := p.profileMenuItem.AddSubMenuItem("Logout", "")
|
logoutItem := p.profileMenuItem.AddSubMenuItem("Deregister", "")
|
||||||
p.logoutSubItem = &subItem{logoutItem, ctx2, cancel2}
|
p.logoutSubItem = &subItem{logoutItem, ctx2, cancel2}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
@@ -616,9 +616,9 @@ func (p *profileMenu) refresh() {
|
|||||||
}
|
}
|
||||||
if err := p.eventHandler.logout(p.ctx); err != nil {
|
if err := p.eventHandler.logout(p.ctx); err != nil {
|
||||||
log.Errorf("logout failed: %v", err)
|
log.Errorf("logout failed: %v", err)
|
||||||
p.app.SendNotification(fyne.NewNotification("Error", "Failed to logout"))
|
p.app.SendNotification(fyne.NewNotification("Error", "Failed to deregister"))
|
||||||
} else {
|
} else {
|
||||||
p.app.SendNotification(fyne.NewNotification("Success", "Logged out successfully"))
|
p.app.SendNotification(fyne.NewNotification("Success", "Deregistered successfully"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -102,6 +102,11 @@ func (n *NameServer) IsEqual(other *NameServer) bool {
|
|||||||
other.Port == n.Port
|
other.Port == n.Port
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddrPort returns the nameserver as a netip.AddrPort
|
||||||
|
func (n *NameServer) AddrPort() netip.AddrPort {
|
||||||
|
return netip.AddrPortFrom(n.IP, uint16(n.Port))
|
||||||
|
}
|
||||||
|
|
||||||
// ParseNameServerURL parses a nameserver url in the format <type>://<ip>:<port>, e.g., udp://1.1.1.1:53
|
// ParseNameServerURL parses a nameserver url in the format <type>://<ip>:<port>, e.g., udp://1.1.1.1:53
|
||||||
func ParseNameServerURL(nsURL string) (NameServer, error) {
|
func ParseNameServerURL(nsURL string) (NameServer, error) {
|
||||||
parsedURL, err := url.Parse(nsURL)
|
parsedURL, err := url.Parse(nsURL)
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -63,7 +63,7 @@ require (
|
|||||||
github.com/miekg/dns v1.1.59
|
github.com/miekg/dns v1.1.59
|
||||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||||
github.com/nadoo/ipset v0.5.0
|
github.com/nadoo/ipset v0.5.0
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250805121557-5f225a973d1f
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20250812185008-dfc66fa49a2e
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
||||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||||
github.com/oschwald/maxminddb-golang v1.12.0
|
github.com/oschwald/maxminddb-golang v1.12.0
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
|||||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
|
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
|
||||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
|
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250805121557-5f225a973d1f h1:YmqNWdRbeVn1lSpkLzIiFHX2cndRuaVYyynx2ibrOtg=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20250812185008-dfc66fa49a2e h1:S85laGfx1UP+nmRF9smP6/TY965kLWz41PbBK1TX8g0=
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250805121557-5f225a973d1f/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20250812185008-dfc66fa49a2e/go.mod h1:Jjve0+eUjOLKL3PJtAhjfM2iJ0SxWio5elHqlV1ymP8=
|
||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ import (
|
|||||||
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/peers"
|
"github.com/netbirdio/netbird/management/server/peers"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
@@ -45,7 +46,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/auth"
|
"github.com/netbirdio/netbird/management/server/auth"
|
||||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
|
||||||
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
"github.com/netbirdio/netbird/management/server/metrics"
|
"github.com/netbirdio/netbird/management/server/metrics"
|
||||||
@@ -220,7 +220,8 @@ var (
|
|||||||
return fmt.Errorf("build default manager: %v", err)
|
return fmt.Errorf("build default manager: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsManager)
|
groupsManager := groups.NewManager(store, permissionsManager, accountManager)
|
||||||
|
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsManager, groupsManager)
|
||||||
|
|
||||||
trustedPeers := config.ReverseProxy.TrustedPeers
|
trustedPeers := config.ReverseProxy.TrustedPeers
|
||||||
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
|
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
|
||||||
@@ -277,7 +278,6 @@ var (
|
|||||||
config.GetAuthAudiences(),
|
config.GetAuthAudiences(),
|
||||||
config.HttpConfig.IdpSignKeyRefreshEnabled)
|
config.HttpConfig.IdpSignKeyRefreshEnabled)
|
||||||
|
|
||||||
groupsManager := groups.NewManager(store, permissionsManager, accountManager)
|
|
||||||
resourcesManager := resources.NewManager(store, permissionsManager, groupsManager, accountManager)
|
resourcesManager := resources.NewManager(store, permissionsManager, groupsManager, accountManager)
|
||||||
routersManager := routers.NewManager(store, permissionsManager, accountManager)
|
routersManager := routers.NewManager(store, permissionsManager, accountManager)
|
||||||
networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager)
|
networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager)
|
||||||
|
|||||||
@@ -40,12 +40,12 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/management/server/util"
|
"github.com/netbirdio/netbird/management/server/util"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -346,12 +346,12 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
|||||||
}
|
}
|
||||||
|
|
||||||
if updateAccountPeers || groupsUpdated {
|
if updateAccountPeers || groupsUpdated {
|
||||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return transaction.SaveAccountSettings(ctx, store.LockingStrengthUpdate, accountID, newSettings)
|
return transaction.SaveAccountSettings(ctx, accountID, newSettings)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -405,7 +405,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra
|
|||||||
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
|
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
|
||||||
}
|
}
|
||||||
|
|
||||||
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
|
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -746,7 +746,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
|||||||
|
|
||||||
// AccountExists checks if an account exists.
|
// AccountExists checks if an account exists.
|
||||||
func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) {
|
func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) {
|
||||||
return am.Store.AccountExists(ctx, store.LockingStrengthShare, accountID)
|
return am.Store.AccountExists(ctx, store.LockingStrengthNone, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountIDByUserID retrieves the account ID based on the userID provided.
|
// GetAccountIDByUserID retrieves the account ID based on the userID provided.
|
||||||
@@ -758,7 +758,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI
|
|||||||
return "", status.Errorf(status.NotFound, "no valid userID provided")
|
return "", status.Errorf(status.NotFound, "no valid userID provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userID)
|
accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||||
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
|
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
|
||||||
@@ -813,7 +813,7 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any)
|
|||||||
log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID)
|
log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID)
|
||||||
accountIDString := fmt.Sprintf("%v", accountID)
|
accountIDString := fmt.Sprintf("%v", accountID)
|
||||||
|
|
||||||
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountIDString)
|
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountIDString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
@@ -867,7 +867,7 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(ctx context.Context, e
|
|||||||
|
|
||||||
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
|
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
|
||||||
func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, accountID string) (*idp.UserData, error) {
|
func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, accountID string) (*idp.UserData, error) {
|
||||||
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
|
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -897,7 +897,7 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s
|
|||||||
|
|
||||||
// add extra check on external cache manager. We may get to this point when the user is not yet findable in IDP,
|
// add extra check on external cache manager. We may get to this point when the user is not yet findable in IDP,
|
||||||
// or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta
|
// or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta
|
||||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, accountID)
|
log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, accountID)
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -1048,7 +1048,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
|
|||||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
defer unlockAccount()
|
defer unlockAccount()
|
||||||
|
|
||||||
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, accountID)
|
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
|
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
|
||||||
return err
|
return err
|
||||||
@@ -1058,7 +1058,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
|
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("error getting user: %v", err)
|
log.WithContext(ctx).Errorf("error getting user: %v", err)
|
||||||
return err
|
return err
|
||||||
@@ -1145,7 +1145,7 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context,
|
|||||||
|
|
||||||
newUser := types.NewRegularUser(userAuth.UserId)
|
newUser := types.NewRegularUser(userAuth.UserId)
|
||||||
newUser.AccountID = domainAccountID
|
newUser.AccountID = domainAccountID
|
||||||
err := am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser)
|
err := am.Store.SaveUser(ctx, newUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@@ -1223,7 +1223,7 @@ func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID s
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetAccountMeta(ctx, store.LockingStrengthShare, accountID)
|
return am.Store.GetAccountMeta(ctx, store.LockingStrengthNone, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountOnboarding retrieves the onboarding information for a specific account.
|
// GetAccountOnboarding retrieves the onboarding information for a specific account.
|
||||||
@@ -1308,7 +1308,7 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
|
|||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
|
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// this is not really possible because we got an account by user ID
|
// this is not really possible because we got an account by user ID
|
||||||
return "", "", status.Errorf(status.NotFound, "user %s not found", userAuth.UserId)
|
return "", "", status.Errorf(status.NotFound, "user %s not found", userAuth.UserId)
|
||||||
@@ -1340,7 +1340,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId)
|
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, userAuth.AccountId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -1366,12 +1366,12 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
|||||||
var hasChanges bool
|
var hasChanges bool
|
||||||
var user *types.User
|
var user *types.User
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
|
user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error getting user: %w", err)
|
return fmt.Errorf("error getting user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId)
|
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthNone, userAuth.AccountId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error getting account groups: %w", err)
|
return fmt.Errorf("error getting account groups: %w", err)
|
||||||
}
|
}
|
||||||
@@ -1387,7 +1387,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = transaction.CreateGroups(ctx, store.LockingStrengthUpdate, userAuth.AccountId, newGroupsToCreate); err != nil {
|
if err = transaction.CreateGroups(ctx, userAuth.AccountId, newGroupsToCreate); err != nil {
|
||||||
return fmt.Errorf("error saving groups: %w", err)
|
return fmt.Errorf("error saving groups: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1395,13 +1395,13 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
|||||||
removeOldGroups = util.Difference(user.AutoGroups, updatedAutoGroups)
|
removeOldGroups = util.Difference(user.AutoGroups, updatedAutoGroups)
|
||||||
|
|
||||||
user.AutoGroups = updatedAutoGroups
|
user.AutoGroups = updatedAutoGroups
|
||||||
if err = transaction.SaveUser(ctx, store.LockingStrengthUpdate, user); err != nil {
|
if err = transaction.SaveUser(ctx, user); err != nil {
|
||||||
return fmt.Errorf("error saving user: %w", err)
|
return fmt.Errorf("error saving user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Propagate changes to peers if group propagation is enabled
|
// Propagate changes to peers if group propagation is enabled
|
||||||
if settings.GroupsPropagationEnabled {
|
if settings.GroupsPropagationEnabled {
|
||||||
peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, userAuth.AccountId, userAuth.UserId)
|
peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, userAuth.AccountId, userAuth.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error getting user peers: %w", err)
|
return fmt.Errorf("error getting user peers: %w", err)
|
||||||
}
|
}
|
||||||
@@ -1419,7 +1419,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, userAuth.AccountId); err != nil {
|
if err = transaction.IncrementNetworkSerial(ctx, userAuth.AccountId); err != nil {
|
||||||
return fmt.Errorf("error incrementing network serial: %w", err)
|
return fmt.Errorf("error incrementing network serial: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1437,7 +1437,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, g := range addNewGroups {
|
for _, g := range addNewGroups {
|
||||||
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g)
|
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthNone, userAuth.AccountId, g)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId)
|
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId)
|
||||||
} else {
|
} else {
|
||||||
@@ -1450,7 +1450,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, g := range removeOldGroups {
|
for _, g := range removeOldGroups {
|
||||||
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g)
|
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthNone, userAuth.AccountId, g)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId)
|
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId)
|
||||||
} else {
|
} else {
|
||||||
@@ -1511,7 +1511,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
|
|||||||
}
|
}
|
||||||
|
|
||||||
if userAuth.IsChild {
|
if userAuth.IsChild {
|
||||||
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, userAuth.AccountId)
|
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthNone, userAuth.AccountId)
|
||||||
if err != nil || !exists {
|
if err != nil || !exists {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@@ -1535,7 +1535,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
|
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
|
||||||
if handleNotFound(err) != nil {
|
if handleNotFound(err) != nil {
|
||||||
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
|
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
|
||||||
return "", err
|
return "", err
|
||||||
@@ -1556,7 +1556,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
|
|||||||
return am.addNewPrivateAccount(ctx, domainAccountID, userAuth)
|
return am.addNewPrivateAccount(ctx, domainAccountID, userAuth)
|
||||||
}
|
}
|
||||||
func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) {
|
func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) {
|
||||||
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain)
|
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
|
||||||
if handleNotFound(err) != nil {
|
if handleNotFound(err) != nil {
|
||||||
|
|
||||||
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
|
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
|
||||||
@@ -1571,7 +1571,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
|
|||||||
cancel := am.Store.AcquireGlobalLock(ctx)
|
cancel := am.Store.AcquireGlobalLock(ctx)
|
||||||
|
|
||||||
// check again if the domain has a primary account because of simultaneous requests
|
// check again if the domain has a primary account because of simultaneous requests
|
||||||
domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain)
|
domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
|
||||||
if handleNotFound(err) != nil {
|
if handleNotFound(err) != nil {
|
||||||
cancel()
|
cancel()
|
||||||
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
|
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
|
||||||
@@ -1582,7 +1582,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) {
|
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) {
|
||||||
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
|
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
|
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
|
||||||
return "", err
|
return "", err
|
||||||
@@ -1592,7 +1592,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context
|
|||||||
return "", fmt.Errorf("user %s is not part of the account id %s", userAuth.UserId, userAuth.AccountId)
|
return "", fmt.Errorf("user %s is not part of the account id %s", userAuth.UserId, userAuth.AccountId)
|
||||||
}
|
}
|
||||||
|
|
||||||
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, userAuth.AccountId)
|
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, userAuth.AccountId)
|
||||||
if handleNotFound(err) != nil {
|
if handleNotFound(err) != nil {
|
||||||
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
|
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
|
||||||
return "", err
|
return "", err
|
||||||
@@ -1603,7 +1603,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context
|
|||||||
}
|
}
|
||||||
|
|
||||||
// We checked if the domain has a primary account already
|
// We checked if the domain has a primary account already
|
||||||
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, userAuth.Domain)
|
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, userAuth.Domain)
|
||||||
if handleNotFound(err) != nil {
|
if handleNotFound(err) != nil {
|
||||||
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
|
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
|
||||||
return "", err
|
return "", err
|
||||||
@@ -1751,7 +1751,7 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction store.Store, peer *nbpeer.Peer, settings *types.Settings) (bool, error) {
|
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction store.Store, peer *nbpeer.Peer, settings *types.Settings) (bool, error) {
|
||||||
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID)
|
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, peer.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
@@ -1780,7 +1780,7 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account
|
|||||||
if !allowed {
|
if !allowed {
|
||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
return am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
return am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
|
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
|
||||||
@@ -1870,7 +1870,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
|
|||||||
cancel := am.Store.AcquireGlobalLock(ctx)
|
cancel := am.Store.AcquireGlobalLock(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain)
|
existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
|
||||||
if handleNotFound(err) != nil {
|
if handleNotFound(err) != nil {
|
||||||
return nil, false, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
@@ -1890,7 +1890,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
|
|||||||
for range 2 {
|
for range 2 {
|
||||||
accountId := xid.New().String()
|
accountId := xid.New().String()
|
||||||
|
|
||||||
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, accountId)
|
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthNone, accountId)
|
||||||
if err != nil || exists {
|
if err != nil || exists {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -1965,7 +1965,7 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, account.Domain)
|
existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, account.Domain)
|
||||||
|
|
||||||
// error is not a not found error
|
// error is not a not found error
|
||||||
if handleNotFound(err) != nil {
|
if handleNotFound(err) != nil {
|
||||||
@@ -2002,17 +2002,17 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
|
|||||||
// propagateUserGroupMemberships propagates all account users' group memberships to their peers.
|
// propagateUserGroupMemberships propagates all account users' group memberships to their peers.
|
||||||
// Returns true if any groups were modified, true if those updates affect peers and an error.
|
// Returns true if any groups were modified, true if those updates affect peers and an error.
|
||||||
func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) {
|
func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) {
|
||||||
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
|
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, false, err
|
return false, false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
accountGroupPeers, err := transaction.GetAccountGroupPeers(ctx, store.LockingStrengthShare, accountID)
|
accountGroupPeers, err := transaction.GetAccountGroupPeers(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, false, fmt.Errorf("error getting account group peers: %w", err)
|
return false, false, fmt.Errorf("error getting account group peers: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
accountGroups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
|
accountGroups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, false, fmt.Errorf("error getting account groups: %w", err)
|
return false, false, fmt.Errorf("error getting account groups: %w", err)
|
||||||
}
|
}
|
||||||
@@ -2025,7 +2025,7 @@ func propagateUserGroupMemberships(ctx context.Context, transaction store.Store,
|
|||||||
|
|
||||||
updatedGroups := []string{}
|
updatedGroups := []string{}
|
||||||
for _, user := range users {
|
for _, user := range users {
|
||||||
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, user.Id)
|
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, false, err
|
return false, false, err
|
||||||
}
|
}
|
||||||
@@ -2074,7 +2074,7 @@ func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, t
|
|||||||
|
|
||||||
account.Network.Net = newIPNet
|
account.Network.Net = newIPNet
|
||||||
|
|
||||||
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
|
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -2099,7 +2099,7 @@ func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, t
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, peer := range peers {
|
for _, peer := range peers {
|
||||||
if err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer); err != nil {
|
if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
|
||||||
return status.Errorf(status.Internal, "save updated peer %s: %v", peer.ID, err)
|
return status.Errorf(status.Internal, "save updated peer %s: %v", peer.ID, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -2154,7 +2154,7 @@ func (am *DefaultAccountManager) updatePeerIPInTransaction(ctx context.Context,
|
|||||||
return fmt.Errorf("get account: %w", err)
|
return fmt.Errorf("get account: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
existingPeer, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID)
|
existingPeer, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get peer: %w", err)
|
return fmt.Errorf("get peer: %w", err)
|
||||||
}
|
}
|
||||||
@@ -2185,7 +2185,7 @@ func (am *DefaultAccountManager) updatePeerIPInTransaction(ctx context.Context,
|
|||||||
func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transaction store.Store, accountID, userID string, peer *nbpeer.Peer, newIP netip.Addr) error {
|
func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transaction store.Store, accountID, userID string, peer *nbpeer.Peer, newIP netip.Addr) error {
|
||||||
log.WithContext(ctx).Infof("updating peer %s IP from %s to %s", peer.ID, peer.IP, newIP)
|
log.WithContext(ctx).Infof("updating peer %s IP from %s to %s", peer.ID, peer.IP, newIP)
|
||||||
|
|
||||||
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get account settings: %w", err)
|
return fmt.Errorf("get account settings: %w", err)
|
||||||
}
|
}
|
||||||
@@ -2195,7 +2195,7 @@ func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transacti
|
|||||||
oldIP := peer.IP.String()
|
oldIP := peer.IP.String()
|
||||||
|
|
||||||
peer.IP = newIP.AsSlice()
|
peer.IP = newIP.AsSlice()
|
||||||
err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer)
|
err = transaction.SavePeer(ctx, accountID, peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("save peer: %w", err)
|
return fmt.Errorf("save peer: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -783,7 +783,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthShare, accountID)
|
exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthNone, accountID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.True(t, exists, "expected to get existing account after creation using userid")
|
assert.True(t, exists, "expected to get existing account after creation using userid")
|
||||||
|
|
||||||
@@ -900,11 +900,11 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
|
|||||||
t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount))
|
t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount))
|
||||||
}
|
}
|
||||||
|
|
||||||
pats, err := manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, "service-user-1")
|
pats, err := manager.Store.GetUserPATs(context.Background(), store.LockingStrengthNone, "service-user-1")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, pats, 0)
|
assert.Len(t, pats, 0)
|
||||||
|
|
||||||
pats, err = manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, userId)
|
pats, err = manager.Store.GetUserPATs(context.Background(), store.LockingStrengthNone, userId)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, pats, 0)
|
assert.Len(t, pats, 0)
|
||||||
}
|
}
|
||||||
@@ -1786,7 +1786,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
|
|||||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||||
require.NoError(t, err, "unable to create an account")
|
require.NoError(t, err, "unable to create an account")
|
||||||
|
|
||||||
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID)
|
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthNone, accountID)
|
||||||
require.NoError(t, err, "unable to get account settings")
|
require.NoError(t, err, "unable to get account settings")
|
||||||
|
|
||||||
assert.NotNil(t, settings)
|
assert.NotNil(t, settings)
|
||||||
@@ -1971,7 +1971,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
|||||||
assert.False(t, updatedSettings.PeerLoginExpirationEnabled)
|
assert.False(t, updatedSettings.PeerLoginExpirationEnabled)
|
||||||
assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour)
|
assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour)
|
||||||
|
|
||||||
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID)
|
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthNone, accountID)
|
||||||
require.NoError(t, err, "unable to get account settings")
|
require.NoError(t, err, "unable to get account settings")
|
||||||
|
|
||||||
assert.False(t, settings.PeerLoginExpirationEnabled)
|
assert.False(t, settings.PeerLoginExpirationEnabled)
|
||||||
@@ -2655,7 +2655,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||||
assert.NoError(t, err, "unable to sync jwt groups")
|
assert.NoError(t, err, "unable to sync jwt groups")
|
||||||
|
|
||||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
|
||||||
assert.NoError(t, err, "unable to get user")
|
assert.NoError(t, err, "unable to get user")
|
||||||
assert.Len(t, user.AutoGroups, 0, "JWT groups should not be synced")
|
assert.Len(t, user.AutoGroups, 0, "JWT groups should not be synced")
|
||||||
})
|
})
|
||||||
@@ -2669,7 +2669,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
err := manager.SyncUserJWTGroups(context.Background(), claims)
|
err := manager.SyncUserJWTGroups(context.Background(), claims)
|
||||||
assert.NoError(t, err, "unable to sync jwt groups")
|
assert.NoError(t, err, "unable to sync jwt groups")
|
||||||
|
|
||||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
|
||||||
assert.NoError(t, err, "unable to get user")
|
assert.NoError(t, err, "unable to get user")
|
||||||
assert.Empty(t, user.AutoGroups, "auto groups must be empty")
|
assert.Empty(t, user.AutoGroups, "auto groups must be empty")
|
||||||
})
|
})
|
||||||
@@ -2683,18 +2683,18 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
err := manager.SyncUserJWTGroups(context.Background(), claims)
|
err := manager.SyncUserJWTGroups(context.Background(), claims)
|
||||||
assert.NoError(t, err, "unable to sync jwt groups")
|
assert.NoError(t, err, "unable to sync jwt groups")
|
||||||
|
|
||||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
|
||||||
assert.NoError(t, err, "unable to get user")
|
assert.NoError(t, err, "unable to get user")
|
||||||
assert.Len(t, user.AutoGroups, 0)
|
assert.Len(t, user.AutoGroups, 0)
|
||||||
|
|
||||||
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1")
|
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthNone, "accountID", "group1")
|
||||||
assert.NoError(t, err, "unable to get group")
|
assert.NoError(t, err, "unable to get group")
|
||||||
assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued")
|
assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("jwt match existing api group in user auto groups", func(t *testing.T) {
|
t.Run("jwt match existing api group in user auto groups", func(t *testing.T) {
|
||||||
account.Users["user1"].AutoGroups = []string{"group1"}
|
account.Users["user1"].AutoGroups = []string{"group1"}
|
||||||
assert.NoError(t, manager.Store.SaveUser(context.Background(), store.LockingStrengthUpdate, account.Users["user1"]))
|
assert.NoError(t, manager.Store.SaveUser(context.Background(), account.Users["user1"]))
|
||||||
|
|
||||||
claims := nbcontext.UserAuth{
|
claims := nbcontext.UserAuth{
|
||||||
UserId: "user1",
|
UserId: "user1",
|
||||||
@@ -2704,11 +2704,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||||
assert.NoError(t, err, "unable to sync jwt groups")
|
assert.NoError(t, err, "unable to sync jwt groups")
|
||||||
|
|
||||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
|
||||||
assert.NoError(t, err, "unable to get user")
|
assert.NoError(t, err, "unable to get user")
|
||||||
assert.Len(t, user.AutoGroups, 1)
|
assert.Len(t, user.AutoGroups, 1)
|
||||||
|
|
||||||
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1")
|
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthNone, "accountID", "group1")
|
||||||
assert.NoError(t, err, "unable to get group")
|
assert.NoError(t, err, "unable to get group")
|
||||||
assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued")
|
assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued")
|
||||||
})
|
})
|
||||||
@@ -2722,7 +2722,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||||
assert.NoError(t, err, "unable to sync jwt groups")
|
assert.NoError(t, err, "unable to sync jwt groups")
|
||||||
|
|
||||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
|
||||||
assert.NoError(t, err, "unable to get user")
|
assert.NoError(t, err, "unable to get user")
|
||||||
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
|
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
|
||||||
})
|
})
|
||||||
@@ -2736,7 +2736,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||||
assert.NoError(t, err, "unable to sync jwt groups")
|
assert.NoError(t, err, "unable to sync jwt groups")
|
||||||
|
|
||||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
|
||||||
assert.NoError(t, err, "unable to get user")
|
assert.NoError(t, err, "unable to get user")
|
||||||
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
|
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
|
||||||
})
|
})
|
||||||
@@ -2750,11 +2750,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||||
assert.NoError(t, err, "unable to sync jwt groups")
|
assert.NoError(t, err, "unable to sync jwt groups")
|
||||||
|
|
||||||
groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, "accountID")
|
groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthNone, "accountID")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Len(t, groups, 3, "new group3 should be added")
|
assert.Len(t, groups, 3, "new group3 should be added")
|
||||||
|
|
||||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2")
|
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
|
||||||
assert.NoError(t, err, "unable to get user")
|
assert.NoError(t, err, "unable to get user")
|
||||||
assert.Len(t, user.AutoGroups, 1, "new group should be added")
|
assert.Len(t, user.AutoGroups, 1, "new group should be added")
|
||||||
})
|
})
|
||||||
@@ -2768,7 +2768,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||||
assert.NoError(t, err, "unable to sync jwt groups")
|
assert.NoError(t, err, "unable to sync jwt groups")
|
||||||
|
|
||||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
|
||||||
assert.NoError(t, err, "unable to get user")
|
assert.NoError(t, err, "unable to get user")
|
||||||
assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain")
|
assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain")
|
||||||
assert.Contains(t, user.AutoGroups, "group1", "group1 should still be present")
|
assert.Contains(t, user.AutoGroups, "group1", "group1 should still be present")
|
||||||
@@ -2783,7 +2783,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||||
assert.NoError(t, err, "unable to sync jwt groups")
|
assert.NoError(t, err, "unable to sync jwt groups")
|
||||||
|
|
||||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2")
|
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
|
||||||
assert.NoError(t, err, "unable to get user")
|
assert.NoError(t, err, "unable to get user")
|
||||||
assert.Len(t, user.AutoGroups, 0, "all JWT groups should be removed")
|
assert.Len(t, user.AutoGroups, 0, "all JWT groups should be removed")
|
||||||
})
|
})
|
||||||
@@ -3348,11 +3348,11 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"}
|
peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"}
|
||||||
err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer1)
|
err = manager.Store.AddPeerToAccount(ctx, peer1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"}
|
peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"}
|
||||||
err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer2)
|
err = manager.Store.AddPeerToAccount(ctx, peer2)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
t.Run("should skip propagation when the user has no groups", func(t *testing.T) {
|
t.Run("should skip propagation when the user has no groups", func(t *testing.T) {
|
||||||
@@ -3364,20 +3364,20 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("should update membership but no account peers update for unused groups", func(t *testing.T) {
|
t.Run("should update membership but no account peers update for unused groups", func(t *testing.T) {
|
||||||
group1 := &types.Group{ID: "group1", Name: "Group 1", AccountID: account.Id}
|
group1 := &types.Group{ID: "group1", Name: "Group 1", AccountID: account.Id}
|
||||||
require.NoError(t, manager.Store.CreateGroup(ctx, store.LockingStrengthUpdate, group1))
|
require.NoError(t, manager.Store.CreateGroup(ctx, group1))
|
||||||
|
|
||||||
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId)
|
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
user.AutoGroups = append(user.AutoGroups, group1.ID)
|
user.AutoGroups = append(user.AutoGroups, group1.ID)
|
||||||
require.NoError(t, manager.Store.SaveUser(ctx, store.LockingStrengthUpdate, user))
|
require.NoError(t, manager.Store.SaveUser(ctx, user))
|
||||||
|
|
||||||
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
|
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, groupsUpdated)
|
assert.True(t, groupsUpdated)
|
||||||
assert.False(t, groupChangesAffectPeers)
|
assert.False(t, groupChangesAffectPeers)
|
||||||
|
|
||||||
group, err := manager.Store.GetGroupByID(ctx, store.LockingStrengthShare, account.Id, group1.ID)
|
group, err := manager.Store.GetGroupByID(ctx, store.LockingStrengthNone, account.Id, group1.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, group.Peers, 2)
|
assert.Len(t, group.Peers, 2)
|
||||||
assert.Contains(t, group.Peers, "peer1")
|
assert.Contains(t, group.Peers, "peer1")
|
||||||
@@ -3386,13 +3386,13 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("should update membership and account peers for used groups", func(t *testing.T) {
|
t.Run("should update membership and account peers for used groups", func(t *testing.T) {
|
||||||
group2 := &types.Group{ID: "group2", Name: "Group 2", AccountID: account.Id}
|
group2 := &types.Group{ID: "group2", Name: "Group 2", AccountID: account.Id}
|
||||||
require.NoError(t, manager.Store.CreateGroup(ctx, store.LockingStrengthUpdate, group2))
|
require.NoError(t, manager.Store.CreateGroup(ctx, group2))
|
||||||
|
|
||||||
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId)
|
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
user.AutoGroups = append(user.AutoGroups, group2.ID)
|
user.AutoGroups = append(user.AutoGroups, group2.ID)
|
||||||
require.NoError(t, manager.Store.SaveUser(ctx, store.LockingStrengthUpdate, user))
|
require.NoError(t, manager.Store.SaveUser(ctx, user))
|
||||||
|
|
||||||
_, err = manager.SavePolicy(context.Background(), account.Id, initiatorId, &types.Policy{
|
_, err = manager.SavePolicy(context.Background(), account.Id, initiatorId, &types.Policy{
|
||||||
Name: "Group1 Policy",
|
Name: "Group1 Policy",
|
||||||
@@ -3415,7 +3415,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
|
|||||||
assert.True(t, groupsUpdated)
|
assert.True(t, groupsUpdated)
|
||||||
assert.True(t, groupChangesAffectPeers)
|
assert.True(t, groupChangesAffectPeers)
|
||||||
|
|
||||||
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthShare, account.Id, []string{"group1", "group2"})
|
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthNone, account.Id, []string{"group1", "group2"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
for _, group := range groups {
|
for _, group := range groups {
|
||||||
assert.Len(t, group.Peers, 2)
|
assert.Len(t, group.Peers, 2)
|
||||||
@@ -3432,18 +3432,18 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("should not remove peers when groups are removed from user", func(t *testing.T) {
|
t.Run("should not remove peers when groups are removed from user", func(t *testing.T) {
|
||||||
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId)
|
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
user.AutoGroups = []string{"group1"}
|
user.AutoGroups = []string{"group1"}
|
||||||
require.NoError(t, manager.Store.SaveUser(ctx, store.LockingStrengthUpdate, user))
|
require.NoError(t, manager.Store.SaveUser(ctx, user))
|
||||||
|
|
||||||
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
|
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.False(t, groupsUpdated)
|
assert.False(t, groupsUpdated)
|
||||||
assert.False(t, groupChangesAffectPeers)
|
assert.False(t, groupChangesAffectPeers)
|
||||||
|
|
||||||
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthShare, account.Id, []string{"group1", "group2"})
|
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthNone, account.Id, []string{"group1", "group2"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
for _, group := range groups {
|
for _, group := range groups {
|
||||||
assert.Len(t, group.Peers, 2)
|
assert.Len(t, group.Peers, 2)
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbco
|
|||||||
return userAuth, nil
|
return userAuth, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId)
|
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, userAuth.AccountId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return userAuth, err
|
return userAuth, err
|
||||||
}
|
}
|
||||||
@@ -94,7 +94,7 @@ func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbco
|
|||||||
|
|
||||||
// MarkPATUsed marks a personal access token as used
|
// MarkPATUsed marks a personal access token as used
|
||||||
func (am *manager) MarkPATUsed(ctx context.Context, tokenID string) error {
|
func (am *manager) MarkPATUsed(ctx context.Context, tokenID string) error {
|
||||||
return am.store.MarkPATUsed(ctx, store.LockingStrengthUpdate, tokenID)
|
return am.store.MarkPATUsed(ctx, tokenID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPATInfo retrieves user, personal access token, domain, and category details from a personal access token.
|
// GetPATInfo retrieves user, personal access token, domain, and category details from a personal access token.
|
||||||
@@ -104,7 +104,7 @@ func (am *manager) GetPATInfo(ctx context.Context, token string) (user *types.Us
|
|||||||
return nil, nil, "", "", err
|
return nil, nil, "", "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
domain, category, err = am.store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, user.AccountID)
|
domain, category, err = am.store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, user.AccountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, "", "", err
|
return nil, nil, "", "", err
|
||||||
}
|
}
|
||||||
@@ -142,12 +142,12 @@ func (am *manager) extractPATFromToken(ctx context.Context, token string) (*type
|
|||||||
var pat *types.PersonalAccessToken
|
var pat *types.PersonalAccessToken
|
||||||
|
|
||||||
err = am.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err = am.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken)
|
pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthNone, encodedHashedToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthShare, pat.ID)
|
user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthNone, pat.ID)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -8,14 +8,14 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/management/server/util"
|
"github.com/netbirdio/netbird/management/server/util"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DNSConfigCache is a thread-safe cache for DNS configuration components
|
// DNSConfigCache is a thread-safe cache for DNS configuration components
|
||||||
@@ -72,7 +72,7 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID)
|
return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveDNSSettings validates a user role and updates the account's DNS settings
|
// SaveDNSSettings validates a user role and updates the account's DNS settings
|
||||||
@@ -113,11 +113,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
|||||||
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
|
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
|
||||||
eventsToStore = append(eventsToStore, events...)
|
eventsToStore = append(eventsToStore, events...)
|
||||||
|
|
||||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return transaction.SaveDNSSettings(ctx, store.LockingStrengthUpdate, accountID, dnsSettingsToSave)
|
return transaction.SaveDNSSettings(ctx, accountID, dnsSettingsToSave)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -139,7 +139,7 @@ func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, t
|
|||||||
var eventsToStore []func()
|
var eventsToStore []func()
|
||||||
|
|
||||||
modifiedGroups := slices.Concat(addedGroups, removedGroups)
|
modifiedGroups := slices.Concat(addedGroups, removedGroups)
|
||||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups)
|
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, modifiedGroups)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err)
|
log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err)
|
||||||
return nil
|
return nil
|
||||||
@@ -195,7 +195,7 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, settings.DisabledManagementGroups)
|
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, settings.DisabledManagementGroups)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -134,7 +134,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
|
func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
|
||||||
peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthShare)
|
peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthNone)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err)
|
log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -11,9 +11,9 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
func isEnabled() bool {
|
func isEnabled() bool {
|
||||||
@@ -103,7 +103,7 @@ func (am *DefaultAccountManager) fillEventsWithUserInfo(ctx context.Context, eve
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) getEventsUserInfo(ctx context.Context, events []*activity.Event, accountId string, userId string) (map[string]eventUserInfo, error) {
|
func (am *DefaultAccountManager) getEventsUserInfo(ctx context.Context, events []*activity.Event, accountId string, userId string) (map[string]eventUserInfo, error) {
|
||||||
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountId)
|
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -154,7 +154,7 @@ func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context,
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
externalUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id)
|
externalUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// @todo consider logging
|
// @todo consider logging
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -14,11 +14,11 @@ import (
|
|||||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/management/server/util"
|
"github.com/netbirdio/netbird/management/server/util"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
type GroupLinkError struct {
|
type GroupLinkError struct {
|
||||||
@@ -49,7 +49,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
|
|||||||
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID)
|
return am.Store.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllGroups returns all groups in an account
|
// GetAllGroups returns all groups in an account
|
||||||
@@ -57,12 +57,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
|
|||||||
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
|
return am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
||||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
|
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
|
||||||
return am.Store.GetGroupByName(ctx, store.LockingStrengthShare, accountID, groupName)
|
return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateGroup object of the peers
|
// CreateGroup object of the peers
|
||||||
@@ -96,11 +96,11 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := transaction.CreateGroup(ctx, store.LockingStrengthUpdate, newGroup); err != nil {
|
if err := transaction.CreateGroup(ctx, newGroup); err != nil {
|
||||||
return status.Errorf(status.Internal, "failed to create group: %v", err)
|
return status.Errorf(status.Internal, "failed to create group: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -147,7 +147,7 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, newGroup.ID)
|
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(status.NotFound, "group with ID %s not found", newGroup.ID)
|
return status.Errorf(status.NotFound, "group with ID %s not found", newGroup.ID)
|
||||||
}
|
}
|
||||||
@@ -176,11 +176,11 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return transaction.UpdateGroup(ctx, store.LockingStrengthUpdate, newGroup)
|
return transaction.UpdateGroup(ctx, newGroup)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -234,11 +234,11 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return transaction.CreateGroups(ctx, store.LockingStrengthUpdate, accountID, groupsToSave)
|
return transaction.CreateGroups(ctx, accountID, groupsToSave)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -292,11 +292,11 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return transaction.UpdateGroups(ctx, store.LockingStrengthUpdate, accountID, groupsToSave)
|
return transaction.UpdateGroups(ctx, accountID, groupsToSave)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -320,7 +320,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
|
|||||||
addedPeers := make([]string, 0)
|
addedPeers := make([]string, 0)
|
||||||
removedPeers := make([]string, 0)
|
removedPeers := make([]string, 0)
|
||||||
|
|
||||||
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, newGroup.ID)
|
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
|
||||||
if err == nil && oldGroup != nil {
|
if err == nil && oldGroup != nil {
|
||||||
addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers)
|
addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers)
|
||||||
removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers)
|
removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers)
|
||||||
@@ -332,13 +332,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
|
|||||||
}
|
}
|
||||||
|
|
||||||
modifiedPeers := slices.Concat(addedPeers, removedPeers)
|
modifiedPeers := slices.Concat(addedPeers, removedPeers)
|
||||||
peers, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, modifiedPeers)
|
peers, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, modifiedPeers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err)
|
log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err)
|
log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err)
|
||||||
return nil
|
return nil
|
||||||
@@ -423,11 +423,11 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
|||||||
deletedGroups = append(deletedGroups, group)
|
deletedGroups = append(deletedGroups, group)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return transaction.DeleteGroups(ctx, store.LockingStrengthUpdate, accountID, groupIDsToDelete)
|
return transaction.DeleteGroups(ctx, accountID, groupIDsToDelete)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -454,7 +454,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -495,11 +495,11 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return transaction.UpdateGroup(ctx, store.LockingStrengthUpdate, group)
|
return transaction.UpdateGroup(ctx, group)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -526,7 +526,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -567,11 +567,11 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return transaction.UpdateGroup(ctx, store.LockingStrengthUpdate, group)
|
return transaction.UpdateGroup(ctx, group)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -591,7 +591,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
|
|||||||
}
|
}
|
||||||
|
|
||||||
if newGroup.ID == "" && newGroup.Issued == types.GroupIssuedAPI {
|
if newGroup.ID == "" && newGroup.Issued == types.GroupIssuedAPI {
|
||||||
existingGroup, err := transaction.GetGroupByName(ctx, store.LockingStrengthShare, accountID, newGroup.Name)
|
existingGroup, err := transaction.GetGroupByName(ctx, store.LockingStrengthNone, accountID, newGroup.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound {
|
if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound {
|
||||||
return err
|
return err
|
||||||
@@ -608,7 +608,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, peerID := range newGroup.Peers {
|
for _, peerID := range newGroup.Peers {
|
||||||
_, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID)
|
_, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
|
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
|
||||||
}
|
}
|
||||||
@@ -620,7 +620,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
|
|||||||
func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error {
|
func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error {
|
||||||
// disable a deleting integration group if the initiator is not an admin service user
|
// disable a deleting integration group if the initiator is not an admin service user
|
||||||
if group.Issued == types.GroupIssuedIntegration {
|
if group.Issued == types.GroupIssuedIntegration {
|
||||||
executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(status.Internal, "failed to get user")
|
return status.Errorf(status.Internal, "failed to get user")
|
||||||
}
|
}
|
||||||
@@ -666,7 +666,7 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty
|
|||||||
|
|
||||||
// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account.
|
// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account.
|
||||||
func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, group *types.Group) error {
|
func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, group *types.Group) error {
|
||||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, group.AccountID)
|
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, group.AccountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(status.Internal, "failed to get DNS settings")
|
return status.Errorf(status.Internal, "failed to get DNS settings")
|
||||||
}
|
}
|
||||||
@@ -675,7 +675,7 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, gr
|
|||||||
return &GroupLinkError{"disabled DNS management groups", group.Name}
|
return &GroupLinkError{"disabled DNS management groups", group.Name}
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, group.AccountID)
|
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, group.AccountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(status.Internal, "failed to get account settings")
|
return status.Errorf(status.Internal, "failed to get account settings")
|
||||||
}
|
}
|
||||||
@@ -689,7 +689,7 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, gr
|
|||||||
|
|
||||||
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
|
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
|
||||||
func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *route.Route) {
|
func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *route.Route) {
|
||||||
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
|
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
|
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
|
||||||
return false, nil
|
return false, nil
|
||||||
@@ -709,7 +709,7 @@ func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountI
|
|||||||
|
|
||||||
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
|
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
|
||||||
func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.Policy) {
|
func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.Policy) {
|
||||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
|
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
|
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
|
||||||
return false, nil
|
return false, nil
|
||||||
@@ -727,7 +727,7 @@ func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, account
|
|||||||
|
|
||||||
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
|
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
|
||||||
func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
|
func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
|
||||||
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID)
|
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
|
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
|
||||||
return false, nil
|
return false, nil
|
||||||
@@ -746,7 +746,7 @@ func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID
|
|||||||
|
|
||||||
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
|
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
|
||||||
func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.SetupKey) {
|
func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.SetupKey) {
|
||||||
setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID)
|
setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
|
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
|
||||||
return false, nil
|
return false, nil
|
||||||
@@ -762,7 +762,7 @@ func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accou
|
|||||||
|
|
||||||
// isGroupLinkedToUser checks if a group is linked to any user in the account.
|
// isGroupLinkedToUser checks if a group is linked to any user in the account.
|
||||||
func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.User) {
|
func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.User) {
|
||||||
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
|
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
|
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
|
||||||
return false, nil
|
return false, nil
|
||||||
@@ -778,7 +778,7 @@ func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID
|
|||||||
|
|
||||||
// isGroupLinkedToNetworkRouter checks if a group is linked to any network router in the account.
|
// isGroupLinkedToNetworkRouter checks if a group is linked to any network router in the account.
|
||||||
func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *routerTypes.NetworkRouter) {
|
func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *routerTypes.NetworkRouter) {
|
||||||
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID)
|
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("error retrieving network routers while checking group linkage: %v", err)
|
log.WithContext(ctx).Errorf("error retrieving network routers while checking group linkage: %v", err)
|
||||||
return false, nil
|
return false, nil
|
||||||
@@ -798,7 +798,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
|
|||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID)
|
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
@@ -826,7 +826,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
|
|||||||
|
|
||||||
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
|
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
|
||||||
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
|
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
|
||||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs)
|
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,10 +26,10 @@ import (
|
|||||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||||
peer2 "github.com/netbirdio/netbird/management/server/peer"
|
peer2 "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -898,7 +898,7 @@ func Test_AddPeerAndAddToAll(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error {
|
err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error {
|
||||||
err = transaction.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer)
|
err = transaction.AddPeerToAccount(context.Background(), peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
|
return fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
|
||||||
}
|
}
|
||||||
@@ -971,7 +971,7 @@ func Test_IncrementNetworkSerial(t *testing.T) {
|
|||||||
<-start
|
<-start
|
||||||
|
|
||||||
err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error {
|
err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error {
|
||||||
err = transaction.IncrementNetworkSerial(context.Background(), store.LockingStrengthNone, accountID)
|
err = transaction.IncrementNetworkSerial(context.Background(), accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get account %s: %v", accountID, err)
|
return fmt.Errorf("failed to get account %s: %v", accountID, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,12 +6,12 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
@@ -21,6 +21,7 @@ type Manager interface {
|
|||||||
AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resourceID *types.Resource) error
|
AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resourceID *types.Resource) error
|
||||||
AddResourceToGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID string, resourceID *types.Resource) (func(), error)
|
AddResourceToGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID string, resourceID *types.Resource) (func(), error)
|
||||||
RemoveResourceFromGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID, resourceID string) (func(), error)
|
RemoveResourceFromGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID, resourceID string) (func(), error)
|
||||||
|
GetPeerGroupIDs(ctx context.Context, accountID, peerID string) ([]string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type managerImpl struct {
|
type managerImpl struct {
|
||||||
@@ -49,7 +50,7 @@ func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
|
groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting account groups: %w", err)
|
return nil, fmt.Errorf("error getting account groups: %w", err)
|
||||||
}
|
}
|
||||||
@@ -96,13 +97,13 @@ func (m *managerImpl) AddResourceToGroupInTransaction(ctx context.Context, trans
|
|||||||
return nil, fmt.Errorf("error adding resource to group: %w", err)
|
return nil, fmt.Errorf("error adding resource to group: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID)
|
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting group: %w", err)
|
return nil, fmt.Errorf("error getting group: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: at some point, this will need to become a switch statement
|
// TODO: at some point, this will need to become a switch statement
|
||||||
networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resource.ID)
|
networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resource.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting network resource: %w", err)
|
return nil, fmt.Errorf("error getting network resource: %w", err)
|
||||||
}
|
}
|
||||||
@@ -120,13 +121,13 @@ func (m *managerImpl) RemoveResourceFromGroupInTransaction(ctx context.Context,
|
|||||||
return nil, fmt.Errorf("error removing resource from group: %w", err)
|
return nil, fmt.Errorf("error removing resource from group: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID)
|
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting group: %w", err)
|
return nil, fmt.Errorf("error getting group: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: at some point, this will need to become a switch statement
|
// TODO: at some point, this will need to become a switch statement
|
||||||
networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resourceID)
|
networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resourceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting network resource: %w", err)
|
return nil, fmt.Errorf("error getting network resource: %w", err)
|
||||||
}
|
}
|
||||||
@@ -142,6 +143,10 @@ func (m *managerImpl) GetResourceGroupsInTransaction(ctx context.Context, transa
|
|||||||
return transaction.GetResourceGroups(ctx, lockingStrength, accountID, resourceID)
|
return transaction.GetResourceGroups(ctx, lockingStrength, accountID, resourceID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) GetPeerGroupIDs(ctx context.Context, accountID, peerID string) ([]string, error) {
|
||||||
|
return m.store.GetPeerGroupIDs(ctx, store.LockingStrengthShare, accountID, peerID)
|
||||||
|
}
|
||||||
|
|
||||||
func ToGroupsInfoMap(groups []*types.Group, idCount int) map[string][]api.GroupMinimum {
|
func ToGroupsInfoMap(groups []*types.Group, idCount int) map[string][]api.GroupMinimum {
|
||||||
groupsInfoMap := make(map[string][]api.GroupMinimum, idCount)
|
groupsInfoMap := make(map[string][]api.GroupMinimum, idCount)
|
||||||
groupsChecked := make(map[string]struct{}, len(groups)) // not sure why this is needed (left over from old implementation)
|
groupsChecked := make(map[string]struct{}, len(groups)) // not sure why this is needed (left over from old implementation)
|
||||||
@@ -202,6 +207,10 @@ func (m *mockManager) RemoveResourceFromGroupInTransaction(ctx context.Context,
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockManager) GetPeerGroupIDs(ctx context.Context, accountID, peerID string) ([]string, error) {
|
||||||
|
return []string{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func NewManagerMock() Manager {
|
func NewManagerMock() Manager {
|
||||||
return &mockManager{}
|
return &mockManager{}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/encryption"
|
"github.com/netbirdio/netbird/encryption"
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/auth"
|
"github.com/netbirdio/netbird/management/server/auth"
|
||||||
@@ -32,9 +31,10 @@ import (
|
|||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
internalStatus "github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
internalStatus "github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GRPCServer an instance of a Management gRPC API server
|
// GRPCServer an instance of a Management gRPC API server
|
||||||
@@ -662,7 +662,7 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toSyncResponse(ctx context.Context, config *types.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings) *proto.SyncResponse {
|
func toSyncResponse(ctx context.Context, config *types.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string) *proto.SyncResponse {
|
||||||
response := &proto.SyncResponse{
|
response := &proto.SyncResponse{
|
||||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings),
|
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings),
|
||||||
NetworkMap: &proto.NetworkMap{
|
NetworkMap: &proto.NetworkMap{
|
||||||
@@ -674,7 +674,7 @@ func toSyncResponse(ctx context.Context, config *types.Config, peer *nbpeer.Peer
|
|||||||
}
|
}
|
||||||
|
|
||||||
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
|
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
|
||||||
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, nbConfig, extraSettings)
|
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
|
||||||
response.NetbirdConfig = extendedConfig
|
response.NetbirdConfig = extendedConfig
|
||||||
|
|
||||||
response.NetworkMap.PeerConfig = response.PeerConfig
|
response.NetworkMap.PeerConfig = response.PeerConfig
|
||||||
@@ -750,7 +750,12 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
|
|||||||
return status.Errorf(codes.Internal, "error handling request")
|
return status.Errorf(codes.Internal, "error handling request")
|
||||||
}
|
}
|
||||||
|
|
||||||
plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra)
|
peerGroups, err := getPeerGroupIDs(ctx, s.accountManager.GetStore(), peer.AccountID, peer.ID)
|
||||||
|
if err != nil {
|
||||||
|
return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups)
|
||||||
|
|
||||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
|
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -913,6 +918,7 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage)
|
|||||||
|
|
||||||
func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
|
func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
|
||||||
log.WithContext(ctx).Debugf("Logout request from peer [%s]", req.WgPubKey)
|
log.WithContext(ctx).Debugf("Logout request from peer [%s]", req.WgPubKey)
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
empty := &proto.Empty{}
|
empty := &proto.Empty{}
|
||||||
peerKey, err := s.parseRequest(ctx, req, empty)
|
peerKey, err := s.parseRequest(ctx, req, empty)
|
||||||
@@ -920,7 +926,7 @@ func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
peer, err := s.accountManager.GetStore().GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, peerKey.String())
|
peer, err := s.accountManager.GetStore().GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerKey.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("peer %s is not registered for logout", peerKey.String())
|
log.WithContext(ctx).Debugf("peer %s is not registered for logout", peerKey.String())
|
||||||
// TODO: consider idempotency
|
// TODO: consider idempotency
|
||||||
@@ -944,7 +950,7 @@ func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*
|
|||||||
|
|
||||||
s.accountManager.BufferUpdateAccountPeers(ctx, peer.AccountID)
|
s.accountManager.BufferUpdateAccountPeers(ctx, peer.AccountID)
|
||||||
|
|
||||||
log.WithContext(ctx).Infof("peer %s logged out successfully", peerKey.String())
|
log.WithContext(ctx).Debugf("peer %s logged out successfully after %s", peerKey.String(), time.Since(start))
|
||||||
|
|
||||||
return &proto.Empty{}, nil
|
return &proto.Empty{}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -199,6 +199,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
|||||||
settings.Extra = &types.ExtraSettings{
|
settings.Extra = &types.ExtraSettings{
|
||||||
PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled,
|
PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled,
|
||||||
FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled,
|
FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled,
|
||||||
|
FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups,
|
||||||
FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled,
|
FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -327,6 +328,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
|
|||||||
apiSettings.Extra = &api.AccountExtraSettings{
|
apiSettings.Extra = &api.AccountExtraSettings{
|
||||||
PeerApprovalEnabled: settings.Extra.PeerApprovalEnabled,
|
PeerApprovalEnabled: settings.Extra.PeerApprovalEnabled,
|
||||||
NetworkTrafficLogsEnabled: settings.Extra.FlowEnabled,
|
NetworkTrafficLogsEnabled: settings.Extra.FlowEnabled,
|
||||||
|
NetworkTrafficLogsGroups: settings.Extra.FlowGroups,
|
||||||
NetworkTrafficPacketCounterEnabled: settings.Extra.FlowPacketCounterEnabled,
|
NetworkTrafficPacketCounterEnabled: settings.Extra.FlowPacketCounterEnabled,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID
|
|||||||
|
|
||||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
for _, groupID := range groupIDs {
|
for _, groupID := range groupIDs {
|
||||||
_, err := transaction.GetGroupByID(context.Background(), store.LockingStrengthShare, accountID, groupID)
|
_, err := transaction.GetGroupByID(context.Background(), store.LockingStrengthNone, accountID, groupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -97,17 +97,17 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI
|
|||||||
var peers []*nbpeer.Peer
|
var peers []*nbpeer.Peer
|
||||||
var settings *types.Settings
|
var settings *types.Settings
|
||||||
|
|
||||||
groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
|
groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
|
peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,14 +22,15 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/encryption"
|
"github.com/netbirdio/netbird/encryption"
|
||||||
"github.com/netbirdio/netbird/formatter/hook"
|
"github.com/netbirdio/netbird/formatter/hook"
|
||||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -446,6 +447,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
|
|||||||
Return(&types.ExtraSettings{}, nil).
|
Return(&types.ExtraSettings{}, nil).
|
||||||
AnyTimes()
|
AnyTimes()
|
||||||
permissionsManager := permissions.NewManager(store)
|
permissionsManager := permissions.NewManager(store)
|
||||||
|
groupsManager := groups.NewManagerMock()
|
||||||
|
|
||||||
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
||||||
eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||||
@@ -455,7 +457,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
|
|||||||
return nil, nil, "", cleanup, err
|
return nil, nil, "", cleanup, err
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||||
|
|
||||||
ephemeralMgr := NewEphemeralManager(store, accountManager)
|
ephemeralMgr := NewEphemeralManager(store, accountManager)
|
||||||
mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{})
|
mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{})
|
||||||
@@ -645,7 +647,7 @@ func testSyncStatusRace(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, peerWithInvalidStatus.PublicKey().String())
|
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerWithInvalidStatus.PublicKey().String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import (
|
|||||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
@@ -216,7 +217,8 @@ func startServer(
|
|||||||
t.Fatalf("failed creating an account manager: %v", err)
|
t.Fatalf("failed creating an account manager: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
groupsManager := groups.NewManager(str, permissionsManager, accountManager)
|
||||||
|
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||||
mgmtServer, err := server.NewServer(
|
mgmtServer, err := server.NewServer(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
config,
|
config,
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetColumnName(db *gorm.DB, column string) string {
|
func GetColumnName(db *gorm.DB, column string) string {
|
||||||
@@ -466,7 +467,7 @@ func MigrateJsonToTable[T any](ctx context.Context, db *gorm.DB, columnName stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, value := range data {
|
for _, value := range data {
|
||||||
if err := tx.Create(
|
if err := tx.Clauses(clause.OnConflict{DoNothing: true}).Create(
|
||||||
mapperFunc(row["account_id"].(string), row["id"].(string), value),
|
mapperFunc(row["account_id"].(string), row["id"].(string), value),
|
||||||
).Error; err != nil {
|
).Error; err != nil {
|
||||||
return fmt.Errorf("failed to insert id %v: %w", row["id"], err)
|
return fmt.Errorf("failed to insert id %v: %w", row["id"], err)
|
||||||
|
|||||||
@@ -13,9 +13,9 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*[*.a-z]{1,}$`
|
const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*[*.a-z]{1,}$`
|
||||||
@@ -32,7 +32,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupID)
|
return am.Store.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateNameServerGroup creates and saves a new nameserver group
|
// CreateNameServerGroup creates and saves a new nameserver group
|
||||||
@@ -73,11 +73,11 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return transaction.SaveNameServerGroup(ctx, store.LockingStrengthUpdate, newNSGroup)
|
return transaction.SaveNameServerGroup(ctx, newNSGroup)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -112,7 +112,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
|||||||
var updateAccountPeers bool
|
var updateAccountPeers bool
|
||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupToSave.ID)
|
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupToSave.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -127,11 +127,11 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return transaction.SaveNameServerGroup(ctx, store.LockingStrengthUpdate, nsGroupToSave)
|
return transaction.SaveNameServerGroup(ctx, nsGroupToSave)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -173,11 +173,11 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return transaction.DeleteNameServerGroup(ctx, store.LockingStrengthUpdate, accountID, nsGroupID)
|
return transaction.DeleteNameServerGroup(ctx, accountID, nsGroupID)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -202,7 +202,7 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID)
|
return am.Store.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateNameServerGroup(ctx context.Context, transaction store.Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error {
|
func validateNameServerGroup(ctx context.Context, transaction store.Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error {
|
||||||
@@ -216,7 +216,7 @@ func validateNameServerGroup(ctx context.Context, transaction store.Store, accou
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID)
|
nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -226,7 +226,7 @@ func validateNameServerGroup(ctx context.Context, transaction store.Store, accou
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, nameserverGroup.Groups)
|
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, nameserverGroup.Groups)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,8 +14,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
@@ -56,7 +56,7 @@ func (m *managerImpl) GetAllNetworks(ctx context.Context, accountID, userID stri
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.store.GetAccountNetworks(ctx, store.LockingStrengthShare, accountID)
|
return m.store.GetAccountNetworks(ctx, store.LockingStrengthNone, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {
|
func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {
|
||||||
@@ -73,7 +73,7 @@ func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network
|
|||||||
unlock := m.store.AcquireWriteLockByUID(ctx, network.AccountID)
|
unlock := m.store.AcquireWriteLockByUID(ctx, network.AccountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
err = m.store.SaveNetwork(ctx, store.LockingStrengthUpdate, network)
|
err = m.store.SaveNetwork(ctx, network)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to save network: %w", err)
|
return nil, fmt.Errorf("failed to save network: %w", err)
|
||||||
}
|
}
|
||||||
@@ -92,7 +92,7 @@ func (m *managerImpl) GetNetwork(ctx context.Context, accountID, userID, network
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.store.GetNetworkByID(ctx, store.LockingStrengthShare, accountID, networkID)
|
return m.store.GetNetworkByID(ctx, store.LockingStrengthNone, accountID, networkID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {
|
func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {
|
||||||
@@ -114,7 +114,7 @@ func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network
|
|||||||
|
|
||||||
m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkUpdated, network.EventMeta())
|
m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkUpdated, network.EventMeta())
|
||||||
|
|
||||||
return network, m.store.SaveNetwork(ctx, store.LockingStrengthUpdate, network)
|
return network, m.store.SaveNetwork(ctx, network)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error {
|
func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error {
|
||||||
@@ -162,12 +162,12 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
|
|||||||
eventsToStore = append(eventsToStore, event)
|
eventsToStore = append(eventsToStore, event)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.DeleteNetwork(ctx, store.LockingStrengthUpdate, accountID, networkID)
|
err = transaction.DeleteNetwork(ctx, accountID, networkID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to delete network: %w", err)
|
return fmt.Errorf("failed to delete network: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID)
|
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,10 +12,10 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
nbtypes "github.com/netbirdio/netbird/management/server/types"
|
nbtypes "github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/management/server/util"
|
"github.com/netbirdio/netbird/management/server/util"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
@@ -57,7 +57,7 @@ func (m *managerImpl) GetAllResourcesInNetwork(ctx context.Context, accountID, u
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.store.GetNetworkResourcesByNetID(ctx, store.LockingStrengthShare, accountID, networkID)
|
return m.store.GetNetworkResourcesByNetID(ctx, store.LockingStrengthNone, accountID, networkID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) {
|
func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) {
|
||||||
@@ -69,7 +69,7 @@ func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, u
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthShare, accountID)
|
return m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthNone, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) {
|
func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) {
|
||||||
@@ -81,7 +81,7 @@ func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID,
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
resources, err := m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthShare, accountID)
|
resources, err := m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get network resources: %w", err)
|
return nil, fmt.Errorf("failed to get network resources: %w", err)
|
||||||
}
|
}
|
||||||
@@ -113,7 +113,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
|
|||||||
|
|
||||||
var eventsToStore []func()
|
var eventsToStore []func()
|
||||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
_, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name)
|
_, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return status.Errorf(status.InvalidArgument, "resource with name %s already exists", resource.Name)
|
return status.Errorf(status.InvalidArgument, "resource with name %s already exists", resource.Name)
|
||||||
}
|
}
|
||||||
@@ -123,7 +123,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
|
|||||||
return fmt.Errorf("failed to get network: %w", err)
|
return fmt.Errorf("failed to get network: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.SaveNetworkResource(ctx, store.LockingStrengthUpdate, resource)
|
err = transaction.SaveNetworkResource(ctx, resource)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to save network resource: %w", err)
|
return fmt.Errorf("failed to save network resource: %w", err)
|
||||||
}
|
}
|
||||||
@@ -145,7 +145,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
|
|||||||
eventsToStore = append(eventsToStore, event)
|
eventsToStore = append(eventsToStore, event)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, resource.AccountID)
|
err = transaction.IncrementNetworkSerial(ctx, resource.AccountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
}
|
}
|
||||||
@@ -174,7 +174,7 @@ func (m *managerImpl) GetResource(ctx context.Context, accountID, userID, networ
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resourceID)
|
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resourceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get network resource: %w", err)
|
return nil, fmt.Errorf("failed to get network resource: %w", err)
|
||||||
}
|
}
|
||||||
@@ -218,22 +218,22 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
|||||||
return status.NewResourceNotPartOfNetworkError(resource.ID, resource.NetworkID)
|
return status.NewResourceNotPartOfNetworkError(resource.ID, resource.NetworkID)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID)
|
_, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, resource.AccountID, resource.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get network resource: %w", err)
|
return fmt.Errorf("failed to get network resource: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
oldResource, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name)
|
oldResource, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
|
||||||
if err == nil && oldResource.ID != resource.ID {
|
if err == nil && oldResource.ID != resource.ID {
|
||||||
return status.Errorf(status.InvalidArgument, "new resource name already exists")
|
return status.Errorf(status.InvalidArgument, "new resource name already exists")
|
||||||
}
|
}
|
||||||
|
|
||||||
oldResource, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID)
|
oldResource, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, resource.AccountID, resource.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get network resource: %w", err)
|
return fmt.Errorf("failed to get network resource: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.SaveNetworkResource(ctx, store.LockingStrengthUpdate, resource)
|
err = transaction.SaveNetworkResource(ctx, resource)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to save network resource: %w", err)
|
return fmt.Errorf("failed to save network resource: %w", err)
|
||||||
}
|
}
|
||||||
@@ -248,7 +248,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
|||||||
m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceUpdated, resource.EventMeta(network))
|
m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceUpdated, resource.EventMeta(network))
|
||||||
})
|
})
|
||||||
|
|
||||||
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, resource.AccountID)
|
err = transaction.IncrementNetworkSerial(ctx, resource.AccountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
}
|
}
|
||||||
@@ -325,7 +325,7 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
|
|||||||
return fmt.Errorf("failed to delete resource: %w", err)
|
return fmt.Errorf("failed to delete resource: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID)
|
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
}
|
}
|
||||||
@@ -375,7 +375,7 @@ func (m *managerImpl) DeleteResourceInTransaction(ctx context.Context, transacti
|
|||||||
eventsToStore = append(eventsToStore, event)
|
eventsToStore = append(eventsToStore, event)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.DeleteNetworkResource(ctx, store.LockingStrengthUpdate, accountID, resourceID)
|
err = transaction.DeleteNetworkResource(ctx, accountID, resourceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to delete network resource: %w", err)
|
return nil, fmt.Errorf("failed to delete network resource: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,8 +14,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
@@ -54,7 +54,7 @@ func (m *managerImpl) GetAllRoutersInNetwork(ctx context.Context, accountID, use
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthShare, accountID, networkID)
|
return m.store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, accountID, networkID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, userID string) (map[string][]*types.NetworkRouter, error) {
|
func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, userID string) (map[string][]*types.NetworkRouter, error) {
|
||||||
@@ -66,7 +66,7 @@ func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, use
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
routers, err := m.store.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID)
|
routers, err := m.store.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get network routers: %w", err)
|
return nil, fmt.Errorf("failed to get network routers: %w", err)
|
||||||
}
|
}
|
||||||
@@ -93,7 +93,7 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
|
|||||||
|
|
||||||
var network *networkTypes.Network
|
var network *networkTypes.Network
|
||||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthShare, router.AccountID, router.NetworkID)
|
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get network: %w", err)
|
return fmt.Errorf("failed to get network: %w", err)
|
||||||
}
|
}
|
||||||
@@ -104,12 +104,12 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
|
|||||||
|
|
||||||
router.ID = xid.New().String()
|
router.ID = xid.New().String()
|
||||||
|
|
||||||
err = transaction.SaveNetworkRouter(ctx, store.LockingStrengthUpdate, router)
|
err = transaction.SaveNetworkRouter(ctx, router)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create network router: %w", err)
|
return fmt.Errorf("failed to create network router: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, router.AccountID)
|
err = transaction.IncrementNetworkSerial(ctx, router.AccountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
}
|
}
|
||||||
@@ -136,7 +136,7 @@ func (m *managerImpl) GetRouter(ctx context.Context, accountID, userID, networkI
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
router, err := m.store.GetNetworkRouterByID(ctx, store.LockingStrengthShare, accountID, routerID)
|
router, err := m.store.GetNetworkRouterByID(ctx, store.LockingStrengthNone, accountID, routerID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get network router: %w", err)
|
return nil, fmt.Errorf("failed to get network router: %w", err)
|
||||||
}
|
}
|
||||||
@@ -162,7 +162,7 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
|
|||||||
|
|
||||||
var network *networkTypes.Network
|
var network *networkTypes.Network
|
||||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthShare, router.AccountID, router.NetworkID)
|
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get network: %w", err)
|
return fmt.Errorf("failed to get network: %w", err)
|
||||||
}
|
}
|
||||||
@@ -171,12 +171,12 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
|
|||||||
return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID)
|
return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.SaveNetworkRouter(ctx, store.LockingStrengthUpdate, router)
|
err = transaction.SaveNetworkRouter(ctx, router)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to update network router: %w", err)
|
return fmt.Errorf("failed to update network router: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, router.AccountID)
|
err = transaction.IncrementNetworkSerial(ctx, router.AccountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
}
|
}
|
||||||
@@ -213,7 +213,7 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
|
|||||||
return fmt.Errorf("failed to delete network router: %w", err)
|
return fmt.Errorf("failed to delete network router: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID)
|
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
}
|
}
|
||||||
@@ -232,7 +232,7 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (func(), error) {
|
func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (func(), error) {
|
||||||
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthShare, accountID, networkID)
|
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthNone, accountID, networkID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get network: %w", err)
|
return nil, fmt.Errorf("failed to get network: %w", err)
|
||||||
}
|
}
|
||||||
@@ -246,7 +246,7 @@ func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction
|
|||||||
return nil, status.NewRouterNotPartOfNetworkError(routerID, networkID)
|
return nil, status.NewRouterNotPartOfNetworkError(routerID, networkID)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.DeleteNetworkRouter(ctx, store.LockingStrengthUpdate, accountID, routerID)
|
err = transaction.DeleteNetworkRouter(ctx, accountID, routerID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to delete network router: %w", err)
|
return nil, fmt.Errorf("failed to delete network router: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,28 +17,28 @@ import (
|
|||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
|
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
|
||||||
// the current user is not an admin.
|
// the current user is not an admin.
|
||||||
func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
|
func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
|
||||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -48,7 +48,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
|||||||
return nil, status.NewPermissionValidationError(err)
|
return nil, status.NewPermissionValidationError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
accountPeers, err := am.Store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, nameFilter, ipFilter)
|
accountPeers, err := am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, nameFilter, ipFilter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -58,7 +58,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
|||||||
return accountPeers, nil
|
return accountPeers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get account settings: %w", err)
|
return nil, fmt.Errorf("failed to get account settings: %w", err)
|
||||||
}
|
}
|
||||||
@@ -130,7 +130,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
|||||||
}
|
}
|
||||||
|
|
||||||
if peer.AddedWithSSOLogin() {
|
if peer.AddedWithSSOLogin() {
|
||||||
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -173,7 +173,7 @@ func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocatio
|
|||||||
peer.Location.CountryCode = location.Country.ISOCode
|
peer.Location.CountryCode = location.Country.ISOCode
|
||||||
peer.Location.CityName = location.City.Names.En
|
peer.Location.CityName = location.City.Names.En
|
||||||
peer.Location.GeoNameID = location.City.GeonameID
|
peer.Location.GeoNameID = location.City.GeonameID
|
||||||
err = transaction.SavePeerLocation(ctx, store.LockingStrengthUpdate, accountID, peer)
|
err = transaction.SavePeerLocation(ctx, accountID, peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
|
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
|
||||||
}
|
}
|
||||||
@@ -182,7 +182,7 @@ func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocatio
|
|||||||
|
|
||||||
log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected)
|
log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected)
|
||||||
|
|
||||||
err := transaction.SavePeerStatus(ctx, store.LockingStrengthUpdate, accountID, peer.ID, *newStatus)
|
err := transaction.SavePeerStatus(ctx, accountID, peer.ID, *newStatus)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
@@ -219,7 +219,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
settings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -281,7 +281,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
|||||||
inactivityExpirationChanged = true
|
inactivityExpirationChanged = true
|
||||||
}
|
}
|
||||||
|
|
||||||
return transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer)
|
return transaction.SavePeer(ctx, accountID, peer)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -346,7 +346,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
|||||||
return status.NewPermissionDeniedError()
|
return status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthShare, peerID)
|
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -383,7 +383,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
|||||||
return fmt.Errorf("failed to delete peer: %w", err)
|
return fmt.Errorf("failed to delete peer: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -609,7 +609,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
newPeer.DNSLabel = freeLabel
|
newPeer.DNSLabel = freeLabel
|
||||||
newPeer.IP = freeIP
|
newPeer.IP = freeIP
|
||||||
|
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
|
||||||
defer func() {
|
defer func() {
|
||||||
if unlock != nil {
|
if unlock != nil {
|
||||||
unlock()
|
unlock()
|
||||||
@@ -617,7 +617,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
err = transaction.AddPeerToAccount(ctx, store.LockingStrengthUpdate, newPeer)
|
err = transaction.AddPeerToAccount(ctx, newPeer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -658,7 +658,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID)
|
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
}
|
}
|
||||||
@@ -734,7 +734,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
|||||||
var err error
|
var err error
|
||||||
var postureChecks []*posture.Checks
|
var postureChecks []*posture.Checks
|
||||||
|
|
||||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
@@ -746,7 +746,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
|||||||
}
|
}
|
||||||
|
|
||||||
if peer.UserID != "" {
|
if peer.UserID != "" {
|
||||||
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID)
|
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, peer.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -774,7 +774,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
|||||||
if updated {
|
if updated {
|
||||||
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
|
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
|
||||||
log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID)
|
log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID)
|
||||||
if err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer); err != nil {
|
if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -849,7 +849,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
|||||||
var isPeerUpdated bool
|
var isPeerUpdated bool
|
||||||
var postureChecks []*posture.Checks
|
var postureChecks []*posture.Checks
|
||||||
|
|
||||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
@@ -911,7 +911,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
|||||||
}
|
}
|
||||||
|
|
||||||
if shouldStorePeer {
|
if shouldStorePeer {
|
||||||
if err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer); err != nil {
|
if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -934,7 +934,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
|||||||
|
|
||||||
// getPeerPostureChecks returns the posture checks for the peer.
|
// getPeerPostureChecks returns the posture checks for the peer.
|
||||||
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) {
|
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) {
|
||||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
|
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -958,7 +958,7 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI
|
|||||||
peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...)
|
peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
peerPostureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthShare, accountID, peerPostureChecksIDs)
|
peerPostureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, peerPostureChecksIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -973,7 +973,7 @@ func processPeerPostureChecks(ctx context.Context, transaction store.Store, poli
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, rule.Sources)
|
sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, rule.Sources)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -998,7 +998,7 @@ func processPeerPostureChecks(ctx context.Context, transaction store.Store, poli
|
|||||||
// with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired
|
// with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired
|
||||||
// and before starting the engine, we do the checks without an account lock to avoid piling up requests.
|
// and before starting the engine, we do the checks without an account lock to avoid piling up requests.
|
||||||
func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login types.PeerLogin) error {
|
func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login types.PeerLogin) error {
|
||||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, login.WireGuardPubKey)
|
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, login.WireGuardPubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -1009,7 +1009,7 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -1080,7 +1080,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact
|
|||||||
// If peer was expired before and if it reached this point, it is re-authenticated.
|
// If peer was expired before and if it reached this point, it is re-authenticated.
|
||||||
// UserID is present, meaning that JWT validation passed successfully in the API layer.
|
// UserID is present, meaning that JWT validation passed successfully in the API layer.
|
||||||
peer = peer.UpdateLastLogin()
|
peer = peer.UpdateLastLogin()
|
||||||
err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, peer.AccountID, peer)
|
err = transaction.SavePeer(ctx, peer.AccountID, peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -1090,7 +1090,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact
|
|||||||
log.WithContext(ctx).Debugf("failed to update user last login: %v", err)
|
log.WithContext(ctx).Debugf("failed to update user last login: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, peer.AccountID)
|
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, peer.AccountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get account settings: %w", err)
|
return fmt.Errorf("failed to get account settings: %w", err)
|
||||||
}
|
}
|
||||||
@@ -1132,7 +1132,7 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *types.Se
|
|||||||
|
|
||||||
// GetPeer for a given accountID, peerID and userID error if not found.
|
// GetPeer for a given accountID, peerID and userID error if not found.
|
||||||
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
|
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
|
||||||
peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID)
|
peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1145,7 +1145,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
|
|||||||
return peer, nil
|
return peer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1171,7 +1171,7 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
|
|||||||
|
|
||||||
// it is also possible that user doesn't own the peer but some of his peers have access to it,
|
// it is also possible that user doesn't own the peer but some of his peers have access to it,
|
||||||
// this is a valid case, show the peer as well.
|
// this is a valid case, show the peer as well.
|
||||||
userPeers, err := am.Store.GetUserPeers(ctx, store.LockingStrengthShare, accountID, userID)
|
userPeers, err := am.Store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1275,8 +1275,9 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
|||||||
}
|
}
|
||||||
am.metrics.UpdateChannelMetrics().CountMergeNetworkMapDuration(time.Since(start))
|
am.metrics.UpdateChannelMetrics().CountMergeNetworkMapDuration(time.Since(start))
|
||||||
|
|
||||||
|
peerGroups := account.GetPeerGroups(p.ID)
|
||||||
start = time.Now()
|
start = time.Now()
|
||||||
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting)
|
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups))
|
||||||
am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start))
|
am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start))
|
||||||
|
|
||||||
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
||||||
@@ -1386,7 +1387,8 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings)
|
peerGroups := account.GetPeerGroups(peerId)
|
||||||
|
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups))
|
||||||
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1394,7 +1396,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
|
|||||||
// If there is no peer that expires this function returns false and a duration of 0.
|
// If there is no peer that expires this function returns false and a duration of 0.
|
||||||
// This function only considers peers that haven't been expired yet and that are connected.
|
// This function only considers peers that haven't been expired yet and that are connected.
|
||||||
func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
|
func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
|
||||||
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthShare, accountID)
|
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err)
|
log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err)
|
||||||
return peerSchedulerRetryInterval, true
|
return peerSchedulerRetryInterval, true
|
||||||
@@ -1404,7 +1406,7 @@ func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, acco
|
|||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
|
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
|
||||||
return peerSchedulerRetryInterval, true
|
return peerSchedulerRetryInterval, true
|
||||||
@@ -1438,7 +1440,7 @@ func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, acco
|
|||||||
// If there is no peer that expires this function returns false and a duration of 0.
|
// If there is no peer that expires this function returns false and a duration of 0.
|
||||||
// This function only considers peers that haven't been expired yet and that are not connected.
|
// This function only considers peers that haven't been expired yet and that are not connected.
|
||||||
func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
|
func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
|
||||||
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthShare, accountID)
|
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err)
|
log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err)
|
||||||
return peerSchedulerRetryInterval, true
|
return peerSchedulerRetryInterval, true
|
||||||
@@ -1448,7 +1450,7 @@ func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Conte
|
|||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
|
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
|
||||||
return peerSchedulerRetryInterval, true
|
return peerSchedulerRetryInterval, true
|
||||||
@@ -1479,12 +1481,12 @@ func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Conte
|
|||||||
|
|
||||||
// getExpiredPeers returns peers that have been expired.
|
// getExpiredPeers returns peers that have been expired.
|
||||||
func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
|
func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
|
||||||
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthShare, accountID)
|
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1502,12 +1504,12 @@ func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID
|
|||||||
|
|
||||||
// getInactivePeers returns peers that have been expired by inactivity
|
// getInactivePeers returns peers that have been expired by inactivity
|
||||||
func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
|
func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
|
||||||
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthShare, accountID)
|
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1530,7 +1532,7 @@ func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, p
|
|||||||
|
|
||||||
// getPeerGroupIDs returns the IDs of the groups that the peer is part of.
|
// getPeerGroupIDs returns the IDs of the groups that the peer is part of.
|
||||||
func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID string, peerID string) ([]string, error) {
|
func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID string, peerID string) ([]string, error) {
|
||||||
return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthShare, accountID, peerID)
|
return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
|
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
|
||||||
@@ -1548,7 +1550,7 @@ func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID
|
|||||||
func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) {
|
func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) {
|
||||||
var peerDeletedEvents []func()
|
var peerDeletedEvents []func()
|
||||||
|
|
||||||
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1568,7 +1570,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = transaction.DeletePeer(ctx, store.LockingStrengthUpdate, accountID, peer.ID); err != nil {
|
if err = transaction.DeletePeer(ctx, accountID, peer.ID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1624,7 +1626,7 @@ func (am *DefaultAccountManager) validatePeerDelete(ctx context.Context, transac
|
|||||||
|
|
||||||
// isPeerLinkedToNetworkRouter checks if a peer is linked to any network router in the account.
|
// isPeerLinkedToNetworkRouter checks if a peer is linked to any network router in the account.
|
||||||
func isPeerLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, peerID string) (bool, *routerTypes.NetworkRouter) {
|
func isPeerLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, peerID string) (bool, *routerTypes.NetworkRouter) {
|
||||||
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID)
|
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("error retrieving network routers while checking peer linkage: %v", err)
|
log.WithContext(ctx).Errorf("error retrieving network routers while checking peer linkage: %v", err)
|
||||||
return false, nil
|
return false, nil
|
||||||
|
|||||||
@@ -38,8 +38,6 @@ import (
|
|||||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
@@ -47,6 +45,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
nbroute "github.com/netbirdio/netbird/route"
|
nbroute "github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPeer_LoginExpired(t *testing.T) {
|
func TestPeer_LoginExpired(t *testing.T) {
|
||||||
@@ -1164,7 +1164,7 @@ func TestToSyncResponse(t *testing.T) {
|
|||||||
}
|
}
|
||||||
dnsCache := &DNSConfigCache{}
|
dnsCache := &DNSConfigCache{}
|
||||||
accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true}
|
accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true}
|
||||||
response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil)
|
response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{})
|
||||||
|
|
||||||
assert.NotNil(t, response)
|
assert.NotNil(t, response)
|
||||||
// assert peer config
|
// assert peer config
|
||||||
@@ -1307,7 +1307,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, newPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels)
|
assert.Equal(t, newPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels)
|
||||||
|
|
||||||
peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, addedPeer.Key)
|
peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, addedPeer.Key)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, peer.AccountID, existingAccountID)
|
assert.Equal(t, peer.AccountID, existingAccountID)
|
||||||
assert.Equal(t, peer.UserID, existingUserID)
|
assert.Equal(t, peer.UserID, existingUserID)
|
||||||
@@ -1442,7 +1442,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
|||||||
assert.NotNil(t, addedPeer, "addedPeer should not be nil on success")
|
assert.NotNil(t, addedPeer, "addedPeer should not be nil on success")
|
||||||
assert.Equal(t, currentPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels, "ExtraDNSLabels mismatch")
|
assert.Equal(t, currentPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels, "ExtraDNSLabels mismatch")
|
||||||
|
|
||||||
peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, currentPeer.Key)
|
peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, currentPeer.Key)
|
||||||
require.NoError(t, err, "Failed to get peer by pub key: %s", currentPeer.Key)
|
require.NoError(t, err, "Failed to get peer by pub key: %s", currentPeer.Key)
|
||||||
assert.Equal(t, existingAccountID, peerFromStore.AccountID, "AccountID mismatch for peer from store")
|
assert.Equal(t, existingAccountID, peerFromStore.AccountID, "AccountID mismatch for peer from store")
|
||||||
assert.Equal(t, currentPeer.ExtraDNSLabels, peerFromStore.ExtraDNSLabels, "ExtraDNSLabels mismatch for peer from store")
|
assert.Equal(t, currentPeer.ExtraDNSLabels, peerFromStore.ExtraDNSLabels, "ExtraDNSLabels mismatch for peer from store")
|
||||||
@@ -1476,8 +1476,9 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
|||||||
|
|
||||||
func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||||
engine := os.Getenv("NETBIRD_STORE_ENGINE")
|
engine := os.Getenv("NETBIRD_STORE_ENGINE")
|
||||||
if engine == "sqlite" || engine == "" {
|
if engine == "sqlite" || engine == "mysql" || engine == "" {
|
||||||
t.Skip("Skipping test because sqlite test store is not respecting foreign keys")
|
// we intentionally disabled foreign keys in mysql
|
||||||
|
t.Skip("Skipping test because store is not respecting foreign keys")
|
||||||
}
|
}
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
@@ -1528,7 +1529,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
|||||||
_, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer)
|
_, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
_, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, newPeer.Key)
|
_, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, newPeer.Key)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
account, err := s.GetAccount(context.Background(), existingAccountID)
|
account, err := s.GetAccount(context.Background(), existingAccountID)
|
||||||
@@ -1699,7 +1700,7 @@ func Test_LoginPeer(t *testing.T) {
|
|||||||
|
|
||||||
assert.Equal(t, existingAccountID, loggedinPeer.AccountID, "AccountID mismatch for logged peer")
|
assert.Equal(t, existingAccountID, loggedinPeer.AccountID, "AccountID mismatch for logged peer")
|
||||||
|
|
||||||
peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, loginInput.WireGuardPubKey)
|
peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, loginInput.WireGuardPubKey)
|
||||||
require.NoError(t, err, "Failed to get peer by pub key: %s", loginInput.WireGuardPubKey)
|
require.NoError(t, err, "Failed to get peer by pub key: %s", loginInput.WireGuardPubKey)
|
||||||
assert.Equal(t, existingAccountID, peerFromStore.AccountID, "AccountID mismatch for peer from store")
|
assert.Equal(t, existingAccountID, peerFromStore.AccountID, "AccountID mismatch for peer from store")
|
||||||
assert.Equal(t, loggedinPeer.ID, peerFromStore.ID, "Peer ID mismatch between loggedinPeer and peerFromStore")
|
assert.Equal(t, loggedinPeer.ID, peerFromStore.ID, "Peer ID mismatch between loggedinPeer and peerFromStore")
|
||||||
@@ -2160,10 +2161,10 @@ func Test_IsUniqueConstraintError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
t.Cleanup(cleanup)
|
t.Cleanup(cleanup)
|
||||||
|
|
||||||
err = s.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer)
|
err = s.AddPeerToAccount(context.Background(), peer)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = s.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer)
|
err = s.AddPeerToAccount(context.Background(), peer)
|
||||||
result := isUniqueConstraintError(err)
|
result := isUniqueConstraintError(err)
|
||||||
assert.True(t, result)
|
assert.True(t, result)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
@@ -42,7 +42,7 @@ func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID str
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.store.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID)
|
return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
|
func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
|
||||||
@@ -52,12 +52,12 @@ func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !allowed {
|
if !allowed {
|
||||||
return m.store.GetUserPeers(ctx, store.LockingStrengthShare, accountID, userID)
|
return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
|
return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) {
|
func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) {
|
||||||
return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthShare, peerID)
|
return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,9 +11,9 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/roles"
|
"github.com/netbirdio/netbird/management/server/permissions/roles"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
@@ -45,7 +45,7 @@ func (m *managerImpl) ValidateUserPermissions(
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,11 +6,11 @@ import (
|
|||||||
|
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
@@ -27,7 +27,7 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policyID)
|
return am.Store.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policyID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SavePolicy in the store
|
// SavePolicy in the store
|
||||||
@@ -61,7 +61,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -71,7 +71,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
|||||||
saveFunc = transaction.SavePolicy
|
saveFunc = transaction.SavePolicy
|
||||||
}
|
}
|
||||||
|
|
||||||
return saveFunc(ctx, store.LockingStrengthUpdate, policy)
|
return saveFunc(ctx, policy)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -113,11 +113,11 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return transaction.DeletePolicy(ctx, store.LockingStrengthUpdate, accountID, policyID)
|
return transaction.DeletePolicy(ctx, accountID, policyID)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -142,13 +142,13 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
|
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
|
// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
|
||||||
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) {
|
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) {
|
||||||
if isUpdate {
|
if isUpdate {
|
||||||
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID)
|
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
@@ -173,7 +173,7 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a
|
|||||||
// validatePolicy validates the policy and its rules.
|
// validatePolicy validates the policy and its rules.
|
||||||
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error {
|
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error {
|
||||||
if policy.ID != "" {
|
if policy.ID != "" {
|
||||||
_, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID)
|
_, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -182,12 +182,12 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
|
|||||||
policy.AccountID = accountID
|
policy.AccountID = accountID
|
||||||
}
|
}
|
||||||
|
|
||||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, policy.RuleGroups())
|
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, policy.RuleGroups())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthShare, accountID, policy.SourcePostureChecks)
|
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, policy.SourcePostureChecks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,9 +13,9 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
||||||
@@ -27,7 +27,7 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID)
|
return am.Store.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecksID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SavePostureChecks saves a posture check.
|
// SavePostureChecks saves a posture check.
|
||||||
@@ -62,7 +62,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -70,7 +70,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
|||||||
}
|
}
|
||||||
|
|
||||||
postureChecks.AccountID = accountID
|
postureChecks.AccountID = accountID
|
||||||
return transaction.SavePostureChecks(ctx, store.LockingStrengthUpdate, postureChecks)
|
return transaction.SavePostureChecks(ctx, postureChecks)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -101,7 +101,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
|
|||||||
var postureChecks *posture.Checks
|
var postureChecks *posture.Checks
|
||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
postureChecks, err = transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID)
|
postureChecks, err = transaction.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecksID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -110,11 +110,11 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return transaction.DeletePostureChecks(ctx, store.LockingStrengthUpdate, accountID, postureChecksID)
|
return transaction.DeletePostureChecks(ctx, accountID, postureChecksID)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -135,7 +135,7 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID)
|
return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getPeerPostureChecks returns the posture checks applied for a given peer.
|
// getPeerPostureChecks returns the posture checks applied for a given peer.
|
||||||
@@ -161,7 +161,7 @@ func (am *DefaultAccountManager) getPeerPostureChecks(account *types.Account, pe
|
|||||||
|
|
||||||
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
|
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
|
||||||
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) {
|
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) {
|
||||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
|
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
@@ -190,14 +190,14 @@ func validatePostureChecks(ctx context.Context, transaction store.Store, account
|
|||||||
|
|
||||||
// If the posture check already has an ID, verify its existence in the store.
|
// If the posture check already has an ID, verify its existence in the store.
|
||||||
if postureChecks.ID != "" {
|
if postureChecks.ID != "" {
|
||||||
if _, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecks.ID); err != nil {
|
if _, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecks.ID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// For new posture checks, ensure no duplicates by name.
|
// For new posture checks, ensure no duplicates by name.
|
||||||
checks, err := transaction.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID)
|
checks, err := transaction.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -259,7 +259,7 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
|
|||||||
|
|
||||||
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
|
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
|
||||||
func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error {
|
func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error {
|
||||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
|
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,15 +9,15 @@ import (
|
|||||||
|
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetRoute gets a route object from account and route IDs
|
// GetRoute gets a route object from account and route IDs
|
||||||
@@ -30,7 +30,7 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, accountID, string(routeID))
|
return am.Store.GetRouteByID(ctx, store.LockingStrengthNone, accountID, string(routeID))
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
|
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
|
||||||
@@ -59,7 +59,7 @@ func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction sto
|
|||||||
seenPeers[string(prefixRoute.ID)] = true
|
seenPeers[string(prefixRoute.ID)] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
peerGroupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, prefixRoute.PeerGroups)
|
peerGroupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, prefixRoute.PeerGroups)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -83,7 +83,7 @@ func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction sto
|
|||||||
|
|
||||||
if peerID := checkRoute.Peer; peerID != "" {
|
if peerID := checkRoute.Peer; peerID != "" {
|
||||||
// check that peerID exists and is not in any route as single peer or part of the group
|
// check that peerID exists and is not in any route as single peer or part of the group
|
||||||
_, err = transaction.GetPeerByID(context.Background(), store.LockingStrengthShare, accountID, peerID)
|
_, err = transaction.GetPeerByID(context.Background(), store.LockingStrengthNone, accountID, peerID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
||||||
}
|
}
|
||||||
@@ -104,7 +104,7 @@ func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction sto
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
|
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
|
||||||
peersMap, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, group.Peers)
|
peersMap, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, group.Peers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -181,11 +181,11 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return transaction.SaveRoute(ctx, store.LockingStrengthUpdate, newRoute)
|
return transaction.SaveRoute(ctx, newRoute)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -238,11 +238,11 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
|||||||
}
|
}
|
||||||
routeToSave.AccountID = accountID
|
routeToSave.AccountID = accountID
|
||||||
|
|
||||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return transaction.SaveRoute(ctx, store.LockingStrengthUpdate, routeToSave)
|
return transaction.SaveRoute(ctx, routeToSave)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -284,11 +284,11 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return transaction.DeleteRoute(ctx, store.LockingStrengthUpdate, accountID, string(routeID))
|
return transaction.DeleteRoute(ctx, accountID, string(routeID))
|
||||||
})
|
})
|
||||||
|
|
||||||
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
|
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
|
||||||
@@ -310,7 +310,7 @@ func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, user
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
|
return am.Store.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateRoute(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) error {
|
func validateRoute(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) error {
|
||||||
@@ -353,7 +353,7 @@ func validateRoute(ctx context.Context, transaction store.Store, accountID strin
|
|||||||
// validateRouteGroups validates the route groups and returns the validated groups map.
|
// validateRouteGroups validates the route groups and returns the validated groups map.
|
||||||
func validateRouteGroups(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) (map[string]*types.Group, error) {
|
func validateRouteGroups(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) (map[string]*types.Group, error) {
|
||||||
groupsToValidate := slices.Concat(routeToSave.Groups, routeToSave.PeerGroups, routeToSave.AccessControlGroups)
|
groupsToValidate := slices.Concat(routeToSave.Groups, routeToSave.PeerGroups, routeToSave.AccessControlGroups)
|
||||||
groupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupsToValidate)
|
groupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupsToValidate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -494,7 +494,7 @@ func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, ro
|
|||||||
|
|
||||||
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
|
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
|
||||||
func getRoutesByPrefixOrDomains(ctx context.Context, transaction store.Store, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) {
|
func getRoutesByPrefixOrDomains(ctx context.Context, transaction store.Store, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) {
|
||||||
accountRoutes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
|
accountRoutes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
@@ -27,6 +26,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -1100,7 +1100,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route")
|
assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route")
|
||||||
|
|
||||||
groups, err := am.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, account.Id)
|
groups, err := am.Store.GetAccountGroups(context.Background(), store.LockingStrengthNone, account.Id)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
var groupHA1, groupHA2 *types.Group
|
var groupHA1, groupHA2 *types.Group
|
||||||
for _, group := range groups {
|
for _, group := range groups {
|
||||||
|
|||||||
@@ -11,10 +11,10 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/management/server/users"
|
"github.com/netbirdio/netbird/management/server/users"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
@@ -60,7 +60,7 @@ func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string)
|
|||||||
return nil, fmt.Errorf("get extra settings: %w", err)
|
return nil, fmt.Errorf("get extra settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get account settings: %w", err)
|
return nil, fmt.Errorf("get account settings: %w", err)
|
||||||
}
|
}
|
||||||
@@ -68,6 +68,7 @@ func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string)
|
|||||||
// Once we migrate the peer approval to settings manager this merging is obsolete
|
// Once we migrate the peer approval to settings manager this merging is obsolete
|
||||||
if settings.Extra != nil {
|
if settings.Extra != nil {
|
||||||
settings.Extra.FlowEnabled = extraSettings.FlowEnabled
|
settings.Extra.FlowEnabled = extraSettings.FlowEnabled
|
||||||
|
settings.Extra.FlowGroups = extraSettings.FlowGroups
|
||||||
settings.Extra.FlowPacketCounterEnabled = extraSettings.FlowPacketCounterEnabled
|
settings.Extra.FlowPacketCounterEnabled = extraSettings.FlowPacketCounterEnabled
|
||||||
settings.Extra.FlowENCollectionEnabled = extraSettings.FlowENCollectionEnabled
|
settings.Extra.FlowENCollectionEnabled = extraSettings.FlowENCollectionEnabled
|
||||||
settings.Extra.FlowDnsCollectionEnabled = extraSettings.FlowDnsCollectionEnabled
|
settings.Extra.FlowDnsCollectionEnabled = extraSettings.FlowDnsCollectionEnabled
|
||||||
@@ -82,7 +83,7 @@ func (m *managerImpl) GetExtraSettings(ctx context.Context, accountID string) (*
|
|||||||
return nil, fmt.Errorf("get extra settings: %w", err)
|
return nil, fmt.Errorf("get extra settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get account settings: %w", err)
|
return nil, fmt.Errorf("get account settings: %w", err)
|
||||||
}
|
}
|
||||||
@@ -93,6 +94,7 @@ func (m *managerImpl) GetExtraSettings(ctx context.Context, accountID string) (*
|
|||||||
}
|
}
|
||||||
|
|
||||||
settings.Extra.FlowEnabled = extraSettings.FlowEnabled
|
settings.Extra.FlowEnabled = extraSettings.FlowEnabled
|
||||||
|
settings.Extra.FlowGroups = extraSettings.FlowGroups
|
||||||
|
|
||||||
return settings.Extra, nil
|
return settings.Extra, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,10 +10,10 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/management/server/util"
|
"github.com/netbirdio/netbird/management/server/util"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -81,7 +81,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
|
|||||||
events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, autoGroups, nil, setupKey)
|
events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, autoGroups, nil, setupKey)
|
||||||
eventsToStore = append(eventsToStore, events...)
|
eventsToStore = append(eventsToStore, events...)
|
||||||
|
|
||||||
return transaction.SaveSetupKey(ctx, store.LockingStrengthUpdate, setupKey)
|
return transaction.SaveSetupKey(ctx, setupKey)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -127,7 +127,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
|||||||
return status.Errorf(status.InvalidArgument, "invalid auto groups: %v", err)
|
return status.Errorf(status.InvalidArgument, "invalid auto groups: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyToSave.Id)
|
oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthUpdate, accountID, keyToSave.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -148,7 +148,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
|||||||
events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups, oldKey)
|
events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups, oldKey)
|
||||||
eventsToStore = append(eventsToStore, events...)
|
eventsToStore = append(eventsToStore, events...)
|
||||||
|
|
||||||
return transaction.SaveSetupKey(ctx, store.LockingStrengthUpdate, newKey)
|
return transaction.SaveSetupKey(ctx, newKey)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -175,7 +175,7 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID)
|
return am.Store.GetAccountSetupKeys(ctx, store.LockingStrengthNone, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
|
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
|
||||||
@@ -188,7 +188,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
setupKey, err := am.Store.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID)
|
setupKey, err := am.Store.GetSetupKeyByID(ctx, store.LockingStrengthNone, accountID, keyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -214,12 +214,12 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
|
|||||||
var deletedSetupKey *types.SetupKey
|
var deletedSetupKey *types.SetupKey
|
||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID)
|
deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthUpdate, accountID, keyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return transaction.DeleteSetupKey(ctx, store.LockingStrengthUpdate, accountID, keyID)
|
return transaction.DeleteSetupKey(ctx, accountID, keyID)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -231,7 +231,7 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func validateSetupKeyAutoGroups(ctx context.Context, transaction store.Store, accountID string, autoGroupIDs []string) error {
|
func validateSetupKeyAutoGroups(ctx context.Context, transaction store.Store, accountID string, autoGroupIDs []string) error {
|
||||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, autoGroupIDs)
|
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, autoGroupIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -255,7 +255,7 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran
|
|||||||
var eventsToStore []func()
|
var eventsToStore []func()
|
||||||
|
|
||||||
modifiedGroups := slices.Concat(addedGroups, removedGroups)
|
modifiedGroups := slices.Concat(addedGroups, removedGroups)
|
||||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups)
|
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, modifiedGroups)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err)
|
log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err)
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -72,8 +72,8 @@ type Store interface {
|
|||||||
SaveAccount(ctx context.Context, account *types.Account) error
|
SaveAccount(ctx context.Context, account *types.Account) error
|
||||||
DeleteAccount(ctx context.Context, account *types.Account) error
|
DeleteAccount(ctx context.Context, account *types.Account) error
|
||||||
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
|
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
|
||||||
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error
|
SaveDNSSettings(ctx context.Context, accountID string, settings *types.DNSSettings) error
|
||||||
SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.Settings) error
|
SaveAccountSettings(ctx context.Context, accountID string, settings *types.Settings) error
|
||||||
CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error)
|
CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error)
|
||||||
SaveAccountOnboarding(ctx context.Context, onboarding *types.AccountOnboarding) error
|
SaveAccountOnboarding(ctx context.Context, onboarding *types.AccountOnboarding) error
|
||||||
|
|
||||||
@@ -81,10 +81,10 @@ type Store interface {
|
|||||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error)
|
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error)
|
||||||
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error)
|
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error)
|
||||||
GetAccountOwner(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.User, error)
|
GetAccountOwner(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.User, error)
|
||||||
SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*types.User) error
|
SaveUsers(ctx context.Context, users []*types.User) error
|
||||||
SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error
|
SaveUser(ctx context.Context, user *types.User) error
|
||||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||||
DeleteUser(ctx context.Context, lockStrength LockingStrength, accountID, userID string) error
|
DeleteUser(ctx context.Context, accountID, userID string) error
|
||||||
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
||||||
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
||||||
DeleteTokenID2UserIDIndex(tokenID string) error
|
DeleteTokenID2UserIDIndex(tokenID string) error
|
||||||
@@ -92,34 +92,34 @@ type Store interface {
|
|||||||
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error)
|
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error)
|
||||||
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error)
|
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error)
|
||||||
GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error)
|
GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error)
|
||||||
MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error
|
MarkPATUsed(ctx context.Context, patID string) error
|
||||||
SavePAT(ctx context.Context, strength LockingStrength, pat *types.PersonalAccessToken) error
|
SavePAT(ctx context.Context, pat *types.PersonalAccessToken) error
|
||||||
DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error
|
DeletePAT(ctx context.Context, userID, patID string) error
|
||||||
|
|
||||||
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error)
|
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error)
|
||||||
GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error)
|
GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error)
|
||||||
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error)
|
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error)
|
||||||
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error)
|
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error)
|
||||||
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error)
|
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error)
|
||||||
CreateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error
|
CreateGroups(ctx context.Context, accountID string, groups []*types.Group) error
|
||||||
UpdateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error
|
UpdateGroups(ctx context.Context, accountID string, groups []*types.Group) error
|
||||||
CreateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error
|
CreateGroup(ctx context.Context, group *types.Group) error
|
||||||
UpdateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error
|
UpdateGroup(ctx context.Context, group *types.Group) error
|
||||||
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error
|
DeleteGroup(ctx context.Context, accountID, groupID string) error
|
||||||
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error
|
DeleteGroups(ctx context.Context, accountID string, groupIDs []string) error
|
||||||
|
|
||||||
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Policy, error)
|
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Policy, error)
|
||||||
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.Policy, error)
|
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.Policy, error)
|
||||||
CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error
|
CreatePolicy(ctx context.Context, policy *types.Policy) error
|
||||||
SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error
|
SavePolicy(ctx context.Context, policy *types.Policy) error
|
||||||
DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error
|
DeletePolicy(ctx context.Context, accountID, policyID string) error
|
||||||
|
|
||||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||||
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
|
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
|
||||||
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error)
|
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error)
|
||||||
GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error)
|
GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error)
|
||||||
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
|
SavePostureChecks(ctx context.Context, postureCheck *posture.Checks) error
|
||||||
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
|
DeletePostureChecks(ctx context.Context, accountID, postureChecksID string) error
|
||||||
|
|
||||||
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string, hostname string) ([]string, error)
|
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string, hostname string) ([]string, error)
|
||||||
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
||||||
@@ -130,7 +130,7 @@ type Store interface {
|
|||||||
GetPeerGroupIDs(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]string, error)
|
GetPeerGroupIDs(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]string, error)
|
||||||
AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error
|
AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error
|
||||||
RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error
|
RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error
|
||||||
AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error
|
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
|
||||||
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
||||||
GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||||
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
|
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||||
@@ -139,30 +139,30 @@ type Store interface {
|
|||||||
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
||||||
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
||||||
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)
|
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)
|
||||||
SavePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error
|
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||||
SavePeerStatus(ctx context.Context, lockStrength LockingStrength, accountID, peerID string, status nbpeer.PeerStatus) error
|
SavePeerStatus(ctx context.Context, accountID, peerID string, status nbpeer.PeerStatus) error
|
||||||
SavePeerLocation(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error
|
SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||||
DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error
|
DeletePeer(ctx context.Context, accountID string, peerID string) error
|
||||||
|
|
||||||
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error)
|
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error)
|
||||||
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
|
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
|
||||||
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.SetupKey, error)
|
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.SetupKey, error)
|
||||||
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types.SetupKey, error)
|
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types.SetupKey, error)
|
||||||
SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *types.SetupKey) error
|
SaveSetupKey(ctx context.Context, setupKey *types.SetupKey) error
|
||||||
DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error
|
DeleteSetupKey(ctx context.Context, accountID, keyID string) error
|
||||||
|
|
||||||
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
|
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
|
||||||
GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error)
|
GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error)
|
||||||
SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error
|
SaveRoute(ctx context.Context, route *route.Route) error
|
||||||
DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error
|
DeleteRoute(ctx context.Context, accountID, routeID string) error
|
||||||
|
|
||||||
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
|
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
|
||||||
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
|
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
|
||||||
SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *dns.NameServerGroup) error
|
SaveNameServerGroup(ctx context.Context, nameServerGroup *dns.NameServerGroup) error
|
||||||
DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error
|
DeleteNameServerGroup(ctx context.Context, accountID, nameServerGroupID string) error
|
||||||
|
|
||||||
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
||||||
IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error
|
IncrementNetworkSerial(ctx context.Context, accountId string) error
|
||||||
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*types.Network, error)
|
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*types.Network, error)
|
||||||
|
|
||||||
GetInstallationID() string
|
GetInstallationID() string
|
||||||
@@ -184,21 +184,21 @@ type Store interface {
|
|||||||
|
|
||||||
GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error)
|
GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error)
|
||||||
GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error)
|
GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error)
|
||||||
SaveNetwork(ctx context.Context, lockStrength LockingStrength, network *networkTypes.Network) error
|
SaveNetwork(ctx context.Context, network *networkTypes.Network) error
|
||||||
DeleteNetwork(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) error
|
DeleteNetwork(ctx context.Context, accountID, networkID string) error
|
||||||
|
|
||||||
GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error)
|
GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error)
|
||||||
GetNetworkRoutersByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error)
|
GetNetworkRoutersByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error)
|
||||||
GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error)
|
GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error)
|
||||||
SaveNetworkRouter(ctx context.Context, lockStrength LockingStrength, router *routerTypes.NetworkRouter) error
|
SaveNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error
|
||||||
DeleteNetworkRouter(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) error
|
DeleteNetworkRouter(ctx context.Context, accountID, routerID string) error
|
||||||
|
|
||||||
GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*resourceTypes.NetworkResource, error)
|
GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*resourceTypes.NetworkResource, error)
|
||||||
GetNetworkResourcesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error)
|
GetNetworkResourcesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error)
|
||||||
GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*resourceTypes.NetworkResource, error)
|
GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*resourceTypes.NetworkResource, error)
|
||||||
GetNetworkResourceByName(ctx context.Context, lockStrength LockingStrength, accountID, resourceName string) (*resourceTypes.NetworkResource, error)
|
GetNetworkResourceByName(ctx context.Context, lockStrength LockingStrength, accountID, resourceName string) (*resourceTypes.NetworkResource, error)
|
||||||
SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *resourceTypes.NetworkResource) error
|
SaveNetworkResource(ctx context.Context, resource *resourceTypes.NetworkResource) error
|
||||||
DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error
|
DeleteNetworkResource(ctx context.Context, accountID, resourceID string) error
|
||||||
GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error)
|
GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error)
|
||||||
GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error)
|
GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error)
|
||||||
GetAccountGroupPeers(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]map[string]struct{}, error)
|
GetAccountGroupPeers(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]map[string]struct{}, error)
|
||||||
|
|||||||
@@ -11,13 +11,13 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||||
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||||
authv2 "github.com/netbirdio/netbird/shared/relay/auth/hmac/v2"
|
authv2 "github.com/netbirdio/netbird/shared/relay/auth/hmac/v2"
|
||||||
|
|
||||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultDuration = 12 * time.Hour
|
const defaultDuration = 12 * time.Hour
|
||||||
@@ -39,13 +39,14 @@ type TimeBasedAuthSecretsManager struct {
|
|||||||
relayHmacToken *authv2.Generator
|
relayHmacToken *authv2.Generator
|
||||||
updateManager *PeersUpdateManager
|
updateManager *PeersUpdateManager
|
||||||
settingsManager settings.Manager
|
settingsManager settings.Manager
|
||||||
|
groupsManager groups.Manager
|
||||||
turnCancelMap map[string]chan struct{}
|
turnCancelMap map[string]chan struct{}
|
||||||
relayCancelMap map[string]chan struct{}
|
relayCancelMap map[string]chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
type Token auth.Token
|
type Token auth.Token
|
||||||
|
|
||||||
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *types.TURNConfig, relayCfg *types.Relay, settingsManager settings.Manager) *TimeBasedAuthSecretsManager {
|
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *types.TURNConfig, relayCfg *types.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager {
|
||||||
mgr := &TimeBasedAuthSecretsManager{
|
mgr := &TimeBasedAuthSecretsManager{
|
||||||
updateManager: updateManager,
|
updateManager: updateManager,
|
||||||
turnCfg: turnCfg,
|
turnCfg: turnCfg,
|
||||||
@@ -53,6 +54,7 @@ func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *
|
|||||||
turnCancelMap: make(map[string]chan struct{}),
|
turnCancelMap: make(map[string]chan struct{}),
|
||||||
relayCancelMap: make(map[string]chan struct{}),
|
relayCancelMap: make(map[string]chan struct{}),
|
||||||
settingsManager: settingsManager,
|
settingsManager: settingsManager,
|
||||||
|
groupsManager: groupsManager,
|
||||||
}
|
}
|
||||||
|
|
||||||
if turnCfg != nil {
|
if turnCfg != nil {
|
||||||
@@ -258,6 +260,11 @@ func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, p
|
|||||||
log.WithContext(ctx).Errorf("failed to get extra settings: %v", err)
|
log.WithContext(ctx).Errorf("failed to get extra settings: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peerID, update.NetbirdConfig, extraSettings)
|
peerGroups, err := m.groupsManager.GetPeerGroupIDs(ctx, accountID, peerID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get peer groups: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peerID, peerGroups, update.NetbirdConfig, extraSettings)
|
||||||
update.NetbirdConfig = extendedConfig
|
update.NetbirdConfig = extendedConfig
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,9 +13,10 @@ import (
|
|||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -40,13 +41,14 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
|
|||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
t.Cleanup(ctrl.Finish)
|
t.Cleanup(ctrl.Finish)
|
||||||
settingsMockManager := settings.NewMockManager(ctrl)
|
settingsMockManager := settings.NewMockManager(ctrl)
|
||||||
|
groupsManager := groups.NewManagerMock()
|
||||||
|
|
||||||
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
|
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
|
||||||
CredentialsTTL: ttl,
|
CredentialsTTL: ttl,
|
||||||
Secret: secret,
|
Secret: secret,
|
||||||
Turns: []*types.Host{TurnTestHost},
|
Turns: []*types.Host{TurnTestHost},
|
||||||
TimeBasedCredentials: true,
|
TimeBasedCredentials: true,
|
||||||
}, rc, settingsMockManager)
|
}, rc, settingsMockManager, groupsManager)
|
||||||
|
|
||||||
turnCredentials, err := tested.GenerateTurnToken()
|
turnCredentials, err := tested.GenerateTurnToken()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -91,13 +93,14 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
|
|||||||
t.Cleanup(ctrl.Finish)
|
t.Cleanup(ctrl.Finish)
|
||||||
settingsMockManager := settings.NewMockManager(ctrl)
|
settingsMockManager := settings.NewMockManager(ctrl)
|
||||||
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes()
|
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes()
|
||||||
|
groupsManager := groups.NewManagerMock()
|
||||||
|
|
||||||
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
|
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
|
||||||
CredentialsTTL: ttl,
|
CredentialsTTL: ttl,
|
||||||
Secret: secret,
|
Secret: secret,
|
||||||
Turns: []*types.Host{TurnTestHost},
|
Turns: []*types.Host{TurnTestHost},
|
||||||
TimeBasedCredentials: true,
|
TimeBasedCredentials: true,
|
||||||
}, rc, settingsMockManager)
|
}, rc, settingsMockManager, groupsManager)
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -193,13 +196,14 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
|
|||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
t.Cleanup(ctrl.Finish)
|
t.Cleanup(ctrl.Finish)
|
||||||
settingsMockManager := settings.NewMockManager(ctrl)
|
settingsMockManager := settings.NewMockManager(ctrl)
|
||||||
|
groupsManager := groups.NewManagerMock()
|
||||||
|
|
||||||
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
|
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
|
||||||
CredentialsTTL: ttl,
|
CredentialsTTL: ttl,
|
||||||
Secret: secret,
|
Secret: secret,
|
||||||
Turns: []*types.Host{TurnTestHost},
|
Turns: []*types.Host{TurnTestHost},
|
||||||
TimeBasedCredentials: true,
|
TimeBasedCredentials: true,
|
||||||
}, rc, settingsMockManager)
|
}, rc, settingsMockManager, groupsManager)
|
||||||
|
|
||||||
tested.SetupRefresh(context.Background(), "someAccountID", peer)
|
tested.SetupRefresh(context.Background(), "someAccountID", peer)
|
||||||
if _, ok := tested.turnCancelMap[peer]; !ok {
|
if _, ok := tested.turnCancelMap[peer]; !ok {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package types
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -87,21 +88,21 @@ type ExtraSettings struct {
|
|||||||
// IntegratedValidatorGroups list of group IDs to be used with integrated approval configurations
|
// IntegratedValidatorGroups list of group IDs to be used with integrated approval configurations
|
||||||
IntegratedValidatorGroups []string `gorm:"serializer:json"`
|
IntegratedValidatorGroups []string `gorm:"serializer:json"`
|
||||||
|
|
||||||
FlowEnabled bool `gorm:"-"`
|
FlowEnabled bool `gorm:"-"`
|
||||||
FlowPacketCounterEnabled bool `gorm:"-"`
|
FlowGroups []string `gorm:"-"`
|
||||||
FlowENCollectionEnabled bool `gorm:"-"`
|
FlowPacketCounterEnabled bool `gorm:"-"`
|
||||||
FlowDnsCollectionEnabled bool `gorm:"-"`
|
FlowENCollectionEnabled bool `gorm:"-"`
|
||||||
|
FlowDnsCollectionEnabled bool `gorm:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy copies the ExtraSettings struct
|
// Copy copies the ExtraSettings struct
|
||||||
func (e *ExtraSettings) Copy() *ExtraSettings {
|
func (e *ExtraSettings) Copy() *ExtraSettings {
|
||||||
var cpGroup []string
|
|
||||||
|
|
||||||
return &ExtraSettings{
|
return &ExtraSettings{
|
||||||
PeerApprovalEnabled: e.PeerApprovalEnabled,
|
PeerApprovalEnabled: e.PeerApprovalEnabled,
|
||||||
IntegratedValidatorGroups: append(cpGroup, e.IntegratedValidatorGroups...),
|
IntegratedValidatorGroups: slices.Clone(e.IntegratedValidatorGroups),
|
||||||
IntegratedValidator: e.IntegratedValidator,
|
IntegratedValidator: e.IntegratedValidator,
|
||||||
FlowEnabled: e.FlowEnabled,
|
FlowEnabled: e.FlowEnabled,
|
||||||
|
FlowGroups: slices.Clone(e.FlowGroups),
|
||||||
FlowPacketCounterEnabled: e.FlowPacketCounterEnabled,
|
FlowPacketCounterEnabled: e.FlowPacketCounterEnabled,
|
||||||
FlowENCollectionEnabled: e.FlowENCollectionEnabled,
|
FlowENCollectionEnabled: e.FlowENCollectionEnabled,
|
||||||
FlowDnsCollectionEnabled: e.FlowDnsCollectionEnabled,
|
FlowDnsCollectionEnabled: e.FlowDnsCollectionEnabled,
|
||||||
|
|||||||
@@ -17,11 +17,11 @@ import (
|
|||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/management/server/users"
|
"github.com/netbirdio/netbird/management/server/users"
|
||||||
"github.com/netbirdio/netbird/management/server/util"
|
"github.com/netbirdio/netbird/management/server/util"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
// createServiceUser creates a new service user under the given account.
|
// createServiceUser creates a new service user under the given account.
|
||||||
@@ -46,7 +46,7 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI
|
|||||||
newUser.AccountID = accountID
|
newUser.AccountID = accountID
|
||||||
log.WithContext(ctx).Debugf("New User: %v", newUser)
|
log.WithContext(ctx).Debugf("New User: %v", newUser)
|
||||||
|
|
||||||
if err = am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser); err != nil {
|
if err = am.Store.SaveUser(ctx, newUser); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -95,14 +95,14 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
inviterID := userID
|
inviterID := userID
|
||||||
if initiatorUser.IsServiceUser {
|
if initiatorUser.IsServiceUser {
|
||||||
createdBy, err := am.Store.GetAccountCreatedBy(ctx, store.LockingStrengthShare, accountID)
|
createdBy, err := am.Store.GetAccountCreatedBy(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -124,7 +124,7 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
|
|||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser); err != nil {
|
if err = am.Store.SaveUser(ctx, newUser); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -178,13 +178,13 @@ func (am *DefaultAccountManager) createNewIdpUser(ctx context.Context, accountID
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) {
|
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) {
|
||||||
return am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id)
|
return am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUser looks up a user by provided nbContext.UserAuths.
|
// GetUser looks up a user by provided nbContext.UserAuths.
|
||||||
// Expects account to have been created already.
|
// Expects account to have been created already.
|
||||||
func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbContext.UserAuth) (*types.User, error) {
|
func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbContext.UserAuth) (*types.User, error) {
|
||||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
|
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -209,11 +209,11 @@ func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAu
|
|||||||
// ListUsers returns lists of all users under the account.
|
// ListUsers returns lists of all users under the account.
|
||||||
// It doesn't populate user information such as email or name.
|
// It doesn't populate user information such as email or name.
|
||||||
func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) {
|
func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) {
|
||||||
return am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
|
return am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, accountID string, initiatorUserID string, targetUser *types.User) error {
|
func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, accountID string, initiatorUserID string, targetUser *types.User) error {
|
||||||
if err := am.Store.DeleteUser(ctx, store.LockingStrengthUpdate, accountID, targetUser.Id); err != nil {
|
if err := am.Store.DeleteUser(ctx, accountID, targetUser.Id); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
meta := map[string]any{"name": targetUser.ServiceUserName, "created_at": targetUser.CreatedAt}
|
meta := map[string]any{"name": targetUser.ServiceUserName, "created_at": targetUser.CreatedAt}
|
||||||
@@ -230,7 +230,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
|
|||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
|
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -243,7 +243,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
|
|||||||
return status.NewPermissionDeniedError()
|
return status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
|
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -347,12 +347,12 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
|
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
|
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -367,7 +367,7 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
|
|||||||
return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err)
|
return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = am.Store.SavePAT(ctx, store.LockingStrengthUpdate, &pat.PersonalAccessToken); err != nil {
|
if err = am.Store.SavePAT(ctx, &pat.PersonalAccessToken); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -390,12 +390,12 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
|
|||||||
return status.NewPermissionDeniedError()
|
return status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
|
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
|
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -404,12 +404,12 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
|
|||||||
return status.NewAdminPermissionError()
|
return status.NewAdminPermissionError()
|
||||||
}
|
}
|
||||||
|
|
||||||
pat, err := am.Store.GetPATByID(ctx, store.LockingStrengthShare, targetUserID, tokenID)
|
pat, err := am.Store.GetPATByID(ctx, store.LockingStrengthNone, targetUserID, tokenID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = am.Store.DeletePAT(ctx, store.LockingStrengthUpdate, targetUserID, tokenID); err != nil {
|
if err = am.Store.DeletePAT(ctx, targetUserID, tokenID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -429,12 +429,12 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
|
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
|
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -443,7 +443,7 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
|
|||||||
return nil, status.NewAdminPermissionError()
|
return nil, status.NewAdminPermissionError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetPATByID(ctx, store.LockingStrengthShare, targetUserID, tokenID)
|
return am.Store.GetPATByID(ctx, store.LockingStrengthNone, targetUserID, tokenID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllPATs returns all PATs for a user
|
// GetAllPATs returns all PATs for a user
|
||||||
@@ -456,12 +456,12 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
|
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
|
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -470,7 +470,7 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin
|
|||||||
return nil, status.NewAdminPermissionError()
|
return nil, status.NewAdminPermissionError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetUserPATs(ctx, store.LockingStrengthShare, targetUserID)
|
return am.Store.GetUserPATs(ctx, store.LockingStrengthNone, targetUserID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveUser saves updates to the given user. If the user doesn't exist, it will throw status.NotFound error.
|
// SaveUser saves updates to the given user. If the user doesn't exist, it will throw status.NotFound error.
|
||||||
@@ -511,7 +511,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
|||||||
if !allowed {
|
if !allowed {
|
||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -521,7 +521,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
|||||||
var addUserEvents []func()
|
var addUserEvents []func()
|
||||||
var usersToSave = make([]*types.User, 0, len(updates))
|
var usersToSave = make([]*types.User, 0, len(updates))
|
||||||
|
|
||||||
groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
|
groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting account groups: %w", err)
|
return nil, fmt.Errorf("error getting account groups: %w", err)
|
||||||
}
|
}
|
||||||
@@ -533,7 +533,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
|||||||
|
|
||||||
var initiatorUser *types.User
|
var initiatorUser *types.User
|
||||||
if initiatorUserID != activity.SystemInitiator {
|
if initiatorUserID != activity.SystemInitiator {
|
||||||
result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
|
result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -560,7 +560,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
|||||||
updateAccountPeers = true
|
updateAccountPeers = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return transaction.SaveUsers(ctx, store.LockingStrengthUpdate, usersToSave)
|
return transaction.SaveUsers(ctx, usersToSave)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -593,7 +593,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if settings.GroupsPropagationEnabled && updateAccountPeers {
|
if settings.GroupsPropagationEnabled && updateAccountPeers {
|
||||||
if err = am.Store.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
if err = am.Store.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||||
return nil, fmt.Errorf("failed to increment network serial: %w", err)
|
return nil, fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
}
|
}
|
||||||
am.UpdateAccountPeers(ctx, accountID)
|
am.UpdateAccountPeers(ctx, accountID)
|
||||||
@@ -700,7 +700,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
|
|||||||
|
|
||||||
// getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist.
|
// getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist.
|
||||||
func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, accountID string, update *types.User, addIfNotExists bool) (*types.User, error) {
|
func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, accountID string, update *types.User, addIfNotExists bool) (*types.User, error) {
|
||||||
existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, update.Id)
|
existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, update.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||||
if !addIfNotExists {
|
if !addIfNotExists {
|
||||||
@@ -724,7 +724,7 @@ func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initi
|
|||||||
newInitiatorUser := initiatorUser.Copy()
|
newInitiatorUser := initiatorUser.Copy()
|
||||||
newInitiatorUser.Role = types.UserRoleAdmin
|
newInitiatorUser.Role = types.UserRoleAdmin
|
||||||
|
|
||||||
if err := transaction.SaveUser(ctx, store.LockingStrengthUpdate, newInitiatorUser); err != nil {
|
if err := transaction.SaveUser(ctx, newInitiatorUser); err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
return true, nil
|
return true, nil
|
||||||
@@ -835,7 +835,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
|
|||||||
|
|
||||||
var user *types.User
|
var user *types.User
|
||||||
if initiatorUserID != activity.SystemInitiator {
|
if initiatorUserID != activity.SystemInitiator {
|
||||||
result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
|
result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get user: %w", err)
|
return nil, fmt.Errorf("failed to get user: %w", err)
|
||||||
}
|
}
|
||||||
@@ -845,7 +845,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
|
|||||||
accountUsers := []*types.User{}
|
accountUsers := []*types.User{}
|
||||||
switch {
|
switch {
|
||||||
case allowed:
|
case allowed:
|
||||||
accountUsers, err = am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
|
accountUsers, err = am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -939,7 +939,7 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
|
|||||||
// expireAndUpdatePeers expires all peers of the given user and updates them in the account
|
// expireAndUpdatePeers expires all peers of the given user and updates them in the account
|
||||||
func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accountID string, peers []*nbpeer.Peer) error {
|
func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accountID string, peers []*nbpeer.Peer) error {
|
||||||
log.WithContext(ctx).Debugf("Expiring %d peers for account %s", len(peers), accountID)
|
log.WithContext(ctx).Debugf("Expiring %d peers for account %s", len(peers), accountID)
|
||||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -956,7 +956,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
|
|||||||
peerIDs = append(peerIDs, peer.ID)
|
peerIDs = append(peerIDs, peer.ID)
|
||||||
peer.MarkLoginExpired(true)
|
peer.MarkLoginExpired(true)
|
||||||
|
|
||||||
if err := am.Store.SavePeerStatus(ctx, store.LockingStrengthUpdate, accountID, peer.ID, *peer.Status); err != nil {
|
if err := am.Store.SavePeerStatus(ctx, accountID, peer.ID, *peer.Status); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
am.StoreEvent(
|
am.StoreEvent(
|
||||||
@@ -1009,7 +1009,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
|
|||||||
return status.NewPermissionDeniedError()
|
return status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
|
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -1023,7 +1023,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
|
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
allErrors = errors.Join(allErrors, err)
|
allErrors = errors.Join(allErrors, err)
|
||||||
continue
|
continue
|
||||||
@@ -1087,12 +1087,12 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
|
|||||||
var err error
|
var err error
|
||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
targetUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserInfo.ID)
|
targetUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthUpdate, targetUserInfo.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get user to delete: %w", err)
|
return fmt.Errorf("failed to get user to delete: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, targetUserInfo.ID)
|
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, targetUserInfo.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get user peers: %w", err)
|
return fmt.Errorf("failed to get user peers: %w", err)
|
||||||
}
|
}
|
||||||
@@ -1105,7 +1105,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = transaction.DeleteUser(ctx, store.LockingStrengthUpdate, accountID, targetUserInfo.ID); err != nil {
|
if err = transaction.DeleteUser(ctx, accountID, targetUserInfo.ID); err != nil {
|
||||||
return fmt.Errorf("failed to delete user: %s %w", targetUserInfo.ID, err)
|
return fmt.Errorf("failed to delete user: %s %w", targetUserInfo.ID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1126,7 +1126,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
|
|||||||
|
|
||||||
// GetOwnerInfo retrieves the owner information for a given account ID.
|
// GetOwnerInfo retrieves the owner information for a given account ID.
|
||||||
func (am *DefaultAccountManager) GetOwnerInfo(ctx context.Context, accountID string) (*types.UserInfo, error) {
|
func (am *DefaultAccountManager) GetOwnerInfo(ctx context.Context, accountID string) (*types.UserInfo, error) {
|
||||||
owner, err := am.Store.GetAccountOwner(ctx, store.LockingStrengthShare, accountID)
|
owner, err := am.Store.GetAccountOwner(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1176,7 +1176,7 @@ func validateUserInvite(invite *types.UserInfo) error {
|
|||||||
func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) {
|
func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) {
|
||||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||||
|
|
||||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1193,7 +1193,7 @@ func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAut
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,9 +15,9 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/roles"
|
"github.com/netbirdio/netbird/management/server/permissions/roles"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
"github.com/netbirdio/netbird/management/server/users"
|
"github.com/netbirdio/netbird/management/server/users"
|
||||||
"github.com/netbirdio/netbird/management/server/util"
|
"github.com/netbirdio/netbird/management/server/util"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
@@ -88,7 +88,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
|||||||
|
|
||||||
assert.Equal(t, pat.ID, tokenID)
|
assert.Equal(t, pat.ID, tokenID)
|
||||||
|
|
||||||
user, err := am.Store.GetUserByPATID(context.Background(), store.LockingStrengthShare, tokenID)
|
user, err := am.Store.GetUserByPATID(context.Background(), store.LockingStrengthNone, tokenID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error when getting user by token ID: %s", err)
|
t.Fatalf("Error when getting user by token ID: %s", err)
|
||||||
}
|
}
|
||||||
@@ -1521,7 +1521,7 @@ func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) {
|
|||||||
_, err = am.SaveOrAddUser(context.Background(), "account2", "ownerAccount2", account1.Users[targetId], true)
|
_, err = am.SaveOrAddUser(context.Background(), "account2", "ownerAccount2", account1.Users[targetId], true)
|
||||||
assert.Error(t, err, "update user to another account should fail")
|
assert.Error(t, err, "update user to another account should fail")
|
||||||
|
|
||||||
user, err := s.GetUserByUserID(context.Background(), store.LockingStrengthShare, targetId)
|
user, err := s.GetUserByUserID(context.Background(), store.LockingStrengthNone, targetId)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, account1.Users[targetId].Id, user.Id)
|
assert.Equal(t, account1.Users[targetId].Id, user.Id)
|
||||||
assert.Equal(t, account1.Users[targetId].AccountID, user.AccountID)
|
assert.Equal(t, account1.Users[targetId].AccountID, user.AccountID)
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ func NewManager(store store.Store) Manager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *managerImpl) GetUser(ctx context.Context, userID string) (*types.User, error) {
|
func (m *managerImpl) GetUser(ctx context.Context, userID string) (*types.User, error) {
|
||||||
return m.store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
return m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManagerMock() Manager {
|
func NewManagerMock() Manager {
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user