Compare commits

...

17 Commits

Author SHA1 Message Date
bcmmbaga
2b86463e96 fix merge
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-08-12 13:33:46 +03:00
bcmmbaga
9deff6f06b Merge branch 'main' into handle-existing-domain-user
# Conflicts:
#	management/server/account.go
#	management/server/account_test.go
2025-08-12 13:31:40 +03:00
Pascal Fischer
a942e4add5 [management] use readlock on add peer (#4308) 2025-08-11 15:21:26 +02:00
Viktor Liu
1022a5015c [client] Eliminate upstream server strings in dns code (#4267) 2025-08-11 11:57:21 +02:00
Maycon Santos
375fcf2752 [misc] Post release to forum (#4312) 2025-08-08 21:41:33 +02:00
Maycon Santos
9acf7f9262 [client] Update Windows installer description (#4306)
* [client] Update Windows installer description

* Update netbird.wxs
2025-08-08 21:18:58 +02:00
Viktor Liu
82937ba184 [client] Increase logout timeout (#4311) 2025-08-08 19:16:48 +02:00
Maycon Santos
0f52144894 [misc] Add docs acknowledgement check (#4310)
adds a GitHub Actions workflow to enforce documentation requirements for pull requests, ensuring contributors acknowledge whether their changes need documentation updates or provide a link to a corresponding docs PR.

- Adds a new GitHub Actions workflow that validates documentation acknowledgement in PR descriptions
- Updates the PR template to include mandatory documentation checkboxes and URL field
- Implements validation logic to ensure exactly one documentation option is selected and verifies docs PR URLs when provided
2025-08-08 18:14:26 +02:00
Krzysztof Nazarewski (kdn)
0926400b8a fix: profilemanager panic when reading incomplete config (#4309)
fix: profilemanager panic when reading incomplete config (#4309)
2025-08-08 18:44:25 +03:00
Viktor Liu
bef99d48f8 [client] Rename logout to deregister (#4307) 2025-08-08 15:48:30 +02:00
Pascal Fischer
9e95841252 [management] during JSON migration filter duplicates on conflict (#4303) 2025-08-07 14:12:07 +02:00
hakansa
6da3943559 [client] fix ssh command for non-default profile (#4298)
[client] fix ssh command for non-default profile (#4298)
2025-08-07 13:08:30 +03:00
Pascal Fischer
f5b4659adb [management] Mark SaveAccount deprecated (#4300) 2025-08-07 11:49:37 +02:00
Viktor Liu
3d19468b6c [client] Add windows arm64 build (#4206) 2025-08-07 11:30:19 +02:00
bcmmbaga
1a1e94c805 Check account existence without fully loading it
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-08 18:26:46 +03:00
bcmmbaga
ed939bf7f5 add unit test
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-08 18:09:14 +03:00
bcmmbaga
7caf733217 Skip adding user to domain account if already exists
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-08 15:50:32 +03:00
51 changed files with 549 additions and 415 deletions

View File

@@ -12,6 +12,16 @@
- [ ] Is a feature enhancement - [ ] Is a feature enhancement
- [ ] It is a refactor - [ ] It is a refactor
- [ ] Created tests that fail without the change (if possible) - [ ] Created tests that fail without the change (if possible)
- [ ] Extended the README / documentation, if necessary
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md). > By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).
## Documentation
Select exactly one:
- [ ] I added/updated documentation for this change
- [ ] Documentation is **not needed** for this change (explain why)
### Docs PR URL (required if "docs added" is checked)
Paste the PR link from https://github.com/netbirdio/docs here:
https://github.com/netbirdio/docs/pull/__

94
.github/workflows/docs-ack.yml vendored Normal file
View File

@@ -0,0 +1,94 @@
name: Docs Acknowledgement
on:
pull_request:
types: [opened, edited, synchronize]
permissions:
contents: read
pull-requests: read
jobs:
docs-ack:
name: Require docs PR URL or explicit "not needed"
runs-on: ubuntu-latest
steps:
- name: Read PR body
id: body
run: |
BODY=$(jq -r '.pull_request.body // ""' "$GITHUB_EVENT_PATH")
echo "body<<EOF" >> $GITHUB_OUTPUT
echo "$BODY" >> $GITHUB_OUTPUT
echo "EOF" >> $GITHUB_OUTPUT
- name: Validate checkbox selection
id: validate
run: |
body='${{ steps.body.outputs.body }}'
added_checked=$(printf "%s" "$body" | grep -E '^- \[x\] I added/updated documentation' -i | wc -l | tr -d ' ')
noneed_checked=$(printf "%s" "$body" | grep -E '^- \[x\] Documentation is \*\*not needed\*\*' -i | wc -l | tr -d ' ')
if [ "$added_checked" -eq 1 ] && [ "$noneed_checked" -eq 1 ]; then
echo "::error::Choose exactly one: either 'docs added' OR 'not needed'."
exit 1
fi
if [ "$added_checked" -eq 0 ] && [ "$noneed_checked" -eq 0 ]; then
echo "::error::You must check exactly one docs option in the PR template."
exit 1
fi
if [ "$added_checked" -eq 1 ]; then
echo "mode=added" >> $GITHUB_OUTPUT
else
echo "mode=noneed" >> $GITHUB_OUTPUT
fi
- name: Extract docs PR URL (when 'docs added')
if: steps.validate.outputs.mode == 'added'
id: extract
run: |
body='${{ steps.body.outputs.body }}'
# Strictly require HTTPS and that it's a PR in netbirdio/docs
# Examples accepted:
# https://github.com/netbirdio/docs/pull/1234
url=$(printf "%s" "$body" | grep -Eo 'https://github\.com/netbirdio/docs/pull/[0-9]+' | head -n1 || true)
if [ -z "$url" ]; then
echo "::error::You checked 'docs added' but didn't include a valid HTTPS PR link to netbirdio/docs (e.g., https://github.com/netbirdio/docs/pull/1234)."
exit 1
fi
pr_number=$(echo "$url" | sed -E 's#.*/pull/([0-9]+)$#\1#')
echo "url=$url" >> $GITHUB_OUTPUT
echo "pr_number=$pr_number" >> $GITHUB_OUTPUT
- name: Verify docs PR exists (and is open or merged)
if: steps.validate.outputs.mode == 'added'
uses: actions/github-script@v7
id: verify
with:
pr_number: ${{ steps.extract.outputs.pr_number }}
script: |
const prNumber = parseInt(core.getInput('pr_number'), 10);
const { data } = await github.rest.pulls.get({
owner: 'netbirdio',
repo: 'docs',
pull_number: prNumber
});
// Allow open or merged PRs
const ok = data.state === 'open' || data.merged === true;
core.setOutput('state', data.state);
core.setOutput('merged', String(!!data.merged));
if (!ok) {
core.setFailed(`Docs PR #${prNumber} exists but is neither open nor merged (state=${data.state}, merged=${data.merged}).`);
}
result-encoding: string
github-token: ${{ secrets.GITHUB_TOKEN }}
- name: All good
run: echo "Documentation requirement satisfied ✅"

18
.github/workflows/forum.yml vendored Normal file
View File

@@ -0,0 +1,18 @@
name: Post release topic on Discourse
on:
release:
types: [published]
jobs:
post:
runs-on: ubuntu-latest
steps:
- uses: roots/discourse-topic-github-release-action@main
with:
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
discourse-base-url: https://forum.netbird.io
discourse-author-username: NetBird
discourse-category: 17
discourse-tags:
releases

View File

@@ -9,7 +9,7 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.0.21" SIGN_PIPE_VER: "v0.0.22"
GORELEASER_VER: "v2.3.2" GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird" PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH" COPYRIGHT: "NetBird GmbH"
@@ -79,6 +79,8 @@ jobs:
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64 - name: Generate windows syso amd64
run: goversioninfo -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso run: goversioninfo -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
- name: Generate windows syso arm64
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
- name: Run GoReleaser - name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
with: with:
@@ -154,10 +156,20 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64 run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
- name: Install LLVM-MinGW for ARM64 cross-compilation
run: |
cd /tmp
wget -q https://github.com/mstorsjo/llvm-mingw/releases/download/20250709/llvm-mingw-20250709-ucrt-ubuntu-22.04-x86_64.tar.xz
echo "60cafae6474c7411174cff1d4ba21a8e46cadbaeb05a1bace306add301628337 llvm-mingw-20250709-ucrt-ubuntu-22.04-x86_64.tar.xz" | sha256sum -c
tar -xf llvm-mingw-20250709-ucrt-ubuntu-22.04-x86_64.tar.xz
echo "/tmp/llvm-mingw-20250709-ucrt-ubuntu-22.04-x86_64/bin" >> $GITHUB_PATH
- name: Install goversioninfo - name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64 - name: Generate windows syso amd64
run: goversioninfo -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso run: goversioninfo -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
- name: Generate windows syso arm64
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
- name: Run GoReleaser - name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
@@ -231,17 +243,3 @@ jobs:
ref: ${{ env.SIGN_PIPE_VER }} ref: ${{ env.SIGN_PIPE_VER }}
token: ${{ secrets.SIGN_GITHUB_TOKEN }} token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }' inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'
post_on_forum:
runs-on: ubuntu-latest
continue-on-error: true
needs: [trigger_signer]
steps:
- uses: Codixer/discourse-topic-github-release-action@v2.0.1
with:
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
discourse-base-url: https://forum.netbird.io
discourse-author-username: NetBird
discourse-category: 17
discourse-tags:
releases

View File

@@ -16,8 +16,6 @@ builds:
- arm64 - arm64
- 386 - 386
ignore: ignore:
- goos: windows
goarch: arm64
- goos: windows - goos: windows
goarch: arm goarch: arm
- goos: windows - goos: windows

View File

@@ -15,7 +15,7 @@ builds:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-ui-windows - id: netbird-ui-windows-amd64
dir: client/ui dir: client/ui
binary: netbird-ui binary: netbird-ui
env: env:
@@ -30,6 +30,22 @@ builds:
- -H windowsgui - -H windowsgui
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-ui-windows-arm64
dir: client/ui
binary: netbird-ui
env:
- CGO_ENABLED=1
- CC=aarch64-w64-mingw32-clang
- CXX=aarch64-w64-mingw32-clang++
goos:
- windows
goarch:
- arm64
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
- -H windowsgui
mod_timestamp: "{{ .CommitTimestamp }}"
archives: archives:
- id: linux-arch - id: linux-arch
name_template: "{{ .ProjectName }}-linux_{{ .Version }}_{{ .Os }}_{{ .Arch }}" name_template: "{{ .ProjectName }}-linux_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
@@ -38,7 +54,8 @@ archives:
- id: windows-arch - id: windows-arch
name_template: "{{ .ProjectName }}-windows_{{ .Version }}_{{ .Os }}_{{ .Arch }}" name_template: "{{ .ProjectName }}-windows_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
builds: builds:
- netbird-ui-windows - netbird-ui-windows-amd64
- netbird-ui-windows-arm64
nfpms: nfpms:
- maintainer: Netbird <dev@netbird.io> - maintainer: Netbird <dev@netbird.io>

View File

@@ -4,6 +4,7 @@ package android
import ( import (
"context" "context"
"slices"
"sync" "sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -112,7 +113,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
} }
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
@@ -138,7 +139,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
} }
// Stop the internal client and free the resources // Stop the internal client and free the resources
@@ -235,7 +236,7 @@ func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
return err return err
} }
dnsServer.OnUpdatedHostDNSServer(list.items) dnsServer.OnUpdatedHostDNSServer(slices.Clone(list.items))
return nil return nil
} }

View File

@@ -1,23 +1,34 @@
package android package android
import "fmt" import (
"fmt"
"net/netip"
// DNSList is a wrapper of []string "github.com/netbirdio/netbird/client/internal/dns"
)
// DNSList is a wrapper of []netip.AddrPort with default DNS port
type DNSList struct { type DNSList struct {
items []string items []netip.AddrPort
} }
// Add new DNS address to the collection // Add new DNS address to the collection, returns error if invalid
func (array *DNSList) Add(s string) { func (array *DNSList) Add(s string) error {
array.items = append(array.items, s) addr, err := netip.ParseAddr(s)
if err != nil {
return fmt.Errorf("invalid DNS address: %s", s)
}
addrPort := netip.AddrPortFrom(addr.Unmap(), dns.DefaultPort)
array.items = append(array.items, addrPort)
return nil
} }
// Get return an element of the collection // Get return an element of the collection as string
func (array *DNSList) Get(i int) (string, error) { func (array *DNSList) Get(i int) (string, error) {
if i >= len(array.items) || i < 0 { if i >= len(array.items) || i < 0 {
return "", fmt.Errorf("out of range") return "", fmt.Errorf("out of range")
} }
return array.items[i], nil return array.items[i].Addr().String(), nil
} }
// Size return with the size of the collection // Size return with the size of the collection

View File

@@ -3,20 +3,30 @@ package android
import "testing" import "testing"
func TestDNSList_Get(t *testing.T) { func TestDNSList_Get(t *testing.T) {
l := DNSList{ l := DNSList{}
items: make([]string, 1),
// Add a valid DNS address
err := l.Add("8.8.8.8")
if err != nil {
t.Errorf("unexpected error: %s", err)
} }
_, err := l.Get(0) // Test getting valid index
addr, err := l.Get(0)
if err != nil { if err != nil {
t.Errorf("invalid error: %s", err) t.Errorf("invalid error: %s", err)
} }
if addr != "8.8.8.8" {
t.Errorf("expected 8.8.8.8, got %s", addr)
}
// Test negative index
_, err = l.Get(-1) _, err = l.Get(-1)
if err == nil { if err == nil {
t.Errorf("expected error but got nil") t.Errorf("expected error but got nil")
} }
// Test out of bounds index
_, err = l.Get(1) _, err = l.Get(1)
if err == nil { if err == nil {
t.Errorf("expected error but got nil") t.Errorf("expected error but got nil")

View File

@@ -12,14 +12,15 @@ import (
) )
var logoutCmd = &cobra.Command{ var logoutCmd = &cobra.Command{
Use: "logout", Use: "deregister",
Short: "logout from the NetBird Management Service and delete peer", Aliases: []string{"logout"},
Short: "deregister from the NetBird Management Service and delete peer",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd) SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout()) cmd.SetOut(cmd.OutOrStdout())
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7) ctx, cancel := context.WithTimeout(cmd.Context(), time.Second*15)
defer cancel() defer cancel()
conn, err := DialClientGRPCServer(ctx, daemonAddr) conn, err := DialClientGRPCServer(ctx, daemonAddr)
@@ -44,10 +45,10 @@ var logoutCmd = &cobra.Command{
} }
if _, err := daemonClient.Logout(ctx, req); err != nil { if _, err := daemonClient.Logout(ctx, req); err != nil {
return fmt.Errorf("logout: %v", err) return fmt.Errorf("deregister: %v", err)
} }
cmd.Println("Logged out successfully") cmd.Println("Deregistered successfully")
return nil return nil
}, },
} }

View File

@@ -47,7 +47,7 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name") rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
serviceEnvDesc := `Sets extra environment variables for the service. ` + serviceEnvDesc := `Sets extra environment variables for the service. ` +
`You can specify a comma-separated list of KEY=VALUE pairs. ` + `You can specify a comma-separated list of KEY=VALUE pairs. ` +
`E.g. --service-env LOG_LEVEL=debug,CUSTOM_VAR=value` `E.g. --service-env NB_LOG_LEVEL=debug,CUSTOM_VAR=value`
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc) installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
reconfigureCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc) reconfigureCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)

