mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
Compare commits
17 Commits
snyk-fix-c
...
handle-exi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2b86463e96 | ||
|
|
9deff6f06b | ||
|
|
a942e4add5 | ||
|
|
1022a5015c | ||
|
|
375fcf2752 | ||
|
|
9acf7f9262 | ||
|
|
82937ba184 | ||
|
|
0f52144894 | ||
|
|
0926400b8a | ||
|
|
bef99d48f8 | ||
|
|
9e95841252 | ||
|
|
6da3943559 | ||
|
|
f5b4659adb | ||
|
|
3d19468b6c | ||
|
|
1a1e94c805 | ||
|
|
ed939bf7f5 | ||
|
|
7caf733217 |
12
.github/pull_request_template.md
vendored
12
.github/pull_request_template.md
vendored
@@ -12,6 +12,16 @@
|
||||
- [ ] Is a feature enhancement
|
||||
- [ ] It is a refactor
|
||||
- [ ] Created tests that fail without the change (if possible)
|
||||
- [ ] Extended the README / documentation, if necessary
|
||||
|
||||
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).
|
||||
|
||||
## Documentation
|
||||
Select exactly one:
|
||||
|
||||
- [ ] I added/updated documentation for this change
|
||||
- [ ] Documentation is **not needed** for this change (explain why)
|
||||
|
||||
### Docs PR URL (required if "docs added" is checked)
|
||||
Paste the PR link from https://github.com/netbirdio/docs here:
|
||||
|
||||
https://github.com/netbirdio/docs/pull/__
|
||||
|
||||
94
.github/workflows/docs-ack.yml
vendored
Normal file
94
.github/workflows/docs-ack.yml
vendored
Normal file
@@ -0,0 +1,94 @@
|
||||
name: Docs Acknowledgement
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, edited, synchronize]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: read
|
||||
|
||||
jobs:
|
||||
docs-ack:
|
||||
name: Require docs PR URL or explicit "not needed"
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Read PR body
|
||||
id: body
|
||||
run: |
|
||||
BODY=$(jq -r '.pull_request.body // ""' "$GITHUB_EVENT_PATH")
|
||||
echo "body<<EOF" >> $GITHUB_OUTPUT
|
||||
echo "$BODY" >> $GITHUB_OUTPUT
|
||||
echo "EOF" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Validate checkbox selection
|
||||
id: validate
|
||||
run: |
|
||||
body='${{ steps.body.outputs.body }}'
|
||||
|
||||
added_checked=$(printf "%s" "$body" | grep -E '^- \[x\] I added/updated documentation' -i | wc -l | tr -d ' ')
|
||||
noneed_checked=$(printf "%s" "$body" | grep -E '^- \[x\] Documentation is \*\*not needed\*\*' -i | wc -l | tr -d ' ')
|
||||
|
||||
if [ "$added_checked" -eq 1 ] && [ "$noneed_checked" -eq 1 ]; then
|
||||
echo "::error::Choose exactly one: either 'docs added' OR 'not needed'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ "$added_checked" -eq 0 ] && [ "$noneed_checked" -eq 0 ]; then
|
||||
echo "::error::You must check exactly one docs option in the PR template."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ "$added_checked" -eq 1 ]; then
|
||||
echo "mode=added" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "mode=noneed" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Extract docs PR URL (when 'docs added')
|
||||
if: steps.validate.outputs.mode == 'added'
|
||||
id: extract
|
||||
run: |
|
||||
body='${{ steps.body.outputs.body }}'
|
||||
|
||||
# Strictly require HTTPS and that it's a PR in netbirdio/docs
|
||||
# Examples accepted:
|
||||
# https://github.com/netbirdio/docs/pull/1234
|
||||
url=$(printf "%s" "$body" | grep -Eo 'https://github\.com/netbirdio/docs/pull/[0-9]+' | head -n1 || true)
|
||||
|
||||
if [ -z "$url" ]; then
|
||||
echo "::error::You checked 'docs added' but didn't include a valid HTTPS PR link to netbirdio/docs (e.g., https://github.com/netbirdio/docs/pull/1234)."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
pr_number=$(echo "$url" | sed -E 's#.*/pull/([0-9]+)$#\1#')
|
||||
echo "url=$url" >> $GITHUB_OUTPUT
|
||||
echo "pr_number=$pr_number" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Verify docs PR exists (and is open or merged)
|
||||
if: steps.validate.outputs.mode == 'added'
|
||||
uses: actions/github-script@v7
|
||||
id: verify
|
||||
with:
|
||||
pr_number: ${{ steps.extract.outputs.pr_number }}
|
||||
script: |
|
||||
const prNumber = parseInt(core.getInput('pr_number'), 10);
|
||||
const { data } = await github.rest.pulls.get({
|
||||
owner: 'netbirdio',
|
||||
repo: 'docs',
|
||||
pull_number: prNumber
|
||||
});
|
||||
|
||||
// Allow open or merged PRs
|
||||
const ok = data.state === 'open' || data.merged === true;
|
||||
core.setOutput('state', data.state);
|
||||
core.setOutput('merged', String(!!data.merged));
|
||||
if (!ok) {
|
||||
core.setFailed(`Docs PR #${prNumber} exists but is neither open nor merged (state=${data.state}, merged=${data.merged}).`);
|
||||
}
|
||||
result-encoding: string
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: All good
|
||||
run: echo "Documentation requirement satisfied ✅"
|
||||
18
.github/workflows/forum.yml
vendored
Normal file
18
.github/workflows/forum.yml
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
name: Post release topic on Discourse
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
jobs:
|
||||
post:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: roots/discourse-topic-github-release-action@main
|
||||
with:
|
||||
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
|
||||
discourse-base-url: https://forum.netbird.io
|
||||
discourse-author-username: NetBird
|
||||
discourse-category: 17
|
||||
discourse-tags:
|
||||
releases
|
||||
28
.github/workflows/release.yml
vendored
28
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.21"
|
||||
SIGN_PIPE_VER: "v0.0.22"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
@@ -79,6 +79,8 @@ jobs:
|
||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||
- name: Generate windows syso amd64
|
||||
run: goversioninfo -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
|
||||
- name: Generate windows syso arm64
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
@@ -154,10 +156,20 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
|
||||
|
||||
- name: Install LLVM-MinGW for ARM64 cross-compilation
|
||||
run: |
|
||||
cd /tmp
|
||||
wget -q https://github.com/mstorsjo/llvm-mingw/releases/download/20250709/llvm-mingw-20250709-ucrt-ubuntu-22.04-x86_64.tar.xz
|
||||
echo "60cafae6474c7411174cff1d4ba21a8e46cadbaeb05a1bace306add301628337 llvm-mingw-20250709-ucrt-ubuntu-22.04-x86_64.tar.xz" | sha256sum -c
|
||||
tar -xf llvm-mingw-20250709-ucrt-ubuntu-22.04-x86_64.tar.xz
|
||||
echo "/tmp/llvm-mingw-20250709-ucrt-ubuntu-22.04-x86_64/bin" >> $GITHUB_PATH
|
||||
- name: Install goversioninfo
|
||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||
- name: Generate windows syso amd64
|
||||
run: goversioninfo -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
|
||||
- name: Generate windows syso arm64
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
|
||||
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
@@ -231,17 +243,3 @@ jobs:
|
||||
ref: ${{ env.SIGN_PIPE_VER }}
|
||||
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
|
||||
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'
|
||||
|
||||
post_on_forum:
|
||||
runs-on: ubuntu-latest
|
||||
continue-on-error: true
|
||||
needs: [trigger_signer]
|
||||
steps:
|
||||
- uses: Codixer/discourse-topic-github-release-action@v2.0.1
|
||||
with:
|
||||
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
|
||||
discourse-base-url: https://forum.netbird.io
|
||||
discourse-author-username: NetBird
|
||||
discourse-category: 17
|
||||
discourse-tags:
|
||||
releases
|
||||
|
||||
@@ -16,8 +16,6 @@ builds:
|
||||
- arm64
|
||||
- 386
|
||||
ignore:
|
||||
- goos: windows
|
||||
goarch: arm64
|
||||
- goos: windows
|
||||
goarch: arm
|
||||
- goos: windows
|
||||
|
||||
@@ -15,7 +15,7 @@ builds:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
- id: netbird-ui-windows
|
||||
- id: netbird-ui-windows-amd64
|
||||
dir: client/ui
|
||||
binary: netbird-ui
|
||||
env:
|
||||
@@ -30,6 +30,22 @@ builds:
|
||||
- -H windowsgui
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
- id: netbird-ui-windows-arm64
|
||||
dir: client/ui
|
||||
binary: netbird-ui
|
||||
env:
|
||||
- CGO_ENABLED=1
|
||||
- CC=aarch64-w64-mingw32-clang
|
||||
- CXX=aarch64-w64-mingw32-clang++
|
||||
goos:
|
||||
- windows
|
||||
goarch:
|
||||
- arm64
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
- -H windowsgui
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
archives:
|
||||
- id: linux-arch
|
||||
name_template: "{{ .ProjectName }}-linux_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
|
||||
@@ -38,7 +54,8 @@ archives:
|
||||
- id: windows-arch
|
||||
name_template: "{{ .ProjectName }}-windows_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
|
||||
builds:
|
||||
- netbird-ui-windows
|
||||
- netbird-ui-windows-amd64
|
||||
- netbird-ui-windows-arm64
|
||||
|
||||
nfpms:
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
|
||||
@@ -4,6 +4,7 @@ package android
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -112,7 +113,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
|
||||
}
|
||||
|
||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||
@@ -138,7 +139,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
@@ -235,7 +236,7 @@ func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
|
||||
return err
|
||||
}
|
||||
|
||||
dnsServer.OnUpdatedHostDNSServer(list.items)
|
||||
dnsServer.OnUpdatedHostDNSServer(slices.Clone(list.items))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,23 +1,34 @@
|
||||
package android
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
// DNSList is a wrapper of []string
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
)
|
||||
|
||||
// DNSList is a wrapper of []netip.AddrPort with default DNS port
|
||||
type DNSList struct {
|
||||
items []string
|
||||
items []netip.AddrPort
|
||||
}
|
||||
|
||||
// Add new DNS address to the collection
|
||||
func (array *DNSList) Add(s string) {
|
||||
array.items = append(array.items, s)
|
||||
// Add new DNS address to the collection, returns error if invalid
|
||||
func (array *DNSList) Add(s string) error {
|
||||
addr, err := netip.ParseAddr(s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid DNS address: %s", s)
|
||||
}
|
||||
addrPort := netip.AddrPortFrom(addr.Unmap(), dns.DefaultPort)
|
||||
array.items = append(array.items, addrPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get return an element of the collection
|
||||
// Get return an element of the collection as string
|
||||
func (array *DNSList) Get(i int) (string, error) {
|
||||
if i >= len(array.items) || i < 0 {
|
||||
return "", fmt.Errorf("out of range")
|
||||
}
|
||||
return array.items[i], nil
|
||||
return array.items[i].Addr().String(), nil
|
||||
}
|
||||
|
||||
// Size return with the size of the collection
|
||||
|
||||
@@ -3,20 +3,30 @@ package android
|
||||
import "testing"
|
||||
|
||||
func TestDNSList_Get(t *testing.T) {
|
||||
l := DNSList{
|
||||
items: make([]string, 1),
|
||||
l := DNSList{}
|
||||
|
||||
// Add a valid DNS address
|
||||
err := l.Add("8.8.8.8")
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %s", err)
|
||||
}
|
||||
|
||||
_, err := l.Get(0)
|
||||
// Test getting valid index
|
||||
addr, err := l.Get(0)
|
||||
if err != nil {
|
||||
t.Errorf("invalid error: %s", err)
|
||||
}
|
||||
if addr != "8.8.8.8" {
|
||||
t.Errorf("expected 8.8.8.8, got %s", addr)
|
||||
}
|
||||
|
||||
// Test negative index
|
||||
_, err = l.Get(-1)
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got nil")
|
||||
}
|
||||
|
||||
// Test out of bounds index
|
||||
_, err = l.Get(1)
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got nil")
|
||||
|
||||
@@ -12,14 +12,15 @@ import (
|
||||
)
|
||||
|
||||
var logoutCmd = &cobra.Command{
|
||||
Use: "logout",
|
||||
Short: "logout from the NetBird Management Service and delete peer",
|
||||
Use: "deregister",
|
||||
Aliases: []string{"logout"},
|
||||
Short: "deregister from the NetBird Management Service and delete peer",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
|
||||
ctx, cancel := context.WithTimeout(cmd.Context(), time.Second*15)
|
||||
defer cancel()
|
||||
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
@@ -44,10 +45,10 @@ var logoutCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
if _, err := daemonClient.Logout(ctx, req); err != nil {
|
||||
return fmt.Errorf("logout: %v", err)
|
||||
return fmt.Errorf("deregister: %v", err)
|
||||
}
|
||||
|
||||
cmd.Println("Logged out successfully")
|
||||
cmd.Println("Deregistered successfully")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
@@ -47,7 +47,7 @@ func init() {
|
||||
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
|
||||
serviceEnvDesc := `Sets extra environment variables for the service. ` +
|
||||
`You can specify a comma-separated list of KEY=VALUE pairs. ` +
|
||||
`E.g. --service-env LOG_LEVEL=debug,CUSTOM_VAR=value`
|
||||
`E.g. --service-env NB_LOG_LEVEL=debug,CUSTOM_VAR=value`
|
||||
|
||||
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
|
||||
reconfigureCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
|
||||
|
||||
@@ -59,8 +59,8 @@ var sshCmd = &cobra.Command{
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
pm := profilemanager.NewProfileManager()
|
||||
activeProf, err := pm.GetActiveProfile()
|
||||
sm := profilemanager.NewServiceManager(configPath)
|
||||
activeProf, err := sm.GetActiveProfileState()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get active profile: %v", err)
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
!define WEB_SITE "Netbird.io"
|
||||
!define VERSION $%APPVER%
|
||||
!define COPYRIGHT "Netbird Authors, 2022"
|
||||
!define DESCRIPTION "A WireGuard®-based mesh network that connects your devices into a single private network"
|
||||
!define DESCRIPTION "Connect your devices into a secure WireGuard-based overlay network with SSO, MFA, and granular access controls."
|
||||
!define INSTALLER_NAME "netbird-installer.exe"
|
||||
!define MAIN_APP_EXE "Netbird"
|
||||
!define ICON "ui\\assets\\netbird.ico"
|
||||
@@ -59,9 +59,15 @@ ShowInstDetails Show
|
||||
!define MUI_UNICON "${ICON}"
|
||||
!define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}"
|
||||
!define MUI_UNWELCOMEFINISHPAGE_BITMAP "${BANNER}"
|
||||
!define MUI_FINISHPAGE_RUN
|
||||
!define MUI_FINISHPAGE_RUN_TEXT "Start ${UI_APP_NAME}"
|
||||
!define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink"
|
||||
!ifndef ARCH
|
||||
!define ARCH "amd64"
|
||||
!endif
|
||||
|
||||
!if ${ARCH} == "amd64"
|
||||
!define MUI_FINISHPAGE_RUN
|
||||
!define MUI_FINISHPAGE_RUN_TEXT "Start ${UI_APP_NAME}"
|
||||
!define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink"
|
||||
!endif
|
||||
######################################################################
|
||||
|
||||
!define MUI_ABORTWARNING
|
||||
@@ -213,7 +219,15 @@ Section -MainProgram
|
||||
${INSTALL_TYPE}
|
||||
# SetOverwrite ifnewer
|
||||
SetOutPath "$INSTDIR"
|
||||
File /r "..\\dist\\netbird_windows_amd64\\"
|
||||
!ifndef ARCH
|
||||
!define ARCH "amd64"
|
||||
!endif
|
||||
|
||||
!if ${ARCH} == "arm64"
|
||||
File /r "..\\dist\\netbird_windows_arm64\\"
|
||||
!else
|
||||
File /r "..\\dist\\netbird_windows_amd64\\"
|
||||
!endif
|
||||
SectionEnd
|
||||
######################################################################
|
||||
|
||||
@@ -292,7 +306,9 @@ DetailPrint "Deleting application files..."
|
||||
Delete "$INSTDIR\${UI_APP_EXE}"
|
||||
Delete "$INSTDIR\${MAIN_APP_EXE}"
|
||||
Delete "$INSTDIR\wintun.dll"
|
||||
!if ${ARCH} == "amd64"
|
||||
Delete "$INSTDIR\opengl32.dll"
|
||||
!endif
|
||||
DetailPrint "Removing application directory..."
|
||||
RmDir /r "$INSTDIR"
|
||||
|
||||
@@ -314,8 +330,10 @@ DetailPrint "Uninstallation finished."
|
||||
SectionEnd
|
||||
|
||||
|
||||
!if ${ARCH} == "amd64"
|
||||
Function LaunchLink
|
||||
SetShellVarContext all
|
||||
SetOutPath $INSTDIR
|
||||
ShellExecAsUser::ShellExecAsUser "" "$DESKTOP\${APP_NAME}.lnk"
|
||||
FunctionEnd
|
||||
!endif
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
@@ -70,7 +71,7 @@ func (c *ConnectClient) RunOnAndroid(
|
||||
tunAdapter device.TunAdapter,
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsAddresses []string,
|
||||
dnsAddresses []netip.AddrPort,
|
||||
dnsReadyListener dns.ReadyListener,
|
||||
) error {
|
||||
// in case of non Android os these variables will be nil
|
||||
|
||||
@@ -16,7 +16,7 @@ const (
|
||||
)
|
||||
|
||||
type resolvConf struct {
|
||||
nameServers []string
|
||||
nameServers []netip.Addr
|
||||
searchDomains []string
|
||||
others []string
|
||||
}
|
||||
@@ -36,7 +36,7 @@ func parseBackupResolvConf() (*resolvConf, error) {
|
||||
func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
|
||||
rconf := &resolvConf{
|
||||
searchDomains: make([]string, 0),
|
||||
nameServers: make([]string, 0),
|
||||
nameServers: make([]netip.Addr, 0),
|
||||
others: make([]string, 0),
|
||||
}
|
||||
|
||||
@@ -94,7 +94,11 @@ func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
|
||||
if len(splitLines) != 2 {
|
||||
continue
|
||||
}
|
||||
rconf.nameServers = append(rconf.nameServers, splitLines[1])
|
||||
if addr, err := netip.ParseAddr(splitLines[1]); err == nil {
|
||||
rconf.nameServers = append(rconf.nameServers, addr.Unmap())
|
||||
} else {
|
||||
log.Warnf("invalid nameserver address in resolv.conf: %s, skipping", splitLines[1])
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -104,31 +108,3 @@ func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
|
||||
}
|
||||
return rconf, nil
|
||||
}
|
||||
|
||||
// removeFirstNbNameserver removes the given nameserver from the given file if it is in the first position
|
||||
// and writes the file back to the original location
|
||||
func removeFirstNbNameserver(filename string, nameserverIP netip.Addr) error {
|
||||
resolvConf, err := parseResolvConfFile(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse backup resolv.conf: %w", err)
|
||||
}
|
||||
content, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read %s: %w", filename, err)
|
||||
}
|
||||
|
||||
if len(resolvConf.nameServers) > 1 && resolvConf.nameServers[0] == nameserverIP.String() {
|
||||
newContent := strings.Replace(string(content), fmt.Sprintf("nameserver %s\n", nameserverIP), "", 1)
|
||||
|
||||
stat, err := os.Stat(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("stat %s: %w", filename, err)
|
||||
}
|
||||
if err := os.WriteFile(filename, []byte(newContent), stat.Mode()); err != nil {
|
||||
return fmt.Errorf("write %s: %w", filename, err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,13 +3,9 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_parseResolvConf(t *testing.T) {
|
||||
@@ -99,9 +95,13 @@ options debug
|
||||
t.Errorf("invalid parse result for search domains, expected: %v, got: %v", testCase.expectedSearch, cfg.searchDomains)
|
||||
}
|
||||
|
||||
ok = compareLists(cfg.nameServers, testCase.expectedNS)
|
||||
nsStrings := make([]string, len(cfg.nameServers))
|
||||
for i, ns := range cfg.nameServers {
|
||||
nsStrings[i] = ns.String()
|
||||
}
|
||||
ok = compareLists(nsStrings, testCase.expectedNS)
|
||||
if !ok {
|
||||
t.Errorf("invalid parse result for ns domains, expected: %v, got: %v", testCase.expectedNS, cfg.nameServers)
|
||||
t.Errorf("invalid parse result for ns domains, expected: %v, got: %v", testCase.expectedNS, nsStrings)
|
||||
}
|
||||
|
||||
ok = compareLists(cfg.others, testCase.expectedOther)
|
||||
@@ -177,86 +177,3 @@ nameserver 192.168.0.1
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveFirstNbNameserver(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
content string
|
||||
ipToRemove string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Unrelated nameservers with comments and options",
|
||||
content: `# This is a comment
|
||||
options rotate
|
||||
nameserver 1.1.1.1
|
||||
# Another comment
|
||||
nameserver 8.8.4.4
|
||||
search example.com`,
|
||||
ipToRemove: "9.9.9.9",
|
||||
expected: `# This is a comment
|
||||
options rotate
|
||||
nameserver 1.1.1.1
|
||||
# Another comment
|
||||
nameserver 8.8.4.4
|
||||
search example.com`,
|
||||
},
|
||||
{
|
||||
name: "First nameserver matches",
|
||||
content: `search example.com
|
||||
nameserver 9.9.9.9
|
||||
# oof, a comment
|
||||
nameserver 8.8.4.4
|
||||
options attempts:5`,
|
||||
ipToRemove: "9.9.9.9",
|
||||
expected: `search example.com
|
||||
# oof, a comment
|
||||
nameserver 8.8.4.4
|
||||
options attempts:5`,
|
||||
},
|
||||
{
|
||||
name: "Target IP not the first nameserver",
|
||||
// nolint:dupword
|
||||
content: `# Comment about the first nameserver
|
||||
nameserver 8.8.4.4
|
||||
# Comment before our target
|
||||
nameserver 9.9.9.9
|
||||
options timeout:2`,
|
||||
ipToRemove: "9.9.9.9",
|
||||
// nolint:dupword
|
||||
expected: `# Comment about the first nameserver
|
||||
nameserver 8.8.4.4
|
||||
# Comment before our target
|
||||
nameserver 9.9.9.9
|
||||
options timeout:2`,
|
||||
},
|
||||
{
|
||||
name: "Only nameserver matches",
|
||||
content: `options debug
|
||||
nameserver 9.9.9.9
|
||||
search localdomain`,
|
||||
ipToRemove: "9.9.9.9",
|
||||
expected: `options debug
|
||||
nameserver 9.9.9.9
|
||||
search localdomain`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
tempFile := filepath.Join(tempDir, "resolv.conf")
|
||||
err := os.WriteFile(tempFile, []byte(tc.content), 0644)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ip, err := netip.ParseAddr(tc.ipToRemove)
|
||||
require.NoError(t, err, "Failed to parse IP address")
|
||||
err = removeFirstNbNameserver(tempFile, ip)
|
||||
assert.NoError(t, err)
|
||||
|
||||
content, err := os.ReadFile(tempFile)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.expected, string(content), "The resulting content should match the expected output.")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -146,7 +146,7 @@ func isNbParamsMissing(nbSearchDomains []string, nbNameserverIP netip.Addr, rCon
|
||||
return true
|
||||
}
|
||||
|
||||
if rConf.nameServers[0] != nbNameserverIP.String() {
|
||||
if rConf.nameServers[0] != nbNameserverIP {
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ type fileConfigurator struct {
|
||||
repair *repair
|
||||
originalPerms os.FileMode
|
||||
nbNameserverIP netip.Addr
|
||||
originalNameservers []string
|
||||
originalNameservers []netip.Addr
|
||||
}
|
||||
|
||||
func newFileConfigurator() (*fileConfigurator, error) {
|
||||
@@ -70,7 +70,7 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *st
|
||||
}
|
||||
|
||||
// getOriginalNameservers returns the nameservers that were found in the original resolv.conf
|
||||
func (f *fileConfigurator) getOriginalNameservers() []string {
|
||||
func (f *fileConfigurator) getOriginalNameservers() []netip.Addr {
|
||||
return f.originalNameservers
|
||||
}
|
||||
|
||||
@@ -128,20 +128,14 @@ func (f *fileConfigurator) backup() error {
|
||||
}
|
||||
|
||||
func (f *fileConfigurator) restore() error {
|
||||
err := removeFirstNbNameserver(fileDefaultResolvConfBackupLocation, f.nbNameserverIP)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to remove netbird nameserver from %s on backup restore: %s", fileDefaultResolvConfBackupLocation, err)
|
||||
}
|
||||
|
||||
err = copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath)
|
||||
if err != nil {
|
||||
if err := copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath); err != nil {
|
||||
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
|
||||
}
|
||||
|
||||
return os.RemoveAll(fileDefaultResolvConfBackupLocation)
|
||||
}
|
||||
|
||||
func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error {
|
||||
func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress netip.Addr) error {
|
||||
resolvConf, err := parseDefaultResolvConf()
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse current resolv.conf: %w", err)
|
||||
@@ -152,16 +146,9 @@ func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Add
|
||||
return restoreResolvConfFile()
|
||||
}
|
||||
|
||||
currentDNSAddress, err := netip.ParseAddr(resolvConf.nameServers[0])
|
||||
// not a valid first nameserver -> restore
|
||||
if err != nil {
|
||||
log.Errorf("restoring unclean shutdown: parse dns address %s failed: %s", resolvConf.nameServers[0], err)
|
||||
return restoreResolvConfFile()
|
||||
}
|
||||
|
||||
// current address is still netbird's non-available dns address -> restore
|
||||
// comparing parsed addresses only, to remove ambiguity
|
||||
if currentDNSAddress.String() == storedDNSAddress.String() {
|
||||
currentDNSAddress := resolvConf.nameServers[0]
|
||||
if currentDNSAddress == storedDNSAddress {
|
||||
return restoreResolvConfFile()
|
||||
}
|
||||
|
||||
|
||||
@@ -239,7 +239,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
||||
} else if inServerAddressesArray {
|
||||
address := strings.Split(line, " : ")[1]
|
||||
if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() {
|
||||
dnsSettings.ServerIP = ip
|
||||
dnsSettings.ServerIP = ip.Unmap()
|
||||
inServerAddressesArray = false // Stop reading after finding the first IPv4 address
|
||||
}
|
||||
}
|
||||
@@ -250,7 +250,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
||||
}
|
||||
|
||||
// default to 53 port
|
||||
dnsSettings.ServerPort = defaultPort
|
||||
dnsSettings.ServerPort = DefaultPort
|
||||
|
||||
return dnsSettings, nil
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ func (t osManagerType) String() string {
|
||||
|
||||
type restoreHostManager interface {
|
||||
hostManager
|
||||
restoreUncleanShutdownDNS(*netip.Addr) error
|
||||
restoreUncleanShutdownDNS(netip.Addr) error
|
||||
}
|
||||
|
||||
func newHostManager(wgInterface string) (hostManager, error) {
|
||||
@@ -130,8 +130,9 @@ func checkStub() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
systemdResolvedAddr := netip.AddrFrom4([4]byte{127, 0, 0, 53}) // 127.0.0.53
|
||||
for _, ns := range rConf.nameServers {
|
||||
if ns == "127.0.0.53" {
|
||||
if ns == systemdResolvedAddr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -216,7 +216,7 @@ func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
|
||||
return fmt.Errorf("adding dns setup for all failed: %w", err)
|
||||
}
|
||||
r.routingAll = true
|
||||
log.Infof("configured %s:53 as main DNS forwarder for this peer", ip)
|
||||
log.Infof("configured %s:%d as main DNS forwarder for this peer", ip, DefaultPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,38 +1,31 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type hostsDNSHolder struct {
|
||||
unprotectedDNSList map[string]struct{}
|
||||
unprotectedDNSList map[netip.AddrPort]struct{}
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
func newHostsDNSHolder() *hostsDNSHolder {
|
||||
return &hostsDNSHolder{
|
||||
unprotectedDNSList: make(map[string]struct{}),
|
||||
unprotectedDNSList: make(map[netip.AddrPort]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *hostsDNSHolder) set(list []string) {
|
||||
func (h *hostsDNSHolder) set(list []netip.AddrPort) {
|
||||
h.mutex.Lock()
|
||||
h.unprotectedDNSList = make(map[string]struct{})
|
||||
for _, dns := range list {
|
||||
dnsAddr, err := h.normalizeAddress(dns)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
h.unprotectedDNSList[dnsAddr] = struct{}{}
|
||||
h.unprotectedDNSList = make(map[netip.AddrPort]struct{})
|
||||
for _, addrPort := range list {
|
||||
h.unprotectedDNSList[addrPort] = struct{}{}
|
||||
}
|
||||
h.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (h *hostsDNSHolder) get() map[string]struct{} {
|
||||
func (h *hostsDNSHolder) get() map[netip.AddrPort]struct{} {
|
||||
h.mutex.RLock()
|
||||
l := h.unprotectedDNSList
|
||||
h.mutex.RUnlock()
|
||||
@@ -40,24 +33,10 @@ func (h *hostsDNSHolder) get() map[string]struct{} {
|
||||
}
|
||||
|
||||
//nolint:unused
|
||||
func (h *hostsDNSHolder) isContain(upstream string) bool {
|
||||
func (h *hostsDNSHolder) contains(upstream netip.AddrPort) bool {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
_, ok := h.unprotectedDNSList[upstream]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (h *hostsDNSHolder) normalizeAddress(addr string) (string, error) {
|
||||
a, err := netip.ParseAddr(addr)
|
||||
if err != nil {
|
||||
log.Errorf("invalid upstream IP address: %s, error: %s", addr, err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
if a.Is4() {
|
||||
return fmt.Sprintf("%s:53", addr), nil
|
||||
} else {
|
||||
return fmt.Sprintf("[%s]:53", addr), nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ func (m *MockServer) DnsIP() netip.Addr {
|
||||
return netip.MustParseAddr("100.10.254.255")
|
||||
}
|
||||
|
||||
func (m *MockServer) OnUpdatedHostDNSServer(strings []string) {
|
||||
func (m *MockServer) OnUpdatedHostDNSServer(addrs []netip.AddrPort) {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
@@ -245,7 +245,7 @@ func (n *networkManagerDbusConfigurator) deleteConnectionSettings() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *networkManagerDbusConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
|
||||
func (n *networkManagerDbusConfigurator) restoreUncleanShutdownDNS(netip.Addr) error {
|
||||
if err := n.restoreHostDNS(); err != nil {
|
||||
return fmt.Errorf("restoring dns via network-manager: %w", err)
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ type resolvconf struct {
|
||||
implType resolvconfType
|
||||
|
||||
originalSearchDomains []string
|
||||
originalNameServers []string
|
||||
originalNameServers []netip.Addr
|
||||
othersConfigs []string
|
||||
}
|
||||
|
||||
@@ -110,7 +110,7 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *stateman
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *resolvconf) getOriginalNameservers() []string {
|
||||
func (r *resolvconf) getOriginalNameservers() []netip.Addr {
|
||||
return r.originalNameServers
|
||||
}
|
||||
|
||||
@@ -158,7 +158,7 @@ func (r *resolvconf) applyConfig(content bytes.Buffer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *resolvconf) restoreUncleanShutdownDNS(*netip.Addr) error {
|
||||
func (r *resolvconf) restoreUncleanShutdownDNS(netip.Addr) error {
|
||||
if err := r.restoreHostDNS(); err != nil {
|
||||
return fmt.Errorf("restoring dns for interface %s: %w", r.ifaceName, err)
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ type Server interface {
|
||||
Stop()
|
||||
DnsIP() netip.Addr
|
||||
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
||||
OnUpdatedHostDNSServer(strings []string)
|
||||
OnUpdatedHostDNSServer(addrs []netip.AddrPort)
|
||||
SearchDomains() []string
|
||||
ProbeAvailability()
|
||||
}
|
||||
@@ -55,7 +55,7 @@ type nsGroupsByDomain struct {
|
||||
// hostManagerWithOriginalNS extends the basic hostManager interface
|
||||
type hostManagerWithOriginalNS interface {
|
||||
hostManager
|
||||
getOriginalNameservers() []string
|
||||
getOriginalNameservers() []netip.Addr
|
||||
}
|
||||
|
||||
// DefaultServer dns server object
|
||||
@@ -136,7 +136,7 @@ func NewDefaultServer(
|
||||
func NewDefaultServerPermanentUpstream(
|
||||
ctx context.Context,
|
||||
wgInterface WGIface,
|
||||
hostsDnsList []string,
|
||||
hostsDnsList []netip.AddrPort,
|
||||
config nbdns.Config,
|
||||
listener listener.NetworkChangeListener,
|
||||
statusRecorder *peer.Status,
|
||||
@@ -144,6 +144,7 @@ func NewDefaultServerPermanentUpstream(
|
||||
) *DefaultServer {
|
||||
log.Debugf("host dns address list is: %v", hostsDnsList)
|
||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
|
||||
|
||||
ds.hostsDNSHolder.set(hostsDnsList)
|
||||
ds.permanent = true
|
||||
ds.addHostRootZone()
|
||||
@@ -340,7 +341,7 @@ func (s *DefaultServer) disableDNS() error {
|
||||
|
||||
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
||||
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
||||
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
|
||||
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []netip.AddrPort) {
|
||||
s.hostsDNSHolder.set(hostsDnsList)
|
||||
|
||||
// Check if there's any root handler
|
||||
@@ -461,7 +462,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
|
||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
||||
|
||||
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
|
||||
if s.service.RuntimePort() != DefaultPort && !s.hostManager.supportCustomPort() {
|
||||
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
|
||||
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
|
||||
s.currentConfig.RouteAll = false
|
||||
@@ -581,14 +582,13 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
||||
}
|
||||
|
||||
for _, ns := range originalNameservers {
|
||||
if ns == config.ServerIP.String() {
|
||||
if ns == config.ServerIP {
|
||||
log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, config.ServerIP)
|
||||
continue
|
||||
}
|
||||
|
||||
ns = formatAddr(ns, defaultPort)
|
||||
|
||||
handler.upstreamServers = append(handler.upstreamServers, ns)
|
||||
addrPort := netip.AddrPortFrom(ns, DefaultPort)
|
||||
handler.upstreamServers = append(handler.upstreamServers, addrPort)
|
||||
}
|
||||
handler.deactivate = func(error) { /* always active */ }
|
||||
handler.reactivate = func() { /* always active */ }
|
||||
@@ -695,7 +695,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
|
||||
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
|
||||
continue
|
||||
}
|
||||
handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns))
|
||||
handler.upstreamServers = append(handler.upstreamServers, ns.AddrPort())
|
||||
}
|
||||
|
||||
if len(handler.upstreamServers) == 0 {
|
||||
@@ -770,18 +770,6 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
||||
s.dnsMuxMap = muxUpdateMap
|
||||
}
|
||||
|
||||
func getNSHostPort(ns nbdns.NameServer) string {
|
||||
return formatAddr(ns.IP.String(), ns.Port)
|
||||
}
|
||||
|
||||
// formatAddr formats a nameserver address with port, handling IPv6 addresses properly
|
||||
func formatAddr(address string, port int) string {
|
||||
if ip, err := netip.ParseAddr(address); err == nil && ip.Is6() {
|
||||
return fmt.Sprintf("[%s]:%d", address, port)
|
||||
}
|
||||
return fmt.Sprintf("%s:%d", address, port)
|
||||
}
|
||||
|
||||
// upstreamCallbacks returns two functions, the first one is used to deactivate
|
||||
// the upstream resolver from the configuration, the second one is used to
|
||||
// reactivate it. Not allowed to call reactivate before deactivate.
|
||||
@@ -879,10 +867,7 @@ func (s *DefaultServer) addHostRootZone() {
|
||||
return
|
||||
}
|
||||
|
||||
handler.upstreamServers = make([]string, 0)
|
||||
for k := range hostDNSServers {
|
||||
handler.upstreamServers = append(handler.upstreamServers, k)
|
||||
}
|
||||
handler.upstreamServers = maps.Keys(hostDNSServers)
|
||||
handler.deactivate = func(error) {}
|
||||
handler.reactivate = func() {}
|
||||
|
||||
@@ -893,9 +878,9 @@ func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
|
||||
var states []peer.NSGroupState
|
||||
|
||||
for _, group := range groups {
|
||||
var servers []string
|
||||
var servers []netip.AddrPort
|
||||
for _, ns := range group.NameServers {
|
||||
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
|
||||
servers = append(servers, ns.AddrPort())
|
||||
}
|
||||
|
||||
state := peer.NSGroupState{
|
||||
@@ -927,7 +912,7 @@ func (s *DefaultServer) updateNSState(nsGroup *nbdns.NameServerGroup, err error,
|
||||
func generateGroupKey(nsGroup *nbdns.NameServerGroup) string {
|
||||
var servers []string
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
|
||||
servers = append(servers, ns.AddrPort().String())
|
||||
}
|
||||
return fmt.Sprintf("%v_%v", servers, nsGroup.Domains)
|
||||
}
|
||||
|
||||
@@ -97,9 +97,9 @@ func init() {
|
||||
}
|
||||
|
||||
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
|
||||
var srvs []string
|
||||
var srvs []netip.AddrPort
|
||||
for _, srv := range servers {
|
||||
srvs = append(srvs, getNSHostPort(srv))
|
||||
srvs = append(srvs, srv.AddrPort())
|
||||
}
|
||||
return &upstreamResolverBase{
|
||||
domain: domain,
|
||||
@@ -705,7 +705,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
||||
}
|
||||
defer wgIFace.Close()
|
||||
|
||||
var dnsList []string
|
||||
var dnsList []netip.AddrPort
|
||||
dnsConfig := nbdns.Config{}
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, peer.NewRecorder("mgm"), false)
|
||||
err = dnsServer.Initialize()
|
||||
@@ -715,7 +715,8 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
||||
}
|
||||
defer dnsServer.Stop()
|
||||
|
||||
dnsServer.OnUpdatedHostDNSServer([]string{"8.8.8.8"})
|
||||
addrPort := netip.MustParseAddrPort("8.8.8.8:53")
|
||||
dnsServer.OnUpdatedHostDNSServer([]netip.AddrPort{addrPort})
|
||||
|
||||
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
||||
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||
@@ -731,7 +732,8 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
|
||||
}
|
||||
defer wgIFace.Close()
|
||||
dnsConfig := nbdns.Config{}
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
|
||||
addrPort := netip.MustParseAddrPort("8.8.8.8:53")
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []netip.AddrPort{addrPort}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("failed to initialize DNS server: %v", err)
|
||||
@@ -823,7 +825,8 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||
}
|
||||
defer wgIFace.Close()
|
||||
dnsConfig := nbdns.Config{}
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
|
||||
addrPort := netip.MustParseAddrPort("8.8.8.8:53")
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []netip.AddrPort{addrPort}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("failed to initialize DNS server: %v", err)
|
||||
@@ -2053,56 +2056,3 @@ func TestLocalResolverPriorityConstants(t *testing.T) {
|
||||
assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal")
|
||||
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
|
||||
}
|
||||
|
||||
func TestFormatAddr(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
address string
|
||||
port int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "IPv4 address",
|
||||
address: "8.8.8.8",
|
||||
port: 53,
|
||||
expected: "8.8.8.8:53",
|
||||
},
|
||||
{
|
||||
name: "IPv4 address with custom port",
|
||||
address: "1.1.1.1",
|
||||
port: 5353,
|
||||
expected: "1.1.1.1:5353",
|
||||
},
|
||||
{
|
||||
name: "IPv6 address",
|
||||
address: "fd78:94bf:7df8::1",
|
||||
port: 53,
|
||||
expected: "[fd78:94bf:7df8::1]:53",
|
||||
},
|
||||
{
|
||||
name: "IPv6 address with custom port",
|
||||
address: "2001:db8::1",
|
||||
port: 5353,
|
||||
expected: "[2001:db8::1]:5353",
|
||||
},
|
||||
{
|
||||
name: "IPv6 localhost",
|
||||
address: "::1",
|
||||
port: 53,
|
||||
expected: "[::1]:53",
|
||||
},
|
||||
{
|
||||
name: "Invalid address treated as hostname",
|
||||
address: "dns.example.com",
|
||||
port: 53,
|
||||
expected: "dns.example.com:53",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := formatAddr(tt.address, tt.port)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPort = 53
|
||||
DefaultPort = 53
|
||||
)
|
||||
|
||||
type service interface {
|
||||
|
||||
@@ -122,7 +122,7 @@ func (s *serviceViaListener) RuntimePort() int {
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
if s.ebpfService != nil {
|
||||
return defaultPort
|
||||
return DefaultPort
|
||||
} else {
|
||||
return int(s.listenPort)
|
||||
}
|
||||
@@ -148,9 +148,9 @@ func (s *serviceViaListener) evalListenAddress() (netip.Addr, uint16, error) {
|
||||
return s.customAddr.Addr(), s.customAddr.Port(), nil
|
||||
}
|
||||
|
||||
ip, ok := s.testFreePort(defaultPort)
|
||||
ip, ok := s.testFreePort(DefaultPort)
|
||||
if ok {
|
||||
return ip, defaultPort, nil
|
||||
return ip, DefaultPort, nil
|
||||
}
|
||||
|
||||
ebpfSrv, port, ok := s.tryToUseeBPF()
|
||||
|
||||
@@ -33,7 +33,7 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
|
||||
dnsMux: dns.NewServeMux(),
|
||||
|
||||
runtimeIP: lastIP,
|
||||
runtimePort: defaultPort,
|
||||
runtimePort: DefaultPort,
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -235,7 +235,7 @@ func (s *systemdDbusConfigurator) callLinkMethod(method string, value any) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemdDbusConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
|
||||
func (s *systemdDbusConfigurator) restoreUncleanShutdownDNS(netip.Addr) error {
|
||||
if err := s.restoreHostDNS(); err != nil {
|
||||
return fmt.Errorf("restoring dns via systemd: %w", err)
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ func (s *ShutdownState) Cleanup() error {
|
||||
return fmt.Errorf("create previous host manager: %w", err)
|
||||
}
|
||||
|
||||
if err := manager.restoreUncleanShutdownDNS(&s.DNSAddress); err != nil {
|
||||
if err := manager.restoreUncleanShutdownDNS(s.DNSAddress); err != nil {
|
||||
return fmt.Errorf("restore unclean shutdown dns: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -48,7 +49,7 @@ type upstreamResolverBase struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
upstreamClient upstreamClient
|
||||
upstreamServers []string
|
||||
upstreamServers []netip.AddrPort
|
||||
domain string
|
||||
disabled bool
|
||||
failsCount atomic.Int32
|
||||
@@ -79,17 +80,20 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d
|
||||
|
||||
// String returns a string representation of the upstream resolver
|
||||
func (u *upstreamResolverBase) String() string {
|
||||
return fmt.Sprintf("upstream %v", u.upstreamServers)
|
||||
return fmt.Sprintf("upstream %s", u.upstreamServers)
|
||||
}
|
||||
|
||||
// ID returns the unique handler ID
|
||||
func (u *upstreamResolverBase) ID() types.HandlerID {
|
||||
servers := slices.Clone(u.upstreamServers)
|
||||
slices.Sort(servers)
|
||||
slices.SortFunc(servers, func(a, b netip.AddrPort) int { return a.Compare(b) })
|
||||
|
||||
hash := sha256.New()
|
||||
hash.Write([]byte(u.domain + ":"))
|
||||
hash.Write([]byte(strings.Join(servers, ",")))
|
||||
for _, s := range servers {
|
||||
hash.Write([]byte(s.String()))
|
||||
hash.Write([]byte("|"))
|
||||
}
|
||||
return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
|
||||
}
|
||||
|
||||
@@ -130,7 +134,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
func() {
|
||||
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
|
||||
defer cancel()
|
||||
rm, t, err = u.upstreamClient.exchange(ctx, upstream, r)
|
||||
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
@@ -197,7 +201,7 @@ func (u *upstreamResolverBase) checkUpstreamFails(err error) {
|
||||
proto.SystemEvent_DNS,
|
||||
"All upstream servers failed (fail count exceeded)",
|
||||
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
|
||||
map[string]string{"upstreams": strings.Join(u.upstreamServers, ", ")},
|
||||
map[string]string{"upstreams": u.upstreamServersString()},
|
||||
// TODO add domain meta
|
||||
)
|
||||
}
|
||||
@@ -258,7 +262,7 @@ func (u *upstreamResolverBase) ProbeAvailability() {
|
||||
proto.SystemEvent_DNS,
|
||||
"All upstream servers failed (probe failed)",
|
||||
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
|
||||
map[string]string{"upstreams": strings.Join(u.upstreamServers, ", ")},
|
||||
map[string]string{"upstreams": u.upstreamServersString()},
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -278,7 +282,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
|
||||
operation := func() error {
|
||||
select {
|
||||
case <-u.ctx.Done():
|
||||
return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServers))
|
||||
return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServersString()))
|
||||
default:
|
||||
}
|
||||
|
||||
@@ -291,7 +295,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
|
||||
}
|
||||
}
|
||||
|
||||
log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServers, exponentialBackOff.NextBackOff())
|
||||
log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServersString(), exponentialBackOff.NextBackOff())
|
||||
return fmt.Errorf("upstream check call error")
|
||||
}
|
||||
|
||||
@@ -301,7 +305,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServers)
|
||||
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
|
||||
u.failsCount.Store(0)
|
||||
u.successCount.Add(1)
|
||||
u.reactivate()
|
||||
@@ -331,13 +335,21 @@ func (u *upstreamResolverBase) disable(err error) {
|
||||
go u.waitUntilResponse()
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) testNameserver(server string, timeout time.Duration) error {
|
||||
func (u *upstreamResolverBase) upstreamServersString() string {
|
||||
var servers []string
|
||||
for _, server := range u.upstreamServers {
|
||||
servers = append(servers, server.String())
|
||||
}
|
||||
return strings.Join(servers, ", ")
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) testNameserver(server netip.AddrPort, timeout time.Duration) error {
|
||||
ctx, cancel := context.WithTimeout(u.ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
|
||||
|
||||
_, _, err := u.upstreamClient.exchange(ctx, server, r)
|
||||
_, _, err := u.upstreamClient.exchange(ctx, server.String(), r)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -79,8 +79,8 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri
|
||||
}
|
||||
|
||||
func (u *upstreamResolver) isLocalResolver(upstream string) bool {
|
||||
if u.hostsDNSHolder.isContain(upstream) {
|
||||
return true
|
||||
if addrPort, err := netip.ParseAddrPort(upstream); err == nil {
|
||||
return u.hostsDNSHolder.contains(addrPort)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -62,6 +62,8 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
upstreamIP, err := netip.ParseAddr(upstreamHost)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse upstream host %s: %s", upstreamHost, err)
|
||||
} else {
|
||||
upstreamIP = upstreamIP.Unmap()
|
||||
}
|
||||
if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() {
|
||||
log.Debugf("using private client to query upstream: %s", upstream)
|
||||
|
||||
@@ -59,7 +59,14 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".")
|
||||
resolver.upstreamServers = testCase.InputServers
|
||||
// Convert test servers to netip.AddrPort
|
||||
var servers []netip.AddrPort
|
||||
for _, server := range testCase.InputServers {
|
||||
if addrPort, err := netip.ParseAddrPort(server); err == nil {
|
||||
servers = append(servers, netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()))
|
||||
}
|
||||
}
|
||||
resolver.upstreamServers = servers
|
||||
resolver.upstreamTimeout = testCase.timeout
|
||||
if testCase.cancelCTX {
|
||||
cancel()
|
||||
@@ -128,7 +135,8 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
||||
reactivatePeriod: reactivatePeriod,
|
||||
failsTillDeact: failsTillDeact,
|
||||
}
|
||||
resolver.upstreamServers = []string{"0.0.0.0:-1"}
|
||||
addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection
|
||||
resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())}
|
||||
resolver.failsTillDeact = 0
|
||||
resolver.reactivatePeriod = time.Microsecond * 100
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
@@ -13,7 +15,7 @@ type MobileDependency struct {
|
||||
TunAdapter device.TunAdapter
|
||||
IFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
NetworkChangeListener listener.NetworkChangeListener
|
||||
HostDNSAddresses []string
|
||||
HostDNSAddresses []netip.AddrPort
|
||||
DnsReadyListener dns.ReadyListener
|
||||
|
||||
// iOS only
|
||||
|
||||
@@ -140,7 +140,7 @@ type RosenpassState struct {
|
||||
// whether it's enabled, and the last error message encountered during probing.
|
||||
type NSGroupState struct {
|
||||
ID string
|
||||
Servers []string
|
||||
Servers []netip.AddrPort
|
||||
Domains []string
|
||||
Enabled bool
|
||||
Error error
|
||||
|
||||
@@ -593,17 +593,9 @@ func update(input ConfigInput) (*Config, error) {
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// GetConfig read config file and return with Config. Errors out if it does not exist
|
||||
func GetConfig(configPath string) (*Config, error) {
|
||||
if !fileExists(configPath) {
|
||||
return nil, fmt.Errorf("config file %s does not exist", configPath)
|
||||
}
|
||||
|
||||
config := &Config{}
|
||||
if _, err := util.ReadJson(configPath, config); err != nil {
|
||||
return nil, fmt.Errorf("failed to read config file %s: %w", configPath, err)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
return readConfig(configPath, false)
|
||||
}
|
||||
|
||||
// UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain.
|
||||
@@ -695,6 +687,11 @@ func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
|
||||
|
||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||
func ReadConfig(configPath string) (*Config, error) {
|
||||
return readConfig(configPath, true)
|
||||
}
|
||||
|
||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||
func readConfig(configPath string, createIfMissing bool) (*Config, error) {
|
||||
if fileExists(configPath) {
|
||||
err := util.EnforcePermission(configPath)
|
||||
if err != nil {
|
||||
@@ -715,6 +712,8 @@ func ReadConfig(configPath string) (*Config, error) {
|
||||
}
|
||||
|
||||
return config, nil
|
||||
} else if !createIfMissing {
|
||||
return nil, fmt.Errorf("config file %s does not exist", configPath)
|
||||
}
|
||||
|
||||
cfg, err := createNewConfig(ConfigInput{ConfigPath: configPath})
|
||||
|
||||
@@ -16,19 +16,21 @@
|
||||
<StandardDirectory Id="ProgramFiles64Folder">
|
||||
<Directory Id="NetbirdInstallDir" Name="Netbird">
|
||||
<Component Id="NetbirdFiles" Guid="db3165de-cc6e-4922-8396-9d892950e23e" Bitness="always64">
|
||||
<File ProcessorArchitecture="x64" Source=".\dist\netbird_windows_amd64\netbird.exe" KeyPath="yes" />
|
||||
<File ProcessorArchitecture="x64" Source=".\dist\netbird_windows_amd64\netbird-ui.exe">
|
||||
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\netbird.exe" KeyPath="yes" />
|
||||
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\netbird-ui.exe">
|
||||
<Shortcut Id="NetbirdDesktopShortcut" Directory="DesktopFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" />
|
||||
<Shortcut Id="NetbirdStartMenuShortcut" Directory="StartMenuFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" />
|
||||
</File>
|
||||
<File ProcessorArchitecture="x64" Source=".\dist\netbird_windows_amd64\wintun.dll" />
|
||||
<File ProcessorArchitecture="x64" Source=".\dist\netbird_windows_amd64\opengl32.dll" />
|
||||
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\wintun.dll" />
|
||||
<?if $(var.ArchSuffix) = "amd64" ?>
|
||||
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\opengl32.dll" />
|
||||
<?endif ?>
|
||||
|
||||
<ServiceInstall
|
||||
Id="NetBirdService"
|
||||
Name="NetBird"
|
||||
DisplayName="NetBird"
|
||||
Description="A WireGuard-based mesh network that connects your devices into a single private network."
|
||||
Description="Connect your devices into a secure WireGuard-based overlay network with SSO, MFA and granular access controls."
|
||||
Start="auto" Type="ownProcess"
|
||||
ErrorControl="normal"
|
||||
Account="LocalSystem"
|
||||
|
||||
@@ -1197,8 +1197,14 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
|
||||
if dnsState.Error != nil {
|
||||
err = dnsState.Error.Error()
|
||||
}
|
||||
|
||||
var servers []string
|
||||
for _, server := range dnsState.Servers {
|
||||
servers = append(servers, server.String())
|
||||
}
|
||||
|
||||
pbDnsState := &proto.NSGroupState{
|
||||
Servers: dnsState.Servers,
|
||||
Servers: servers,
|
||||
Domains: dnsState.Domains,
|
||||
Enabled: dnsState.Enabled,
|
||||
Error: err,
|
||||
|
||||
@@ -46,7 +46,7 @@ func (s *serviceClient) showProfilesUI() {
|
||||
widget.NewLabel(""), // profile name
|
||||
layout.NewSpacer(),
|
||||
widget.NewButton("Select", nil),
|
||||
widget.NewButton("Logout", nil),
|
||||
widget.NewButton("Deregister", nil),
|
||||
widget.NewButton("Remove", nil),
|
||||
)
|
||||
},
|
||||
@@ -128,7 +128,7 @@ func (s *serviceClient) showProfilesUI() {
|
||||
}
|
||||
|
||||
logoutBtn.Show()
|
||||
logoutBtn.SetText("Logout")
|
||||
logoutBtn.SetText("Deregister")
|
||||
logoutBtn.OnTapped = func() {
|
||||
s.handleProfileLogout(profile.Name, refresh)
|
||||
}
|
||||
@@ -143,7 +143,7 @@ func (s *serviceClient) showProfilesUI() {
|
||||
if !confirm {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
err = s.removeProfile(profile.Name)
|
||||
if err != nil {
|
||||
log.Errorf("failed to remove profile: %v", err)
|
||||
@@ -334,27 +334,27 @@ func (s *serviceClient) getProfiles() ([]Profile, error) {
|
||||
|
||||
func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback func()) {
|
||||
dialog.ShowConfirm(
|
||||
"Logout",
|
||||
fmt.Sprintf("Are you sure you want to logout from '%s'?", profileName),
|
||||
"Deregister",
|
||||
fmt.Sprintf("Are you sure you want to deregister from '%s'?", profileName),
|
||||
func(confirm bool) {
|
||||
if !confirm {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get service client: %v", err)
|
||||
dialog.ShowError(fmt.Errorf("failed to connect to service"), s.wProfiles)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get current user: %v", err)
|
||||
dialog.ShowError(fmt.Errorf("failed to get current user"), s.wProfiles)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
username := currUser.Username
|
||||
_, err = conn.Logout(s.ctx, &proto.LogoutRequest{
|
||||
ProfileName: &profileName,
|
||||
@@ -362,16 +362,16 @@ func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("logout failed: %v", err)
|
||||
dialog.ShowError(fmt.Errorf("logout failed"), s.wProfiles)
|
||||
dialog.ShowError(fmt.Errorf("deregister failed"), s.wProfiles)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
dialog.ShowInformation(
|
||||
"Logged Out",
|
||||
fmt.Sprintf("Successfully logged out from '%s'", profileName),
|
||||
"Deregistered",
|
||||
fmt.Sprintf("Successfully deregistered from '%s'", profileName),
|
||||
s.wProfiles,
|
||||
)
|
||||
|
||||
|
||||
refreshCallback()
|
||||
},
|
||||
s.wProfiles,
|
||||
@@ -602,7 +602,7 @@ func (p *profileMenu) refresh() {
|
||||
|
||||
// Add Logout menu item
|
||||
ctx2, cancel2 := context.WithCancel(context.Background())
|
||||
logoutItem := p.profileMenuItem.AddSubMenuItem("Logout", "")
|
||||
logoutItem := p.profileMenuItem.AddSubMenuItem("Deregister", "")
|
||||
p.logoutSubItem = &subItem{logoutItem, ctx2, cancel2}
|
||||
|
||||
go func() {
|
||||
@@ -616,9 +616,9 @@ func (p *profileMenu) refresh() {
|
||||
}
|
||||
if err := p.eventHandler.logout(p.ctx); err != nil {
|
||||
log.Errorf("logout failed: %v", err)
|
||||
p.app.SendNotification(fyne.NewNotification("Error", "Failed to logout"))
|
||||
p.app.SendNotification(fyne.NewNotification("Error", "Failed to deregister"))
|
||||
} else {
|
||||
p.app.SendNotification(fyne.NewNotification("Success", "Logged out successfully"))
|
||||
p.app.SendNotification(fyne.NewNotification("Success", "Deregistered successfully"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,6 +102,11 @@ func (n *NameServer) IsEqual(other *NameServer) bool {
|
||||
other.Port == n.Port
|
||||
}
|
||||
|
||||
// AddrPort returns the nameserver as a netip.AddrPort
|
||||
func (n *NameServer) AddrPort() netip.AddrPort {
|
||||
return netip.AddrPortFrom(n.IP, uint16(n.Port))
|
||||
}
|
||||
|
||||
// ParseNameServerURL parses a nameserver url in the format <type>://<ip>:<port>, e.g., udp://1.1.1.1:53
|
||||
func ParseNameServerURL(nsURL string) (NameServer, error) {
|
||||
parsedURL, err := url.Parse(nsURL)
|
||||
|
||||
@@ -571,19 +571,19 @@ func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain
|
||||
for i := 0; i < 2; i++ {
|
||||
accountId := xid.New().String()
|
||||
|
||||
_, err := am.Store.GetAccount(ctx, accountId)
|
||||
statusErr, _ := status.FromError(err)
|
||||
switch {
|
||||
case err == nil:
|
||||
log.WithContext(ctx).Warnf("an account with ID already exists, retrying...")
|
||||
continue
|
||||
case statusErr.Type() == status.NotFound:
|
||||
newAccount := newAccountWithId(ctx, accountId, userID, domain, am.disableDefaultPolicy)
|
||||
am.StoreEvent(ctx, userID, newAccount.Id, accountId, activity.AccountCreated, nil)
|
||||
return newAccount, nil
|
||||
default:
|
||||
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, accountId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if exists {
|
||||
log.WithContext(ctx).Warnf("an account with ID already exists, retrying...")
|
||||
continue
|
||||
}
|
||||
|
||||
newAccount := newAccountWithId(ctx, accountId, userID, domain, am.disableDefaultPolicy)
|
||||
am.StoreEvent(ctx, userID, newAccount.Id, accountId, activity.AccountCreated, nil)
|
||||
return newAccount, nil
|
||||
}
|
||||
|
||||
return nil, status.Errorf(status.Internal, "error while creating new account")
|
||||
@@ -1143,21 +1143,29 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context,
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
|
||||
defer unlockAccount()
|
||||
|
||||
newUser := types.NewRegularUser(userAuth.UserId)
|
||||
newUser.AccountID = domainAccountID
|
||||
err := am.Store.SaveUser(ctx, newUser)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
|
||||
if err != nil {
|
||||
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||
newUser := types.NewRegularUser(userAuth.UserId)
|
||||
newUser.AccountID = domainAccountID
|
||||
err = am.Store.SaveUser(ctx, newUser)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, domainAccountID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, nil)
|
||||
return domainAccountID, nil
|
||||
}
|
||||
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, domainAccountID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, nil)
|
||||
|
||||
return domainAccountID, nil
|
||||
return user.AccountID, nil
|
||||
}
|
||||
|
||||
// redeemInvite checks whether user has been invited and redeems the invite
|
||||
|
||||
@@ -3453,6 +3453,50 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_AddNewUserToDomainAccount(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
userAuth nbcontext.UserAuth
|
||||
expectedRole types.UserRole
|
||||
}{
|
||||
{
|
||||
name: "existing user",
|
||||
userAuth: nbcontext.UserAuth{
|
||||
Domain: "example.com",
|
||||
UserId: "user1",
|
||||
},
|
||||
expectedRole: types.UserRoleOwner,
|
||||
},
|
||||
{
|
||||
name: "new user",
|
||||
userAuth: nbcontext.UserAuth{
|
||||
Domain: "example.com",
|
||||
UserId: "user2",
|
||||
},
|
||||
expectedRole: types.UserRoleUser,
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), "user1", "example.com")
|
||||
require.NoError(t, err, "create init user failed")
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
userAccountID, err := manager.addNewUserToDomainAccount(context.Background(), accountID, tc.userAuth)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, accountID, userAccountID)
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, tc.userAuth.UserId)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, accountID, user.AccountID)
|
||||
assert.Equal(t, tc.expectedRole, user.Role)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -913,6 +913,7 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage)
|
||||
|
||||
func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
|
||||
log.WithContext(ctx).Debugf("Logout request from peer [%s]", req.WgPubKey)
|
||||
start := time.Now()
|
||||
|
||||
empty := &proto.Empty{}
|
||||
peerKey, err := s.parseRequest(ctx, req, empty)
|
||||
@@ -944,7 +945,7 @@ func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*
|
||||
|
||||
s.accountManager.BufferUpdateAccountPeers(ctx, peer.AccountID)
|
||||
|
||||
log.WithContext(ctx).Infof("peer %s logged out successfully", peerKey.String())
|
||||
log.WithContext(ctx).Debugf("peer %s logged out successfully after %s", peerKey.String(), time.Since(start))
|
||||
|
||||
return &proto.Empty{}, nil
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func GetColumnName(db *gorm.DB, column string) string {
|
||||
@@ -466,7 +467,7 @@ func MigrateJsonToTable[T any](ctx context.Context, db *gorm.DB, columnName stri
|
||||
}
|
||||
|
||||
for _, value := range data {
|
||||
if err := tx.Create(
|
||||
if err := tx.Clauses(clause.OnConflict{DoNothing: true}).Create(
|
||||
mapperFunc(row["account_id"].(string), row["id"].(string), value),
|
||||
).Error; err != nil {
|
||||
return fmt.Errorf("failed to insert id %v: %w", row["id"], err)
|
||||
|
||||
@@ -609,7 +609,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
newPeer.DNSLabel = freeLabel
|
||||
newPeer.IP = freeIP
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
|
||||
defer func() {
|
||||
if unlock != nil {
|
||||
unlock()
|
||||
|
||||
@@ -1476,8 +1476,9 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
||||
|
||||
func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
engine := os.Getenv("NETBIRD_STORE_ENGINE")
|
||||
if engine == "sqlite" || engine == "" {
|
||||
t.Skip("Skipping test because sqlite test store is not respecting foreign keys")
|
||||
if engine == "sqlite" || engine == "mysql" || engine == "" {
|
||||
// we intentionally disabled foreign keys in mysql
|
||||
t.Skip("Skipping test because store is not respecting foreign keys")
|
||||
}
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
@@ -76,7 +77,12 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
||||
conns = runtime.NumCPU()
|
||||
}
|
||||
|
||||
if storeEngine == types.SqliteStoreEngine {
|
||||
switch storeEngine {
|
||||
case types.MysqlStoreEngine:
|
||||
if err := db.Exec("SET GLOBAL FOREIGN_KEY_CHECKS = 0").Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case types.SqliteStoreEngine:
|
||||
if err == nil {
|
||||
log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1")
|
||||
}
|
||||
@@ -142,14 +148,16 @@ func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
|
||||
func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||
log.WithContext(ctx).Tracef("acquiring write lock for ID %s", uniqueID)
|
||||
|
||||
start := time.Now()
|
||||
startWait := time.Now()
|
||||
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
|
||||
mtx := value.(*sync.RWMutex)
|
||||
mtx.Lock()
|
||||
log.WithContext(ctx).Tracef("waiting to acquire write lock for ID %s in %v", uniqueID, time.Since(startWait))
|
||||
startHold := time.Now()
|
||||
|
||||
unlock = func() {
|
||||
mtx.Unlock()
|
||||
log.WithContext(ctx).Tracef("released write lock for ID %s in %v", uniqueID, time.Since(start))
|
||||
log.WithContext(ctx).Tracef("released write lock for ID %s in %v", uniqueID, time.Since(startHold))
|
||||
}
|
||||
|
||||
return unlock
|
||||
@@ -159,19 +167,22 @@ func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (
|
||||
func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||
log.WithContext(ctx).Tracef("acquiring read lock for ID %s", uniqueID)
|
||||
|
||||
start := time.Now()
|
||||
startWait := time.Now()
|
||||
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
|
||||
mtx := value.(*sync.RWMutex)
|
||||
mtx.RLock()
|
||||
log.WithContext(ctx).Tracef("waiting to acquire read lock for ID %s in %v", uniqueID, time.Since(startWait))
|
||||
startHold := time.Now()
|
||||
|
||||
unlock = func() {
|
||||
mtx.RUnlock()
|
||||
log.WithContext(ctx).Tracef("released read lock for ID %s in %v", uniqueID, time.Since(start))
|
||||
log.WithContext(ctx).Tracef("released read lock for ID %s in %v", uniqueID, time.Since(startHold))
|
||||
}
|
||||
|
||||
return unlock
|
||||
}
|
||||
|
||||
// Deprecated: Full account operations are no longer supported
|
||||
func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) error {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
@@ -603,13 +614,16 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
|
||||
ctx, cancel := getDebuggingCtx(ctx)
|
||||
defer cancel()
|
||||
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var user types.User
|
||||
result := tx.Take(&user, idQueryCondition, userID)
|
||||
result := tx.WithContext(ctx).Take(&user, idQueryCondition, userID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewUserNotFoundError(userID)
|
||||
@@ -1075,13 +1089,16 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) {
|
||||
ctx, cancel := getDebuggingCtx(ctx)
|
||||
defer cancel()
|
||||
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var accountNetwork types.AccountNetwork
|
||||
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil {
|
||||
if err := tx.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewAccountNotFoundError(accountID)
|
||||
}
|
||||
@@ -1091,13 +1108,16 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
|
||||
ctx, cancel := getDebuggingCtx(ctx)
|
||||
defer cancel()
|
||||
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var peer nbpeer.Peer
|
||||
result := tx.Take(&peer, GetKeyQueryCondition(s), peerKey)
|
||||
result := tx.WithContext(ctx).Take(&peer, GetKeyQueryCondition(s), peerKey)
|
||||
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
@@ -1146,8 +1166,11 @@ func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength Locking
|
||||
|
||||
// SaveUserLastLogin stores the last login time for a user in DB.
|
||||
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
|
||||
ctx, cancel := getDebuggingCtx(ctx)
|
||||
defer cancel()
|
||||
|
||||
var user types.User
|
||||
result := s.db.Take(&user, accountAndIDQueryCondition, accountID, userID)
|
||||
result := s.db.WithContext(ctx).Take(&user, accountAndIDQueryCondition, accountID, userID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.NewUserNotFoundError(userID)
|
||||
@@ -1328,13 +1351,16 @@ func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn s
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) {
|
||||
ctx, cancel := getDebuggingCtx(ctx)
|
||||
defer cancel()
|
||||
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var setupKey types.SetupKey
|
||||
result := tx.
|
||||
result := tx.WithContext(ctx).
|
||||
Take(&setupKey, GetKeyQueryCondition(s), key)
|
||||
|
||||
if result.Error != nil {
|
||||
@@ -1348,7 +1374,10 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
|
||||
}
|
||||
|
||||
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
|
||||
result := s.db.Model(&types.SetupKey{}).
|
||||
ctx, cancel := getDebuggingCtx(ctx)
|
||||
defer cancel()
|
||||
|
||||
result := s.db.WithContext(ctx).Model(&types.SetupKey{}).
|
||||
Where(idQueryCondition, setupKeyID).
|
||||
Updates(map[string]interface{}{
|
||||
"used_times": gorm.Expr("used_times + 1"),
|
||||
@@ -1368,8 +1397,11 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
|
||||
|
||||
// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction
|
||||
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
|
||||
ctx, cancel := getDebuggingCtx(ctx)
|
||||
defer cancel()
|
||||
|
||||
var groupID string
|
||||
_ = s.db.Model(types.Group{}).
|
||||
_ = s.db.WithContext(ctx).Model(types.Group{}).
|
||||
Select("id").
|
||||
Where("account_id = ? AND name = ?", accountID, "All").
|
||||
Limit(1).
|
||||
@@ -1397,13 +1429,16 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
|
||||
|
||||
// AddPeerToGroup adds a peer to a group
|
||||
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupID string) error {
|
||||
ctx, cancel := getDebuggingCtx(ctx)
|
||||
defer cancel()
|
||||
|
||||
peer := &types.GroupPeer{
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
PeerID: peerID,
|
||||
}
|
||||
|
||||
err := s.db.Clauses(clause.OnConflict{
|
||||
err := s.db.WithContext(ctx).Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}},
|
||||
DoNothing: true,
|
||||
}).Create(peer).Error
|
||||
@@ -1593,7 +1628,10 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt
|
||||
}
|
||||
|
||||
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
|
||||
if err := s.db.Create(peer).Error; err != nil {
|
||||
ctx, cancel := getDebuggingCtx(ctx)
|
||||
defer cancel()
|
||||
|
||||
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
|
||||
return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
|
||||
}
|
||||
|
||||
@@ -1719,7 +1757,10 @@ func (s *SqlStore) DeletePeer(ctx context.Context, accountID string, peerID stri
|
||||
}
|
||||
|
||||
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
|
||||
result := s.db.Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
||||
ctx, cancel := getDebuggingCtx(ctx)
|
||||
defer cancel()
|
||||
|
||||
result := s.db.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to increment network serial count in store")
|
||||
@@ -2761,3 +2802,33 @@ func (s *SqlStore) GetAccountGroupPeers(ctx context.Context, lockStrength Lockin
|
||||
|
||||
return groupPeers, nil
|
||||
}
|
||||
|
||||
func getDebuggingCtx(grpcCtx context.Context) (context.Context, context.CancelFunc) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
userID, ok := grpcCtx.Value(nbcontext.UserIDKey).(string)
|
||||
if ok {
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbcontext.UserIDKey, userID)
|
||||
}
|
||||
|
||||
requestID, ok := grpcCtx.Value(nbcontext.RequestIDKey).(string)
|
||||
if ok {
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbcontext.RequestIDKey, requestID)
|
||||
}
|
||||
|
||||
accountID, ok := grpcCtx.Value(nbcontext.AccountIDKey).(string)
|
||||
if ok {
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
|
||||
}
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-grpcCtx.Done():
|
||||
log.WithContext(grpcCtx).Warnf("grpc context ended early, error: %v", grpcCtx.Err())
|
||||
}
|
||||
}()
|
||||
return ctx, cancel
|
||||
}
|
||||
|
||||
@@ -503,7 +503,7 @@ func (c *GrpcClient) Logout() error {
|
||||
return fmt.Errorf("get server public key: %w", err)
|
||||
}
|
||||
|
||||
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*5)
|
||||
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*15)
|
||||
defer cancel()
|
||||
|
||||
message := &proto.Empty{}
|
||||
|
||||
Reference in New Issue
Block a user