View File

@@ -59,8 +59,8 @@ var sshCmd = &cobra.Command{
ctx := internal.CtxInitState(cmd.Context()) ctx := internal.CtxInitState(cmd.Context())
pm := profilemanager.NewProfileManager() sm := profilemanager.NewServiceManager(configPath)
activeProf, err := pm.GetActiveProfile() activeProf, err := sm.GetActiveProfileState()
if err != nil { if err != nil {
return fmt.Errorf("get active profile: %v", err) return fmt.Errorf("get active profile: %v", err)
} }

View File

@@ -3,7 +3,7 @@
!define WEB_SITE "Netbird.io" !define WEB_SITE "Netbird.io"
!define VERSION $%APPVER% !define VERSION $%APPVER%
!define COPYRIGHT "Netbird Authors, 2022" !define COPYRIGHT "Netbird Authors, 2022"
!define DESCRIPTION "A WireGuard®-based mesh network that connects your devices into a single private network" !define DESCRIPTION "Connect your devices into a secure WireGuard-based overlay network with SSO, MFA, and granular access controls."
!define INSTALLER_NAME "netbird-installer.exe" !define INSTALLER_NAME "netbird-installer.exe"
!define MAIN_APP_EXE "Netbird" !define MAIN_APP_EXE "Netbird"
!define ICON "ui\\assets\\netbird.ico" !define ICON "ui\\assets\\netbird.ico"
@@ -59,9 +59,15 @@ ShowInstDetails Show
!define MUI_UNICON "${ICON}" !define MUI_UNICON "${ICON}"
!define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}" !define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}"
!define MUI_UNWELCOMEFINISHPAGE_BITMAP "${BANNER}" !define MUI_UNWELCOMEFINISHPAGE_BITMAP "${BANNER}"
!define MUI_FINISHPAGE_RUN !ifndef ARCH
!define MUI_FINISHPAGE_RUN_TEXT "Start ${UI_APP_NAME}" !define ARCH "amd64"
!define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink" !endif
!if ${ARCH} == "amd64"
!define MUI_FINISHPAGE_RUN
!define MUI_FINISHPAGE_RUN_TEXT "Start ${UI_APP_NAME}"
!define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink"
!endif
###################################################################### ######################################################################
!define MUI_ABORTWARNING !define MUI_ABORTWARNING
@@ -213,7 +219,15 @@ Section -MainProgram
${INSTALL_TYPE} ${INSTALL_TYPE}
# SetOverwrite ifnewer # SetOverwrite ifnewer
SetOutPath "$INSTDIR" SetOutPath "$INSTDIR"
File /r "..\\dist\\netbird_windows_amd64\\" !ifndef ARCH
!define ARCH "amd64"
!endif
!if ${ARCH} == "arm64"
File /r "..\\dist\\netbird_windows_arm64\\"
!else
File /r "..\\dist\\netbird_windows_amd64\\"
!endif
SectionEnd SectionEnd
###################################################################### ######################################################################
@@ -292,7 +306,9 @@ DetailPrint "Deleting application files..."
Delete "$INSTDIR\${UI_APP_EXE}" Delete "$INSTDIR\${UI_APP_EXE}"
Delete "$INSTDIR\${MAIN_APP_EXE}" Delete "$INSTDIR\${MAIN_APP_EXE}"
Delete "$INSTDIR\wintun.dll" Delete "$INSTDIR\wintun.dll"
!if ${ARCH} == "amd64"
Delete "$INSTDIR\opengl32.dll" Delete "$INSTDIR\opengl32.dll"
!endif
DetailPrint "Removing application directory..." DetailPrint "Removing application directory..."
RmDir /r "$INSTDIR" RmDir /r "$INSTDIR"
@@ -314,8 +330,10 @@ DetailPrint "Uninstallation finished."
SectionEnd SectionEnd
!if ${ARCH} == "amd64"
Function LaunchLink Function LaunchLink
SetShellVarContext all SetShellVarContext all
SetOutPath $INSTDIR SetOutPath $INSTDIR
ShellExecAsUser::ShellExecAsUser "" "$DESKTOP\${APP_NAME}.lnk" ShellExecAsUser::ShellExecAsUser "" "$DESKTOP\${APP_NAME}.lnk"
FunctionEnd FunctionEnd
!endif

View File

@@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/netip"
"runtime" "runtime"
"runtime/debug" "runtime/debug"
"strings" "strings"
@@ -70,7 +71,7 @@ func (c *ConnectClient) RunOnAndroid(
tunAdapter device.TunAdapter, tunAdapter device.TunAdapter,
iFaceDiscover stdnet.ExternalIFaceDiscover, iFaceDiscover stdnet.ExternalIFaceDiscover,
networkChangeListener listener.NetworkChangeListener, networkChangeListener listener.NetworkChangeListener,
dnsAddresses []string, dnsAddresses []netip.AddrPort,
dnsReadyListener dns.ReadyListener, dnsReadyListener dns.ReadyListener,
) error { ) error {
// in case of non Android os these variables will be nil // in case of non Android os these variables will be nil

View File

@@ -16,7 +16,7 @@ const (
) )
type resolvConf struct { type resolvConf struct {
nameServers []string nameServers []netip.Addr
searchDomains []string searchDomains []string
others []string others []string
} }
@@ -36,7 +36,7 @@ func parseBackupResolvConf() (*resolvConf, error) {
func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) { func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
rconf := &resolvConf{ rconf := &resolvConf{
searchDomains: make([]string, 0), searchDomains: make([]string, 0),
nameServers: make([]string, 0), nameServers: make([]netip.Addr, 0),
others: make([]string, 0), others: make([]string, 0),
} }
@@ -94,7 +94,11 @@ func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
if len(splitLines) != 2 { if len(splitLines) != 2 {
continue continue
} }
rconf.nameServers = append(rconf.nameServers, splitLines[1]) if addr, err := netip.ParseAddr(splitLines[1]); err == nil {
rconf.nameServers = append(rconf.nameServers, addr.Unmap())
} else {
log.Warnf("invalid nameserver address in resolv.conf: %s, skipping", splitLines[1])
}
continue continue
} }
@@ -104,31 +108,3 @@ func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
} }
return rconf, nil return rconf, nil
} }
// removeFirstNbNameserver removes the given nameserver from the given file if it is in the first position
// and writes the file back to the original location
func removeFirstNbNameserver(filename string, nameserverIP netip.Addr) error {
resolvConf, err := parseResolvConfFile(filename)
if err != nil {
return fmt.Errorf("parse backup resolv.conf: %w", err)
}
content, err := os.ReadFile(filename)
if err != nil {
return fmt.Errorf("read %s: %w", filename, err)
}
if len(resolvConf.nameServers) > 1 && resolvConf.nameServers[0] == nameserverIP.String() {
newContent := strings.Replace(string(content), fmt.Sprintf("nameserver %s\n", nameserverIP), "", 1)
stat, err := os.Stat(filename)
if err != nil {
return fmt.Errorf("stat %s: %w", filename, err)
}
if err := os.WriteFile(filename, []byte(newContent), stat.Mode()); err != nil {
return fmt.Errorf("write %s: %w", filename, err)
}
}
return nil
}

View File

@@ -3,13 +3,9 @@
package dns package dns
import ( import (
"net/netip"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func Test_parseResolvConf(t *testing.T) { func Test_parseResolvConf(t *testing.T) {
@@ -99,9 +95,13 @@ options debug
t.Errorf("invalid parse result for search domains, expected: %v, got: %v", testCase.expectedSearch, cfg.searchDomains) t.Errorf("invalid parse result for search domains, expected: %v, got: %v", testCase.expectedSearch, cfg.searchDomains)
} }
ok = compareLists(cfg.nameServers, testCase.expectedNS) nsStrings := make([]string, len(cfg.nameServers))
for i, ns := range cfg.nameServers {
nsStrings[i] = ns.String()
}
ok = compareLists(nsStrings, testCase.expectedNS)
if !ok { if !ok {
t.Errorf("invalid parse result for ns domains, expected: %v, got: %v", testCase.expectedNS, cfg.nameServers) t.Errorf("invalid parse result for ns domains, expected: %v, got: %v", testCase.expectedNS, nsStrings)
} }
ok = compareLists(cfg.others, testCase.expectedOther) ok = compareLists(cfg.others, testCase.expectedOther)
@@ -177,86 +177,3 @@ nameserver 192.168.0.1
} }
} }
func TestRemoveFirstNbNameserver(t *testing.T) {
testCases := []struct {
name string
content string
ipToRemove string
expected string
}{
{
name: "Unrelated nameservers with comments and options",
content: `# This is a comment
options rotate
nameserver 1.1.1.1
# Another comment
nameserver 8.8.4.4
search example.com`,
ipToRemove: "9.9.9.9",
expected: `# This is a comment
options rotate
nameserver 1.1.1.1
# Another comment
nameserver 8.8.4.4
search example.com`,
},
{
name: "First nameserver matches",
content: `search example.com
nameserver 9.9.9.9
# oof, a comment
nameserver 8.8.4.4
options attempts:5`,
ipToRemove: "9.9.9.9",
expected: `search example.com
# oof, a comment
nameserver 8.8.4.4
options attempts:5`,
},
{
name: "Target IP not the first nameserver",
// nolint:dupword
content: `# Comment about the first nameserver
nameserver 8.8.4.4
# Comment before our target
nameserver 9.9.9.9
options timeout:2`,
ipToRemove: "9.9.9.9",
// nolint:dupword
expected: `# Comment about the first nameserver
nameserver 8.8.4.4
# Comment before our target
nameserver 9.9.9.9
options timeout:2`,
},
{
name: "Only nameserver matches",
content: `options debug
nameserver 9.9.9.9
search localdomain`,
ipToRemove: "9.9.9.9",
expected: `options debug
nameserver 9.9.9.9
search localdomain`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tempDir := t.TempDir()
tempFile := filepath.Join(tempDir, "resolv.conf")
err := os.WriteFile(tempFile, []byte(tc.content), 0644)
assert.NoError(t, err)
ip, err := netip.ParseAddr(tc.ipToRemove)
require.NoError(t, err, "Failed to parse IP address")
err = removeFirstNbNameserver(tempFile, ip)
assert.NoError(t, err)
content, err := os.ReadFile(tempFile)
assert.NoError(t, err)
assert.Equal(t, tc.expected, string(content), "The resulting content should match the expected output.")
})
}
}

View File

@@ -146,7 +146,7 @@ func isNbParamsMissing(nbSearchDomains []string, nbNameserverIP netip.Addr, rCon
return true return true
} }
if rConf.nameServers[0] != nbNameserverIP.String() { if rConf.nameServers[0] != nbNameserverIP {
return true return true
} }

View File

@@ -29,7 +29,7 @@ type fileConfigurator struct {
repair *repair repair *repair
originalPerms os.FileMode originalPerms os.FileMode
nbNameserverIP netip.Addr nbNameserverIP netip.Addr
originalNameservers []string originalNameservers []netip.Addr
} }
func newFileConfigurator() (*fileConfigurator, error) { func newFileConfigurator() (*fileConfigurator, error) {
@@ -70,7 +70,7 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *st
} }
// getOriginalNameservers returns the nameservers that were found in the original resolv.conf // getOriginalNameservers returns the nameservers that were found in the original resolv.conf
func (f *fileConfigurator) getOriginalNameservers() []string { func (f *fileConfigurator) getOriginalNameservers() []netip.Addr {
return f.originalNameservers return f.originalNameservers
} }
@@ -128,20 +128,14 @@ func (f *fileConfigurator) backup() error {
} }
func (f *fileConfigurator) restore() error { func (f *fileConfigurator) restore() error {
err := removeFirstNbNameserver(fileDefaultResolvConfBackupLocation, f.nbNameserverIP) if err := copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath); err != nil {
if err != nil {
log.Errorf("Failed to remove netbird nameserver from %s on backup restore: %s", fileDefaultResolvConfBackupLocation, err)
}
err = copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath)
if err != nil {
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err) return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
} }
return os.RemoveAll(fileDefaultResolvConfBackupLocation) return os.RemoveAll(fileDefaultResolvConfBackupLocation)
} }
func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error { func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress netip.Addr) error {
resolvConf, err := parseDefaultResolvConf() resolvConf, err := parseDefaultResolvConf()
if err != nil { if err != nil {
return fmt.Errorf("parse current resolv.conf: %w", err) return fmt.Errorf("parse current resolv.conf: %w", err)
@@ -152,16 +146,9 @@ func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Add
return restoreResolvConfFile() return restoreResolvConfFile()
} }
currentDNSAddress, err := netip.ParseAddr(resolvConf.nameServers[0])
// not a valid first nameserver -> restore
if err != nil {
log.Errorf("restoring unclean shutdown: parse dns address %s failed: %s", resolvConf.nameServers[0], err)
return restoreResolvConfFile()
}
// current address is still netbird's non-available dns address -> restore // current address is still netbird's non-available dns address -> restore
// comparing parsed addresses only, to remove ambiguity currentDNSAddress := resolvConf.nameServers[0]
if currentDNSAddress.String() == storedDNSAddress.String() { if currentDNSAddress == storedDNSAddress {
return restoreResolvConfFile() return restoreResolvConfFile()
} }

View File

@@ -239,7 +239,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
} else if inServerAddressesArray { } else if inServerAddressesArray {
address := strings.Split(line, " : ")[1] address := strings.Split(line, " : ")[1]
if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() { if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() {
dnsSettings.ServerIP = ip dnsSettings.ServerIP = ip.Unmap()
inServerAddressesArray = false // Stop reading after finding the first IPv4 address inServerAddressesArray = false // Stop reading after finding the first IPv4 address
} }
} }
@@ -250,7 +250,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
} }
// default to 53 port // default to 53 port
dnsSettings.ServerPort = defaultPort dnsSettings.ServerPort = DefaultPort
return dnsSettings, nil return dnsSettings, nil
} }

View File

@@ -42,7 +42,7 @@ func (t osManagerType) String() string {
type restoreHostManager interface { type restoreHostManager interface {
hostManager hostManager
restoreUncleanShutdownDNS(*netip.Addr) error restoreUncleanShutdownDNS(netip.Addr) error
} }
func newHostManager(wgInterface string) (hostManager, error) { func newHostManager(wgInterface string) (hostManager, error) {
@@ -130,8 +130,9 @@ func checkStub() bool {
return true return true
} }
systemdResolvedAddr := netip.AddrFrom4([4]byte{127, 0, 0, 53}) // 127.0.0.53
for _, ns := range rConf.nameServers { for _, ns := range rConf.nameServers {
if ns == "127.0.0.53" { if ns == systemdResolvedAddr {
return true return true
} }
} }

View File

@@ -216,7 +216,7 @@ func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
return fmt.Errorf("adding dns setup for all failed: %w", err) return fmt.Errorf("adding dns setup for all failed: %w", err)
} }
r.routingAll = true r.routingAll = true
log.Infof("configured %s:53 as main DNS forwarder for this peer", ip) log.Infof("configured %s:%d as main DNS forwarder for this peer", ip, DefaultPort)
return nil return nil
} }

View File

@@ -1,38 +1,31 @@
package dns package dns
import ( import (
"fmt"
"net/netip" "net/netip"
"sync" "sync"
log "github.com/sirupsen/logrus"
) )
type hostsDNSHolder struct { type hostsDNSHolder struct {
unprotectedDNSList map[string]struct{} unprotectedDNSList map[netip.AddrPort]struct{}
mutex sync.RWMutex mutex sync.RWMutex
} }
func newHostsDNSHolder() *hostsDNSHolder { func newHostsDNSHolder() *hostsDNSHolder {
return &hostsDNSHolder{ return &hostsDNSHolder{
unprotectedDNSList: make(map[string]struct{}), unprotectedDNSList: make(map[netip.AddrPort]struct{}),
} }
} }
func (h *hostsDNSHolder) set(list []string) { func (h *hostsDNSHolder) set(list []netip.AddrPort) {
h.mutex.Lock() h.mutex.Lock()
h.unprotectedDNSList = make(map[string]struct{}) h.unprotectedDNSList = make(map[netip.AddrPort]struct{})
for _, dns := range list { for _, addrPort := range list {
dnsAddr, err := h.normalizeAddress(dns) h.unprotectedDNSList[addrPort] = struct{}{}
if err != nil {
continue
}
h.unprotectedDNSList[dnsAddr] = struct{}{}
} }
h.mutex.Unlock() h.mutex.Unlock()
} }
func (h *hostsDNSHolder) get() map[string]struct{} { func (h *hostsDNSHolder) get() map[netip.AddrPort]struct{} {
h.mutex.RLock() h.mutex.RLock()
l := h.unprotectedDNSList l := h.unprotectedDNSList
h.mutex.RUnlock() h.mutex.RUnlock()
@@ -40,24 +33,10 @@ func (h *hostsDNSHolder) get() map[string]struct{} {
} }
//nolint:unused //nolint:unused
func (h *hostsDNSHolder) isContain(upstream string) bool { func (h *hostsDNSHolder) contains(upstream netip.AddrPort) bool {
h.mutex.RLock() h.mutex.RLock()
defer h.mutex.RUnlock() defer h.mutex.RUnlock()
_, ok := h.unprotectedDNSList[upstream] _, ok := h.unprotectedDNSList[upstream]
return ok return ok
} }
func (h *hostsDNSHolder) normalizeAddress(addr string) (string, error) {
a, err := netip.ParseAddr(addr)
if err != nil {
log.Errorf("invalid upstream IP address: %s, error: %s", addr, err)
return "", err
}
if a.Is4() {
return fmt.Sprintf("%s:53", addr), nil
} else {
return fmt.Sprintf("[%s]:53", addr), nil
}
}

View File

@@ -50,7 +50,7 @@ func (m *MockServer) DnsIP() netip.Addr {
return netip.MustParseAddr("100.10.254.255") return netip.MustParseAddr("100.10.254.255")
} }
func (m *MockServer) OnUpdatedHostDNSServer(strings []string) { func (m *MockServer) OnUpdatedHostDNSServer(addrs []netip.AddrPort) {
// TODO implement me // TODO implement me
panic("implement me") panic("implement me")
} }

View File

@@ -245,7 +245,7 @@ func (n *networkManagerDbusConfigurator) deleteConnectionSettings() error {
return nil return nil
} }
func (n *networkManagerDbusConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error { func (n *networkManagerDbusConfigurator) restoreUncleanShutdownDNS(netip.Addr) error {
if err := n.restoreHostDNS(); err != nil { if err := n.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via network-manager: %w", err) return fmt.Errorf("restoring dns via network-manager: %w", err)
} }

View File

@@ -40,7 +40,7 @@ type resolvconf struct {
implType resolvconfType implType resolvconfType
originalSearchDomains []string originalSearchDomains []string
originalNameServers []string originalNameServers []netip.Addr
othersConfigs []string othersConfigs []string
} }
@@ -110,7 +110,7 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *stateman
return nil return nil
} }
func (r *resolvconf) getOriginalNameservers() []string { func (r *resolvconf) getOriginalNameservers() []netip.Addr {
return r.originalNameServers return r.originalNameServers
} }
@@ -158,7 +158,7 @@ func (r *resolvconf) applyConfig(content bytes.Buffer) error {
return nil return nil
} }
func (r *resolvconf) restoreUncleanShutdownDNS(*netip.Addr) error { func (r *resolvconf) restoreUncleanShutdownDNS(netip.Addr) error {
if err := r.restoreHostDNS(); err != nil { if err := r.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns for interface %s: %w", r.ifaceName, err) return fmt.Errorf("restoring dns for interface %s: %w", r.ifaceName, err)
} }

View File

@@ -42,7 +42,7 @@ type Server interface {
Stop() Stop()
DnsIP() netip.Addr DnsIP() netip.Addr
UpdateDNSServer(serial uint64, update nbdns.Config) error UpdateDNSServer(serial uint64, update nbdns.Config) error
OnUpdatedHostDNSServer(strings []string) OnUpdatedHostDNSServer(addrs []netip.AddrPort)
SearchDomains() []string SearchDomains() []string
ProbeAvailability() ProbeAvailability()
} }
@@ -55,7 +55,7 @@ type nsGroupsByDomain struct {
// hostManagerWithOriginalNS extends the basic hostManager interface // hostManagerWithOriginalNS extends the basic hostManager interface
type hostManagerWithOriginalNS interface { type hostManagerWithOriginalNS interface {
hostManager hostManager
getOriginalNameservers() []string getOriginalNameservers() []netip.Addr
} }
// DefaultServer dns server object // DefaultServer dns server object
@@ -136,7 +136,7 @@ func NewDefaultServer(
func NewDefaultServerPermanentUpstream( func NewDefaultServerPermanentUpstream(
ctx context.Context, ctx context.Context,
wgInterface WGIface, wgInterface WGIface,
hostsDnsList []string, hostsDnsList []netip.AddrPort,
config nbdns.Config, config nbdns.Config,
listener listener.NetworkChangeListener, listener listener.NetworkChangeListener,
statusRecorder *peer.Status, statusRecorder *peer.Status,
@@ -144,6 +144,7 @@ func NewDefaultServerPermanentUpstream(
) *DefaultServer { ) *DefaultServer {
log.Debugf("host dns address list is: %v", hostsDnsList) log.Debugf("host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys) ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
ds.hostsDNSHolder.set(hostsDnsList) ds.hostsDNSHolder.set(hostsDnsList)
ds.permanent = true ds.permanent = true
ds.addHostRootZone() ds.addHostRootZone()
@@ -340,7 +341,7 @@ func (s *DefaultServer) disableDNS() error {
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones // OnUpdatedHostDNSServer update the DNS servers addresses for root zones
// It will be applied if the mgm server do not enforce DNS settings for root zone // It will be applied if the mgm server do not enforce DNS settings for root zone
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) { func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []netip.AddrPort) {
s.hostsDNSHolder.set(hostsDnsList) s.hostsDNSHolder.set(hostsDnsList)
// Check if there's any root handler // Check if there's any root handler
@@ -461,7 +462,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort()) s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() { if s.service.RuntimePort() != DefaultPort && !s.hostManager.supportCustomPort() {
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " + log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver") "Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
s.currentConfig.RouteAll = false s.currentConfig.RouteAll = false
@@ -581,14 +582,13 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
} }
for _, ns := range originalNameservers { for _, ns := range originalNameservers {
if ns == config.ServerIP.String() { if ns == config.ServerIP {
log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, config.ServerIP) log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, config.ServerIP)
continue continue
} }
ns = formatAddr(ns, defaultPort) addrPort := netip.AddrPortFrom(ns, DefaultPort)
handler.upstreamServers = append(handler.upstreamServers, addrPort)
handler.upstreamServers = append(handler.upstreamServers, ns)
} }
handler.deactivate = func(error) { /* always active */ } handler.deactivate = func(error) { /* always active */ }
handler.reactivate = func() { /* always active */ } handler.reactivate = func() { /* always active */ }
@@ -695,7 +695,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String()) ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
continue continue
} }
handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns)) handler.upstreamServers = append(handler.upstreamServers, ns.AddrPort())
} }
if len(handler.upstreamServers) == 0 { if len(handler.upstreamServers) == 0 {
@@ -770,18 +770,6 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
s.dnsMuxMap = muxUpdateMap s.dnsMuxMap = muxUpdateMap
} }
func getNSHostPort(ns nbdns.NameServer) string {
return formatAddr(ns.IP.String(), ns.Port)
}
// formatAddr formats a nameserver address with port, handling IPv6 addresses properly
func formatAddr(address string, port int) string {
if ip, err := netip.ParseAddr(address); err == nil && ip.Is6() {
return fmt.Sprintf("[%s]:%d", address, port)
}
return fmt.Sprintf("%s:%d", address, port)
}
// upstreamCallbacks returns two functions, the first one is used to deactivate // upstreamCallbacks returns two functions, the first one is used to deactivate
// the upstream resolver from the configuration, the second one is used to // the upstream resolver from the configuration, the second one is used to
// reactivate it. Not allowed to call reactivate before deactivate. // reactivate it. Not allowed to call reactivate before deactivate.
@@ -879,10 +867,7 @@ func (s *DefaultServer) addHostRootZone() {
return return
} }
handler.upstreamServers = make([]string, 0) handler.upstreamServers = maps.Keys(hostDNSServers)
for k := range hostDNSServers {
handler.upstreamServers = append(handler.upstreamServers, k)
}
handler.deactivate = func(error) {} handler.deactivate = func(error) {}
handler.reactivate = func() {} handler.reactivate = func() {}
@@ -893,9 +878,9 @@ func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
var states []peer.NSGroupState var states []peer.NSGroupState
for _, group := range groups { for _, group := range groups {
var servers []string var servers []netip.AddrPort
for _, ns := range group.NameServers { for _, ns := range group.NameServers {
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port)) servers = append(servers, ns.AddrPort())
} }
state := peer.NSGroupState{ state := peer.NSGroupState{
@@ -927,7 +912,7 @@ func (s *DefaultServer) updateNSState(nsGroup *nbdns.NameServerGroup, err error,
func generateGroupKey(nsGroup *nbdns.NameServerGroup) string { func generateGroupKey(nsGroup *nbdns.NameServerGroup) string {
var servers []string var servers []string
for _, ns := range nsGroup.NameServers { for _, ns := range nsGroup.NameServers {
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port)) servers = append(servers, ns.AddrPort().String())
} }
return fmt.Sprintf("%v_%v", servers, nsGroup.Domains) return fmt.Sprintf("%v_%v", servers, nsGroup.Domains)
} }

View File

@@ -97,9 +97,9 @@ func init() {
} }
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase { func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
var srvs []string var srvs []netip.AddrPort
for _, srv := range servers { for _, srv := range servers {
srvs = append(srvs, getNSHostPort(srv)) srvs = append(srvs, srv.AddrPort())
} }
return &upstreamResolverBase{ return &upstreamResolverBase{
domain: domain, domain: domain,
@@ -705,7 +705,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
} }
defer wgIFace.Close() defer wgIFace.Close()
var dnsList []string var dnsList []netip.AddrPort
dnsConfig := nbdns.Config{} dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, peer.NewRecorder("mgm"), false) dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, peer.NewRecorder("mgm"), false)
err = dnsServer.Initialize() err = dnsServer.Initialize()
@@ -715,7 +715,8 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
} }
defer dnsServer.Stop() defer dnsServer.Stop()
dnsServer.OnUpdatedHostDNSServer([]string{"8.8.8.8"}) addrPort := netip.MustParseAddrPort("8.8.8.8:53")
dnsServer.OnUpdatedHostDNSServer([]netip.AddrPort{addrPort})
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort()) resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
_, err = resolver.LookupHost(context.Background(), "netbird.io") _, err = resolver.LookupHost(context.Background(), "netbird.io")
@@ -731,7 +732,8 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
} }
defer wgIFace.Close() defer wgIFace.Close()
dnsConfig := nbdns.Config{} dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, peer.NewRecorder("mgm"), false) addrPort := netip.MustParseAddrPort("8.8.8.8:53")
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []netip.AddrPort{addrPort}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
err = dnsServer.Initialize() err = dnsServer.Initialize()
if err != nil { if err != nil {
t.Errorf("failed to initialize DNS server: %v", err) t.Errorf("failed to initialize DNS server: %v", err)
@@ -823,7 +825,8 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
} }
defer wgIFace.Close() defer wgIFace.Close()
dnsConfig := nbdns.Config{} dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, peer.NewRecorder("mgm"), false) addrPort := netip.MustParseAddrPort("8.8.8.8:53")
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []netip.AddrPort{addrPort}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
err = dnsServer.Initialize() err = dnsServer.Initialize()
if err != nil { if err != nil {
t.Errorf("failed to initialize DNS server: %v", err) t.Errorf("failed to initialize DNS server: %v", err)
@@ -2053,56 +2056,3 @@ func TestLocalResolverPriorityConstants(t *testing.T) {
assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal") assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal")
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain) assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
} }
func TestFormatAddr(t *testing.T) {
tests := []struct {
name string
address string
port int
expected string
}{
{
name: "IPv4 address",
address: "8.8.8.8",
port: 53,
expected: "8.8.8.8:53",
},
{
name: "IPv4 address with custom port",
address: "1.1.1.1",
port: 5353,
expected: "1.1.1.1:5353",
},
{
name: "IPv6 address",
address: "fd78:94bf:7df8::1",
port: 53,
expected: "[fd78:94bf:7df8::1]:53",
},
{
name: "IPv6 address with custom port",
address: "2001:db8::1",
port: 5353,
expected: "[2001:db8::1]:5353",
},
{
name: "IPv6 localhost",
address: "::1",
port: 53,
expected: "[::1]:53",
},
{
name: "Invalid address treated as hostname",
address: "dns.example.com",
port: 53,
expected: "dns.example.com:53",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := formatAddr(tt.address, tt.port)
assert.Equal(t, tt.expected, result)
})
}
}

View File

@@ -7,7 +7,7 @@ import (
) )
const ( const (
defaultPort = 53 DefaultPort = 53
) )
type service interface { type service interface {

View File

@@ -122,7 +122,7 @@ func (s *serviceViaListener) RuntimePort() int {
defer s.listenerFlagLock.Unlock() defer s.listenerFlagLock.Unlock()
if s.ebpfService != nil { if s.ebpfService != nil {
return defaultPort return DefaultPort
} else { } else {
return int(s.listenPort) return int(s.listenPort)
} }
@@ -148,9 +148,9 @@ func (s *serviceViaListener) evalListenAddress() (netip.Addr, uint16, error) {
return s.customAddr.Addr(), s.customAddr.Port(), nil return s.customAddr.Addr(), s.customAddr.Port(), nil
} }
ip, ok := s.testFreePort(defaultPort) ip, ok := s.testFreePort(DefaultPort)
if ok { if ok {
return ip, defaultPort, nil return ip, DefaultPort, nil
} }
ebpfSrv, port, ok := s.tryToUseeBPF() ebpfSrv, port, ok := s.tryToUseeBPF()

View File

@@ -33,7 +33,7 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
dnsMux: dns.NewServeMux(), dnsMux: dns.NewServeMux(),
runtimeIP: lastIP, runtimeIP: lastIP,
runtimePort: defaultPort, runtimePort: DefaultPort,
} }
return s return s
} }

View File

@@ -235,7 +235,7 @@ func (s *systemdDbusConfigurator) callLinkMethod(method string, value any) error
return nil return nil
} }
func (s *systemdDbusConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error { func (s *systemdDbusConfigurator) restoreUncleanShutdownDNS(netip.Addr) error {
if err := s.restoreHostDNS(); err != nil { if err := s.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via systemd: %w", err) return fmt.Errorf("restoring dns via systemd: %w", err)
} }

View File

@@ -27,7 +27,7 @@ func (s *ShutdownState) Cleanup() error {
return fmt.Errorf("create previous host manager: %w", err) return fmt.Errorf("create previous host manager: %w", err)
} }
if err := manager.restoreUncleanShutdownDNS(&s.DNSAddress); err != nil { if err := manager.restoreUncleanShutdownDNS(s.DNSAddress); err != nil {
return fmt.Errorf("restore unclean shutdown dns: %w", err) return fmt.Errorf("restore unclean shutdown dns: %w", err)
} }

View File

@@ -8,6 +8,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/netip"
"slices" "slices"
"strings" "strings"
"sync" "sync"
@@ -48,7 +49,7 @@ type upstreamResolverBase struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
upstreamClient upstreamClient upstreamClient upstreamClient
upstreamServers []string upstreamServers []netip.AddrPort
domain string domain string
disabled bool disabled bool
failsCount atomic.Int32 failsCount atomic.Int32
@@ -79,17 +80,20 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d
// String returns a string representation of the upstream resolver // String returns a string representation of the upstream resolver
func (u *upstreamResolverBase) String() string { func (u *upstreamResolverBase) String() string {
return fmt.Sprintf("upstream %v", u.upstreamServers) return fmt.Sprintf("upstream %s", u.upstreamServers)
} }
// ID returns the unique handler ID // ID returns the unique handler ID
func (u *upstreamResolverBase) ID() types.HandlerID { func (u *upstreamResolverBase) ID() types.HandlerID {
servers := slices.Clone(u.upstreamServers) servers := slices.Clone(u.upstreamServers)
slices.Sort(servers) slices.SortFunc(servers, func(a, b netip.AddrPort) int { return a.Compare(b) })
hash := sha256.New() hash := sha256.New()
hash.Write([]byte(u.domain + ":")) hash.Write([]byte(u.domain + ":"))
hash.Write([]byte(strings.Join(servers, ","))) for _, s := range servers {
hash.Write([]byte(s.String()))
hash.Write([]byte("|"))
}
return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8])) return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
} }
@@ -130,7 +134,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
func() { func() {
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout) ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
defer cancel() defer cancel()
rm, t, err = u.upstreamClient.exchange(ctx, upstream, r) rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
}() }()
if err != nil { if err != nil {
@@ -197,7 +201,7 @@ func (u *upstreamResolverBase) checkUpstreamFails(err error) {
proto.SystemEvent_DNS, proto.SystemEvent_DNS,
"All upstream servers failed (fail count exceeded)", "All upstream servers failed (fail count exceeded)",
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.", "Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
map[string]string{"upstreams": strings.Join(u.upstreamServers, ", ")}, map[string]string{"upstreams": u.upstreamServersString()},
// TODO add domain meta // TODO add domain meta
) )
} }
@@ -258,7 +262,7 @@ func (u *upstreamResolverBase) ProbeAvailability() {
proto.SystemEvent_DNS, proto.SystemEvent_DNS,
"All upstream servers failed (probe failed)", "All upstream servers failed (probe failed)",
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.", "Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
map[string]string{"upstreams": strings.Join(u.upstreamServers, ", ")}, map[string]string{"upstreams": u.upstreamServersString()},
) )
} }
} }
@@ -278,7 +282,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
operation := func() error { operation := func() error {
select { select {
case <-u.ctx.Done(): case <-u.ctx.Done():
return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServers)) return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServersString()))
default: default:
} }
@@ -291,7 +295,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
} }
} }
log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServers, exponentialBackOff.NextBackOff()) log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServersString(), exponentialBackOff.NextBackOff())
return fmt.Errorf("upstream check call error") return fmt.Errorf("upstream check call error")
} }
@@ -301,7 +305,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
return return
} }
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServers) log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
u.failsCount.Store(0) u.failsCount.Store(0)
u.successCount.Add(1) u.successCount.Add(1)
u.reactivate() u.reactivate()
@@ -331,13 +335,21 @@ func (u *upstreamResolverBase) disable(err error) {
go u.waitUntilResponse() go u.waitUntilResponse()
} }
func (u *upstreamResolverBase) testNameserver(server string, timeout time.Duration) error { func (u *upstreamResolverBase) upstreamServersString() string {
var servers []string
for _, server := range u.upstreamServers {
servers = append(servers, server.String())
}
return strings.Join(servers, ", ")
}
func (u *upstreamResolverBase) testNameserver(server netip.AddrPort, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(u.ctx, timeout) ctx, cancel := context.WithTimeout(u.ctx, timeout)
defer cancel() defer cancel()
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA) r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
_, _, err := u.upstreamClient.exchange(ctx, server, r) _, _, err := u.upstreamClient.exchange(ctx, server.String(), r)
return err return err
} }

View File

@@ -79,8 +79,8 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri
} }
func (u *upstreamResolver) isLocalResolver(upstream string) bool { func (u *upstreamResolver) isLocalResolver(upstream string) bool {
if u.hostsDNSHolder.isContain(upstream) { if addrPort, err := netip.ParseAddrPort(upstream); err == nil {
return true return u.hostsDNSHolder.contains(addrPort)
} }
return false return false
} }

View File

@@ -62,6 +62,8 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
upstreamIP, err := netip.ParseAddr(upstreamHost) upstreamIP, err := netip.ParseAddr(upstreamHost)
if err != nil { if err != nil {
log.Warnf("failed to parse upstream host %s: %s", upstreamHost, err) log.Warnf("failed to parse upstream host %s: %s", upstreamHost, err)
} else {
upstreamIP = upstreamIP.Unmap()
} }
if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() { if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() {
log.Debugf("using private client to query upstream: %s", upstream) log.Debugf("using private client to query upstream: %s", upstream)

View File

@@ -59,7 +59,14 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO()) ctx, cancel := context.WithCancel(context.TODO())
resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".") resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".")
resolver.upstreamServers = testCase.InputServers // Convert test servers to netip.AddrPort
var servers []netip.AddrPort
for _, server := range testCase.InputServers {
if addrPort, err := netip.ParseAddrPort(server); err == nil {
servers = append(servers, netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()))
}
}
resolver.upstreamServers = servers
resolver.upstreamTimeout = testCase.timeout resolver.upstreamTimeout = testCase.timeout
if testCase.cancelCTX { if testCase.cancelCTX {
cancel() cancel()
@@ -128,7 +135,8 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
reactivatePeriod: reactivatePeriod, reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact, failsTillDeact: failsTillDeact,
} }
resolver.upstreamServers = []string{"0.0.0.0:-1"} addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection
resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())}
resolver.failsTillDeact = 0 resolver.failsTillDeact = 0
resolver.reactivatePeriod = time.Microsecond * 100 resolver.reactivatePeriod = time.Microsecond * 100

View File

@@ -1,6 +1,8 @@
package internal package internal
import ( import (
"net/netip"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
@@ -13,7 +15,7 @@ type MobileDependency struct {
TunAdapter device.TunAdapter TunAdapter device.TunAdapter
IFaceDiscover stdnet.ExternalIFaceDiscover IFaceDiscover stdnet.ExternalIFaceDiscover
NetworkChangeListener listener.NetworkChangeListener NetworkChangeListener listener.NetworkChangeListener
HostDNSAddresses []string HostDNSAddresses []netip.AddrPort
DnsReadyListener dns.ReadyListener DnsReadyListener dns.ReadyListener
// iOS only // iOS only

View File

@@ -140,7 +140,7 @@ type RosenpassState struct {
// whether it's enabled, and the last error message encountered during probing. // whether it's enabled, and the last error message encountered during probing.
type NSGroupState struct { type NSGroupState struct {
ID string ID string
Servers []string Servers []netip.AddrPort
Domains []string Domains []string
Enabled bool Enabled bool
Error error Error error

View File

@@ -593,17 +593,9 @@ func update(input ConfigInput) (*Config, error) {
return config, nil return config, nil
} }
// GetConfig read config file and return with Config. Errors out if it does not exist
func GetConfig(configPath string) (*Config, error) { func GetConfig(configPath string) (*Config, error) {
if !fileExists(configPath) { return readConfig(configPath, false)
return nil, fmt.Errorf("config file %s does not exist", configPath)
}
config := &Config{}
if _, err := util.ReadJson(configPath, config); err != nil {
return nil, fmt.Errorf("failed to read config file %s: %w", configPath, err)
}
return config, nil
} }
// UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain. // UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain.
@@ -695,6 +687,11 @@ func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
// ReadConfig read config file and return with Config. If it is not exists create a new with default values // ReadConfig read config file and return with Config. If it is not exists create a new with default values
func ReadConfig(configPath string) (*Config, error) { func ReadConfig(configPath string) (*Config, error) {
return readConfig(configPath, true)
}
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
func readConfig(configPath string, createIfMissing bool) (*Config, error) {
if fileExists(configPath) { if fileExists(configPath) {
err := util.EnforcePermission(configPath) err := util.EnforcePermission(configPath)
if err != nil { if err != nil {
@@ -715,6 +712,8 @@ func ReadConfig(configPath string) (*Config, error) {
} }
return config, nil return config, nil
} else if !createIfMissing {
return nil, fmt.Errorf("config file %s does not exist", configPath)
} }
cfg, err := createNewConfig(ConfigInput{ConfigPath: configPath}) cfg, err := createNewConfig(ConfigInput{ConfigPath: configPath})

View File

@@ -16,19 +16,21 @@
<StandardDirectory Id="ProgramFiles64Folder"> <StandardDirectory Id="ProgramFiles64Folder">
<Directory Id="NetbirdInstallDir" Name="Netbird"> <Directory Id="NetbirdInstallDir" Name="Netbird">
<Component Id="NetbirdFiles" Guid="db3165de-cc6e-4922-8396-9d892950e23e" Bitness="always64"> <Component Id="NetbirdFiles" Guid="db3165de-cc6e-4922-8396-9d892950e23e" Bitness="always64">
<File ProcessorArchitecture="x64" Source=".\dist\netbird_windows_amd64\netbird.exe" KeyPath="yes" /> <File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\netbird.exe" KeyPath="yes" />
<File ProcessorArchitecture="x64" Source=".\dist\netbird_windows_amd64\netbird-ui.exe"> <File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\netbird-ui.exe">
<Shortcut Id="NetbirdDesktopShortcut" Directory="DesktopFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" /> <Shortcut Id="NetbirdDesktopShortcut" Directory="DesktopFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" />
<Shortcut Id="NetbirdStartMenuShortcut" Directory="StartMenuFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" /> <Shortcut Id="NetbirdStartMenuShortcut" Directory="StartMenuFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" />
</File> </File>
<File ProcessorArchitecture="x64" Source=".\dist\netbird_windows_amd64\wintun.dll" /> <File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\wintun.dll" />
<File ProcessorArchitecture="x64" Source=".\dist\netbird_windows_amd64\opengl32.dll" /> <?if $(var.ArchSuffix) = "amd64" ?>
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\opengl32.dll" />
<?endif ?>
<ServiceInstall <ServiceInstall
Id="NetBirdService" Id="NetBirdService"
Name="NetBird" Name="NetBird"
DisplayName="NetBird" DisplayName="NetBird"
Description="A WireGuard-based mesh network that connects your devices into a single private network." Description="Connect your devices into a secure WireGuard-based overlay network with SSO, MFA and granular access controls."
Start="auto" Type="ownProcess" Start="auto" Type="ownProcess"
ErrorControl="normal" ErrorControl="normal"
Account="LocalSystem" Account="LocalSystem"

View File

@@ -1197,8 +1197,14 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
if dnsState.Error != nil { if dnsState.Error != nil {
err = dnsState.Error.Error() err = dnsState.Error.Error()
} }
var servers []string
for _, server := range dnsState.Servers {
servers = append(servers, server.String())
}
pbDnsState := &proto.NSGroupState{ pbDnsState := &proto.NSGroupState{
Servers: dnsState.Servers, Servers: servers,
Domains: dnsState.Domains, Domains: dnsState.Domains,
Enabled: dnsState.Enabled, Enabled: dnsState.Enabled,
Error: err, Error: err,

View File

@@ -46,7 +46,7 @@ func (s *serviceClient) showProfilesUI() {
widget.NewLabel(""), // profile name widget.NewLabel(""), // profile name
layout.NewSpacer(), layout.NewSpacer(),
widget.NewButton("Select", nil), widget.NewButton("Select", nil),
widget.NewButton("Logout", nil), widget.NewButton("Deregister", nil),
widget.NewButton("Remove", nil), widget.NewButton("Remove", nil),
) )
}, },
@@ -128,7 +128,7 @@ func (s *serviceClient) showProfilesUI() {
} }
logoutBtn.Show() logoutBtn.Show()
logoutBtn.SetText("Logout") logoutBtn.SetText("Deregister")
logoutBtn.OnTapped = func() { logoutBtn.OnTapped = func() {
s.handleProfileLogout(profile.Name, refresh) s.handleProfileLogout(profile.Name, refresh)
} }
@@ -143,7 +143,7 @@ func (s *serviceClient) showProfilesUI() {
if !confirm { if !confirm {
return return
} }
err = s.removeProfile(profile.Name) err = s.removeProfile(profile.Name)
if err != nil { if err != nil {
log.Errorf("failed to remove profile: %v", err) log.Errorf("failed to remove profile: %v", err)
@@ -334,27 +334,27 @@ func (s *serviceClient) getProfiles() ([]Profile, error) {
func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback func()) { func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback func()) {
dialog.ShowConfirm( dialog.ShowConfirm(
"Logout", "Deregister",
fmt.Sprintf("Are you sure you want to logout from '%s'?", profileName), fmt.Sprintf("Are you sure you want to deregister from '%s'?", profileName),
func(confirm bool) { func(confirm bool) {
if !confirm { if !confirm {
return return
} }
conn, err := s.getSrvClient(defaultFailTimeout) conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil { if err != nil {
log.Errorf("failed to get service client: %v", err) log.Errorf("failed to get service client: %v", err)
dialog.ShowError(fmt.Errorf("failed to connect to service"), s.wProfiles) dialog.ShowError(fmt.Errorf("failed to connect to service"), s.wProfiles)
return return
} }
currUser, err := user.Current() currUser, err := user.Current()
if err != nil { if err != nil {
log.Errorf("failed to get current user: %v", err) log.Errorf("failed to get current user: %v", err)
dialog.ShowError(fmt.Errorf("failed to get current user"), s.wProfiles) dialog.ShowError(fmt.Errorf("failed to get current user"), s.wProfiles)
return return
} }
username := currUser.Username username := currUser.Username
_, err = conn.Logout(s.ctx, &proto.LogoutRequest{ _, err = conn.Logout(s.ctx, &proto.LogoutRequest{
ProfileName: &profileName, ProfileName: &profileName,
@@ -362,16 +362,16 @@ func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback
}) })
if err != nil { if err != nil {
log.Errorf("logout failed: %v", err) log.Errorf("logout failed: %v", err)
dialog.ShowError(fmt.Errorf("logout failed"), s.wProfiles) dialog.ShowError(fmt.Errorf("deregister failed"), s.wProfiles)
return return
} }
dialog.ShowInformation( dialog.ShowInformation(
"Logged Out", "Deregistered",
fmt.Sprintf("Successfully logged out from '%s'", profileName), fmt.Sprintf("Successfully deregistered from '%s'", profileName),
s.wProfiles, s.wProfiles,
) )
refreshCallback() refreshCallback()
}, },
s.wProfiles, s.wProfiles,
@@ -602,7 +602,7 @@ func (p *profileMenu) refresh() {
// Add Logout menu item // Add Logout menu item
ctx2, cancel2 := context.WithCancel(context.Background()) ctx2, cancel2 := context.WithCancel(context.Background())
logoutItem := p.profileMenuItem.AddSubMenuItem("Logout", "") logoutItem := p.profileMenuItem.AddSubMenuItem("Deregister", "")
p.logoutSubItem = &subItem{logoutItem, ctx2, cancel2} p.logoutSubItem = &subItem{logoutItem, ctx2, cancel2}
go func() { go func() {
@@ -616,9 +616,9 @@ func (p *profileMenu) refresh() {
} }
if err := p.eventHandler.logout(p.ctx); err != nil { if err := p.eventHandler.logout(p.ctx); err != nil {
log.Errorf("logout failed: %v", err) log.Errorf("logout failed: %v", err)
p.app.SendNotification(fyne.NewNotification("Error", "Failed to logout")) p.app.SendNotification(fyne.NewNotification("Error", "Failed to deregister"))
} else { } else {
p.app.SendNotification(fyne.NewNotification("Success", "Logged out successfully")) p.app.SendNotification(fyne.NewNotification("Success", "Deregistered successfully"))
} }
} }
} }

View File

@@ -102,6 +102,11 @@ func (n *NameServer) IsEqual(other *NameServer) bool {
other.Port == n.Port other.Port == n.Port
} }
// AddrPort returns the nameserver as a netip.AddrPort
func (n *NameServer) AddrPort() netip.AddrPort {
return netip.AddrPortFrom(n.IP, uint16(n.Port))
}
// ParseNameServerURL parses a nameserver url in the format <type>://<ip>:<port>, e.g., udp://1.1.1.1:53 // ParseNameServerURL parses a nameserver url in the format <type>://<ip>:<port>, e.g., udp://1.1.1.1:53
func ParseNameServerURL(nsURL string) (NameServer, error) { func ParseNameServerURL(nsURL string) (NameServer, error) {
parsedURL, err := url.Parse(nsURL) parsedURL, err := url.Parse(nsURL)

View File

@@ -571,19 +571,19 @@ func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
accountId := xid.New().String() accountId := xid.New().String()
_, err := am.Store.GetAccount(ctx, accountId) exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, accountId)
statusErr, _ := status.FromError(err) if err != nil {
switch {
case err == nil:
log.WithContext(ctx).Warnf("an account with ID already exists, retrying...")
continue
case statusErr.Type() == status.NotFound:
newAccount := newAccountWithId(ctx, accountId, userID, domain, am.disableDefaultPolicy)
am.StoreEvent(ctx, userID, newAccount.Id, accountId, activity.AccountCreated, nil)
return newAccount, nil
default:
return nil, err return nil, err
} }
if exists {
log.WithContext(ctx).Warnf("an account with ID already exists, retrying...")
continue
}
newAccount := newAccountWithId(ctx, accountId, userID, domain, am.disableDefaultPolicy)
am.StoreEvent(ctx, userID, newAccount.Id, accountId, activity.AccountCreated, nil)
return newAccount, nil
} }
return nil, status.Errorf(status.Internal, "error while creating new account") return nil, status.Errorf(status.Internal, "error while creating new account")
@@ -1143,21 +1143,29 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context,
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
defer unlockAccount() defer unlockAccount()
newUser := types.NewRegularUser(userAuth.UserId) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
newUser.AccountID = domainAccountID
err := am.Store.SaveUser(ctx, newUser)
if err != nil { if err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
newUser := types.NewRegularUser(userAuth.UserId)
newUser.AccountID = domainAccountID
err = am.Store.SaveUser(ctx, newUser)
if err != nil {
return "", err
}
err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, domainAccountID)
if err != nil {
return "", err
}
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, nil)
return domainAccountID, nil
}
return "", err return "", err
} }
err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, domainAccountID) return user.AccountID, nil
if err != nil {
return "", err
}
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, nil)
return domainAccountID, nil
} }
// redeemInvite checks whether user has been invited and redeems the invite // redeemInvite checks whether user has been invited and redeems the invite

View File

@@ -3453,6 +3453,50 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
}) })
} }
func TestDefaultAccountManager_AddNewUserToDomainAccount(t *testing.T) {
testCases := []struct {
name string
userAuth nbcontext.UserAuth
expectedRole types.UserRole
}{
{
name: "existing user",
userAuth: nbcontext.UserAuth{
Domain: "example.com",
UserId: "user1",
},
expectedRole: types.UserRoleOwner,
},
{
name: "new user",
userAuth: nbcontext.UserAuth{
Domain: "example.com",
UserId: "user2",
},
expectedRole: types.UserRoleUser,
},
}
manager, err := createManager(t)
require.NoError(t, err)
accountID, err := manager.GetAccountIDByUserID(context.Background(), "user1", "example.com")
require.NoError(t, err, "create init user failed")
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
userAccountID, err := manager.addNewUserToDomainAccount(context.Background(), accountID, tc.userAuth)
require.NoError(t, err)
assert.Equal(t, accountID, userAccountID)
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, tc.userAuth.UserId)
require.NoError(t, err)
assert.Equal(t, accountID, user.AccountID)
assert.Equal(t, tc.expectedRole, user.Role)
})
}
}
func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) { func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err) require.NoError(t, err)

View File

@@ -913,6 +913,7 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage)
func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) { func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
log.WithContext(ctx).Debugf("Logout request from peer [%s]", req.WgPubKey) log.WithContext(ctx).Debugf("Logout request from peer [%s]", req.WgPubKey)
start := time.Now()
empty := &proto.Empty{} empty := &proto.Empty{}
peerKey, err := s.parseRequest(ctx, req, empty) peerKey, err := s.parseRequest(ctx, req, empty)
@@ -944,7 +945,7 @@ func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*
s.accountManager.BufferUpdateAccountPeers(ctx, peer.AccountID) s.accountManager.BufferUpdateAccountPeers(ctx, peer.AccountID)
log.WithContext(ctx).Infof("peer %s logged out successfully", peerKey.String()) log.WithContext(ctx).Debugf("peer %s logged out successfully after %s", peerKey.String(), time.Since(start))
return &proto.Empty{}, nil return &proto.Empty{}, nil
} }

View File

@@ -15,6 +15,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause"
) )
func GetColumnName(db *gorm.DB, column string) string { func GetColumnName(db *gorm.DB, column string) string {
@@ -466,7 +467,7 @@ func MigrateJsonToTable[T any](ctx context.Context, db *gorm.DB, columnName stri
} }
for _, value := range data { for _, value := range data {
if err := tx.Create( if err := tx.Clauses(clause.OnConflict{DoNothing: true}).Create(
mapperFunc(row["account_id"].(string), row["id"].(string), value), mapperFunc(row["account_id"].(string), row["id"].(string), value),
).Error; err != nil { ).Error; err != nil {
return fmt.Errorf("failed to insert id %v: %w", row["id"], err) return fmt.Errorf("failed to insert id %v: %w", row["id"], err)

View File

@@ -609,7 +609,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
newPeer.DNSLabel = freeLabel newPeer.DNSLabel = freeLabel
newPeer.IP = freeIP newPeer.IP = freeIP
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
defer func() { defer func() {
if unlock != nil { if unlock != nil {
unlock() unlock()

View File

@@ -1476,8 +1476,9 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
func Test_RegisterPeerRollbackOnFailure(t *testing.T) { func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
engine := os.Getenv("NETBIRD_STORE_ENGINE") engine := os.Getenv("NETBIRD_STORE_ENGINE")
if engine == "sqlite" || engine == "" { if engine == "sqlite" || engine == "mysql" || engine == "" {
t.Skip("Skipping test because sqlite test store is not respecting foreign keys") // we intentionally disabled foreign keys in mysql
t.Skip("Skipping test because store is not respecting foreign keys")
} }
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet") t.Skip("The SQLite store is not properly supported by Windows yet")

View File

@@ -24,6 +24,7 @@ import (
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
nbcontext "github.com/netbirdio/netbird/management/server/context"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -76,7 +77,12 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
conns = runtime.NumCPU() conns = runtime.NumCPU()
} }
if storeEngine == types.SqliteStoreEngine { switch storeEngine {
case types.MysqlStoreEngine:
if err := db.Exec("SET GLOBAL FOREIGN_KEY_CHECKS = 0").Error; err != nil {
return nil, err
}
case types.SqliteStoreEngine:
if err == nil { if err == nil {
log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1") log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1")
} }
@@ -142,14 +148,16 @@ func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) { func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
log.WithContext(ctx).Tracef("acquiring write lock for ID %s", uniqueID) log.WithContext(ctx).Tracef("acquiring write lock for ID %s", uniqueID)
start := time.Now() startWait := time.Now()
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{}) value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
mtx := value.(*sync.RWMutex) mtx := value.(*sync.RWMutex)
mtx.Lock() mtx.Lock()
log.WithContext(ctx).Tracef("waiting to acquire write lock for ID %s in %v", uniqueID, time.Since(startWait))
startHold := time.Now()
unlock = func() { unlock = func() {
mtx.Unlock() mtx.Unlock()
log.WithContext(ctx).Tracef("released write lock for ID %s in %v", uniqueID, time.Since(start)) log.WithContext(ctx).Tracef("released write lock for ID %s in %v", uniqueID, time.Since(startHold))
} }
return unlock return unlock
@@ -159,19 +167,22 @@ func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (
func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) { func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
log.WithContext(ctx).Tracef("acquiring read lock for ID %s", uniqueID) log.WithContext(ctx).Tracef("acquiring read lock for ID %s", uniqueID)
start := time.Now() startWait := time.Now()
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{}) value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
mtx := value.(*sync.RWMutex) mtx := value.(*sync.RWMutex)
mtx.RLock() mtx.RLock()
log.WithContext(ctx).Tracef("waiting to acquire read lock for ID %s in %v", uniqueID, time.Since(startWait))
startHold := time.Now()
unlock = func() { unlock = func() {
mtx.RUnlock() mtx.RUnlock()
log.WithContext(ctx).Tracef("released read lock for ID %s in %v", uniqueID, time.Since(start)) log.WithContext(ctx).Tracef("released read lock for ID %s in %v", uniqueID, time.Since(startHold))
} }
return unlock return unlock
} }
// Deprecated: Full account operations are no longer supported
func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) error { func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) error {
start := time.Now() start := time.Now()
defer func() { defer func() {
@@ -603,13 +614,16 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
} }
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) { func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
tx := s.db tx := s.db
if lockStrength != LockingStrengthNone { if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
} }
var user types.User var user types.User
result := tx.Take(&user, idQueryCondition, userID) result := tx.WithContext(ctx).Take(&user, idQueryCondition, userID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewUserNotFoundError(userID) return nil, status.NewUserNotFoundError(userID)
@@ -1075,13 +1089,16 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
} }
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) { func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
tx := s.db tx := s.db
if lockStrength != LockingStrengthNone { if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
} }
var accountNetwork types.AccountNetwork var accountNetwork types.AccountNetwork
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil { if err := tx.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID) return nil, status.NewAccountNotFoundError(accountID)
} }
@@ -1091,13 +1108,16 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
} }
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) { func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
tx := s.db tx := s.db
if lockStrength != LockingStrengthNone { if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
} }
var peer nbpeer.Peer var peer nbpeer.Peer
result := tx.Take(&peer, GetKeyQueryCondition(s), peerKey) result := tx.WithContext(ctx).Take(&peer, GetKeyQueryCondition(s), peerKey)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -1146,8 +1166,11 @@ func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength Locking
// SaveUserLastLogin stores the last login time for a user in DB. // SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error { func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
var user types.User var user types.User
result := s.db.Take(&user, accountAndIDQueryCondition, accountID, userID) result := s.db.WithContext(ctx).Take(&user, accountAndIDQueryCondition, accountID, userID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.NewUserNotFoundError(userID) return status.NewUserNotFoundError(userID)
@@ -1328,13 +1351,16 @@ func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn s
} }
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) { func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
tx := s.db tx := s.db
if lockStrength != LockingStrengthNone { if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
} }
var setupKey types.SetupKey var setupKey types.SetupKey
result := tx. result := tx.WithContext(ctx).
Take(&setupKey, GetKeyQueryCondition(s), key) Take(&setupKey, GetKeyQueryCondition(s), key)
if result.Error != nil { if result.Error != nil {
@@ -1348,7 +1374,10 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
} }
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error { func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
result := s.db.Model(&types.SetupKey{}). ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
result := s.db.WithContext(ctx).Model(&types.SetupKey{}).
Where(idQueryCondition, setupKeyID). Where(idQueryCondition, setupKeyID).
Updates(map[string]interface{}{ Updates(map[string]interface{}{
"used_times": gorm.Expr("used_times + 1"), "used_times": gorm.Expr("used_times + 1"),
@@ -1368,8 +1397,11 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction // AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
var groupID string var groupID string
_ = s.db.Model(types.Group{}). _ = s.db.WithContext(ctx).Model(types.Group{}).
Select("id"). Select("id").
Where("account_id = ? AND name = ?", accountID, "All"). Where("account_id = ? AND name = ?", accountID, "All").
Limit(1). Limit(1).
@@ -1397,13 +1429,16 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
// AddPeerToGroup adds a peer to a group // AddPeerToGroup adds a peer to a group
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupID string) error { func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupID string) error {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
peer := &types.GroupPeer{ peer := &types.GroupPeer{
AccountID: accountID, AccountID: accountID,
GroupID: groupID, GroupID: groupID,
PeerID: peerID, PeerID: peerID,
} }
err := s.db.Clauses(clause.OnConflict{ err := s.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}}, Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}},
DoNothing: true, DoNothing: true,
}).Create(peer).Error }).Create(peer).Error
@@ -1593,7 +1628,10 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt
} }
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
if err := s.db.Create(peer).Error; err != nil { ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account: %s", err) return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
} }
@@ -1719,7 +1757,10 @@ func (s *SqlStore) DeletePeer(ctx context.Context, accountID string, peerID stri
} }
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
result := s.db.Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
result := s.db.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error) log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
return status.Errorf(status.Internal, "failed to increment network serial count in store") return status.Errorf(status.Internal, "failed to increment network serial count in store")
@@ -2761,3 +2802,33 @@ func (s *SqlStore) GetAccountGroupPeers(ctx context.Context, lockStrength Lockin
return groupPeers, nil return groupPeers, nil
} }
func getDebuggingCtx(grpcCtx context.Context) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
userID, ok := grpcCtx.Value(nbcontext.UserIDKey).(string)
if ok {
//nolint
ctx = context.WithValue(ctx, nbcontext.UserIDKey, userID)
}
requestID, ok := grpcCtx.Value(nbcontext.RequestIDKey).(string)
if ok {
//nolint
ctx = context.WithValue(ctx, nbcontext.RequestIDKey, requestID)
}
accountID, ok := grpcCtx.Value(nbcontext.AccountIDKey).(string)
if ok {
//nolint
ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
}
go func() {
select {
case <-ctx.Done():
case <-grpcCtx.Done():
log.WithContext(grpcCtx).Warnf("grpc context ended early, error: %v", grpcCtx.Err())
}
}()
return ctx, cancel
}

View File

@@ -503,7 +503,7 @@ func (c *GrpcClient) Logout() error {
return fmt.Errorf("get server public key: %w", err) return fmt.Errorf("get server public key: %w", err)
} }
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*5) mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*15)
defer cancel() defer cancel()
message := &proto.Empty{} message := &proto.Empty{}