Compare commits

...

47 Commits

Author SHA1 Message Date
hakansa
95794f53ce [client] fix Windows NRPT Policy Path (#4572)
[client] fix Windows NRPT Policy Path
2025-10-02 17:42:25 +07:00
hakansa
9bcd3ebed4 [management,client] Make DNS ForwarderPort Configurable & Change Well Known Port (#4479)
makes the DNS forwarder port configurable in the management and client components, while changing the well-known port from 5454 to 22054. The change includes version-aware port assignment to ensure backward compatibility.

- Adds a configurable `ForwarderPort` field to the DNS configuration protocol
- Implements version-based port computation that returns the new port (22054) only when all peers support version 0.59.0 or newer
- Updates the client to dynamically restart the DNS forwarder when the port changes
2025-10-02 01:02:10 +02:00
Maycon Santos
b85045e723 [misc] Update infra scripts with ws proxy for browser client (#4566)
* Update infra scripts with ws proxy for browser client

* add ws proxy to nginx tmpl
2025-10-02 00:52:54 +02:00
Viktor Liu
4d7e59f199 [client,signal,management] Adjust browser client ws proxy paths (#4565) 2025-10-02 00:10:47 +02:00
Viktor Liu
b5daec3b51 [client,signal,management] Add browser client support (#4415) 2025-10-01 20:10:11 +02:00
Zoltan Papp
5e1a40c33f [client] Order the list of candidates for proper comparison (#4561)
Order the list of candidates for proper comparison
2025-09-30 23:40:46 +02:00
Zoltan Papp
e8d301fdc9 [client] Fix/pkg loss (#3338)
The Relayed connection setup is optimistic. It does not have any confirmation of an established end-to-end connection. Peers start sending WireGuard handshake packets immediately after the successful offer-answer handshake.
Meanwhile, for successful P2P connection negotiation, we change the WireGuard endpoint address, but this change does not trigger new handshake initiation. Because the peer switched from Relayed connection to P2P, the packets from the Relay server are dropped and must wait for the next WireGuard handshake via P2P.

To avoid this scenario, the relayed WireGuard proxy no longer drops the packets. Instead, it rewrites the source address to the new P2P endpoint and continues forwarding the packets.

We still have one corner case: if the Relayed server negotiation chooses a server that has not been used before. In this case, one side of the peer connection will be slower to reach the Relay server, and the Relay server will drop the handshake packet.

If everything goes well we should see exactly 5 seconds improvements between the WireGuard configuration time and the handshake time.
2025-09-30 15:31:18 +02:00
hakansa
17bab881f7 [client] Add Windows DNS Policies To GPO Path Always (#4460)
[client] Add Windows DNS Policies To GPO Path Always (#4460)
2025-09-26 16:42:18 +07:00
Vlad
25ed58328a [management] fix network map dns filter (#4547) 2025-09-25 16:29:14 +02:00
hakansa
644ed4b934 [client] Add WireGuard interface lifecycle monitoring (#4370)
* [client] Add WireGuard interface lifecycle monitoring
2025-09-25 15:36:26 +07:00
Pascal Fischer
58faa341d2 [management] Add logs for update channel (#4527) 2025-09-23 12:06:10 +02:00
Viktor Liu
5853b5553c [client] Skip interface for route lookup if it doesn't exist (#4524) 2025-09-22 14:32:00 +02:00
Zoltan Papp
998fb30e1e [client] Check the client status in the earlier phase (#4509)
This PR improves the NetBird client's status checking mechanism by implementing earlier detection of client state changes and better handling of connection lifecycle management. The key improvements focus on:

  • Enhanced status detection - Added waitForReady option to StatusRequest for improved client status handling
  • Better connection management - Improved context handling for signal and management gRPC connections• Reduced connection timeouts - Increased gRPC dial timeout from 3 to 10 seconds for better reliability
  • Cleaner error handling - Enhanced error propagation and context cancellation in retry loops

  Key Changes

  Core Status Improvements:
  - Added waitForReady optional field to StatusRequest proto (daemon.proto:190)
  - Enhanced status checking logic to detect client state changes earlier in the connection process
  - Improved handling of client permanent exit scenarios from retry loops

  Connection & Context Management:
  - Fixed context cancellation in management and signal client retry mechanisms
  - Added proper context propagation for Login operations
  - Enhanced gRPC connection handling with better timeout management

  Error Handling & Cleanup:
  - Moved feedback channels to upper layers for better separation of concerns
  - Improved error handling patterns throughout the client server implementation
  - Fixed synchronization issues and removed debug logging
2025-09-20 22:14:01 +02:00
Maycon Santos
e254b4cde5 [misc] Update SIGN_PIPE_VER to version 0.0.23 (#4521) 2025-09-20 10:24:04 +02:00
Zoltan Papp
ead1c618ba [client] Do not run up cmd if not needed in docker (#4508)
optimizes the NetBird client startup process by avoiding unnecessary login commands when the peer is already authenticated. The changes increase the default login timeout and expand the log message patterns used to detect successful authentication.

- Increased default login timeout from 1 to 5 seconds for more reliable authentication detection
- Enhanced log pattern matching to detect both registration and ready states
- Added extended regex support for more flexible pattern matching
2025-09-20 10:00:18 +02:00
Viktor Liu
55126f990c [client] Use native windows sock opts to avoid routing loops (#4314)
- Move `util/grpc` and `util/net` to `client` so `internal` packages can be accessed
 - Add methods to return the next best interface after the NetBird interface.
- Use `IP_UNICAST_IF` sock opt to force the outgoing interface for the NetBird `net.Dialer` and `net.ListenerConfig` to avoid routing loops. The interface is picked by the new route lookup method.
- Some refactoring to avoid import cycles
- Old behavior is available through `NB_USE_LEGACY_ROUTING=true` env var
2025-09-20 09:31:04 +02:00
Misha Bragin
90577682e4 Add a new product demo video (#4520) 2025-09-19 13:06:44 +02:00
Bethuel Mmbaga
dc30dcacce [management] Filter DNS records to include only peers to connect (#4517)
DNS record filtering to only include peers that a peer can connect to, reducing unnecessary DNS data in the peer's network map.

- Adds a new `filterZoneRecordsForPeers` function to filter DNS records based on peer connectivity
- Modifies `GetPeerNetworkMap` to use filtered DNS records instead of all records in the custom zone
- Includes comprehensive test coverage for the new filtering functionality
2025-09-18 18:57:07 +02:00
Diego Romar
2c87fa6236 [android] Add OnLoginSuccess callback to URLOpener interface (#4492)
The callback will be fired once login -> internal.Login
completes without errors
2025-09-18 15:07:42 +02:00
hakansa
ec8d83ade4 [client] [UI] Down & Up NetBird Async When Settings Updated
[client] [UI] Down & Up NetBird Async When Settings Updated
2025-09-18 18:13:29 +07:00
Bethuel Mmbaga
3130cce72d [management] Add rule ID validation for policy updates (#4499) 2025-09-15 21:08:16 +03:00
Zoltan Papp
bd23ab925e [client] Fix ICE latency handling (#4501)
The GetSelectedCandidatePair() does not carry the latency information.
2025-09-15 15:08:53 +02:00
Zoltan Papp
0c6f671a7c Refactor healthcheck sender and receiver to use configurable options (#4433) 2025-09-12 09:31:03 +02:00
Bethuel Mmbaga
cf7f6c355f [misc] Remove default zitadel admin user in deployment script (#4482)
* Delete default zitadel-admin user during initialization

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-09-11 21:20:10 +02:00
Zoltan Papp
47e64d72db [client] Fix client status check (#4474)
The client status is not enough to protect the RPC calls from concurrency issues, because it is handled internally in the client in an asynchronous way.
2025-09-11 16:21:09 +02:00
Zoltan Papp
9e81e782e5 [client] Fix/v4 stun routing (#4430)
Deduplicate STUN package sending.
Originally, because every peer shared the same UDP address, the library could not distinguish which STUN message was associated with which candidate. As a result, the Pion library responded from all candidates for every STUN message.
2025-09-11 10:08:54 +02:00
Zoltan Papp
7aef0f67df [client] Implement environment variable handling for Android (#4440)
Some features can only be manipulated via environment variables. With this PR, environment variables can be managed from Android.
2025-09-08 18:42:42 +02:00
Maycon Santos
dba7ef667d [misc] Remove aur support and start service on ostree (#4461)
* Remove aur support and start service on ostree

The aur installation was adding many packages and installing more than just the client. For now is best to remove it and rely on binary install

Some users complained about ostree installation not starting the client, we add two explicit commands to it

* use  ${SUDO}

* fix if closure
2025-09-08 15:03:56 +02:00
Zoltan Papp
69d87343d2 [client] Debug information for connection (#4439)
Improve logging

Print the exact time when the first WireGuard handshake occurs
Print the steps for gathering system information
2025-09-08 14:51:34 +02:00
Bethuel Mmbaga
5113c70943 [management] Extends integration and peers manager (#4450) 2025-09-06 13:13:49 +03:00
Zoltan Papp
ad8fcda67b [client] Move some sys info to static place (#4446)
This PR refactors the system information collection code by moving static system information gathering to a dedicated location and separating platform-specific implementations. The primary goal is to improve code organization and maintainability by centralizing static info collection logic.

Key changes:
- Centralized static info collection into dedicated files with platform-specific implementations
- Moved `StaticInfo` struct definition to the main static_info.go file
- Added async initialization function `UpdateStaticInfoAsync()` across all platforms
2025-09-06 10:49:28 +02:00
Pascal Fischer
d33f88df82 [management] only allow user devices to be expired (#4445) 2025-09-05 18:11:23 +02:00
Zoltan Papp
786ca6fc79 Do not block Offer processing from relay worker (#4435)
- do not miss ICE offers when relay worker busy
- close p2p connection before recreate agent
2025-09-05 11:02:29 +02:00
Diego Romar
dfebdf1444 [internal] Add missing assignment of iFaceDiscover when netstack is disabled (#4444)
The internal updateInterfaces() function expects iFaceDiscover to not
be nil
2025-09-04 23:00:10 +02:00
Bethuel Mmbaga
a8dcff69c2 [management] Add peers manager to integrations (#4405) 2025-09-04 23:07:03 +03:00
Viktor Liu
71e944fa57 [relay] Let relay accept any origin (#4426) 2025-09-01 19:51:06 +02:00
Maycon Santos
d39fcfd62a [management] Add user approval (#4411)
This PR adds user approval functionality to the management system, allowing administrators to manually approve new users joining via domain matching. When enabled, users are blocked with pending approval status until explicitly approved by an admin.

Adds UserApprovalRequired setting to control manual user approval requirement
Introduces user approval and rejection endpoints with corresponding business logic
Prevents pending approval users from adding peers or logging in
2025-09-01 18:00:45 +02:00
Zoltan Papp
21368b38d9 [client] Update Pion ICE to the latest version (#4388)
- Update Pion version
- Update protobuf version
2025-09-01 10:42:01 +02:00
Maycon Santos
d817584f52 [misc] fix Windows client and management bench tests (#4424)
Windows tests had too many directories, causing issues to the payload via psexec.

Also migrated all checked benchmarks to send data to grafana.
2025-08-31 17:19:56 +02:00
Pascal Fischer
4d3dc3475d [management] remove duplicated removal of groups on peer delete (#4421) 2025-08-30 12:47:13 +02:00
Pascal Fischer
6fc50a438f [management] remove withContext from store methods (#4422) 2025-08-30 12:46:54 +02:00
Vlad
149559a06b [management] login filter to fix multiple peers connected with the same pub key (#3986) 2025-08-29 19:48:40 +02:00
Pascal Fischer
e14c6de203 [management] fix ephemeral flag on peer batch response (#4420) 2025-08-29 17:41:20 +02:00
Viktor Liu
d4c067f0af [client] Don't deactivate upstream resolvers on failure (#4128) 2025-08-29 17:40:05 +02:00
Pascal Fischer
dbefa8bd9f [management] remove lock and continue user update on failure (#4410) 2025-08-28 17:50:12 +02:00
Pascal Fischer
4fd10b9447 [management] split high latency grpc metrics (#4408) 2025-08-28 13:25:40 +02:00
Viktor Liu
aa595c3073 [client] Fix shared sock buffer allocation (#4409) 2025-08-28 13:25:16 +02:00
304 changed files with 11395 additions and 2963 deletions

View File

@@ -217,7 +217,7 @@ jobs:
- arch: "386" - arch: "386"
raceFlag: "" raceFlag: ""
- arch: "amd64" - arch: "amd64"
raceFlag: "" raceFlag: "-race"
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Install Go - name: Install Go
@@ -382,6 +382,32 @@ jobs:
store: [ 'sqlite', 'postgres' ] store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Create Docker network
run: docker network create promnet
- name: Start Prometheus Pushgateway
run: docker run -d --name pushgateway --network promnet -p 9091:9091 prom/pushgateway
- name: Start Prometheus (for Pushgateway forwarding)
run: |
echo '
global:
scrape_interval: 15s
scrape_configs:
- job_name: "pushgateway"
static_configs:
- targets: ["pushgateway:9091"]
remote_write:
- url: ${{ secrets.GRAFANA_URL }}
basic_auth:
username: ${{ secrets.GRAFANA_USER }}
password: ${{ secrets.GRAFANA_API_KEY }}
' > prometheus.yml
docker run -d --name prometheus --network promnet \
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
-p 9090:9090 \
prom/prometheus
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
@@ -428,9 +454,10 @@ jobs:
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \ NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \ CI=true \
GIT_BRANCH=${{ github.ref_name }} \
go test -tags devcert -run=^$ -bench=. \ go test -tags devcert -run=^$ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
-timeout 20m ./management/... ./shared/management/... -timeout 20m ./management/... ./shared/management/... $(go list ./management/... ./shared/management/... | grep -v -e /management/server/http)
api_benchmark: api_benchmark:
name: "Management / Benchmark (API)" name: "Management / Benchmark (API)"
@@ -521,7 +548,7 @@ jobs:
-run=^$ \ -run=^$ \
-bench=. \ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
-timeout 20m ./management/... ./shared/management/... -timeout 20m ./management/server/http/...
api_integration_test: api_integration_test:
name: "Management / Integration" name: "Management / Integration"
@@ -571,4 +598,4 @@ jobs:
CI=true \ CI=true \
go test -tags=integration \ go test -tags=integration \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/... ./shared/management/... -timeout 20m ./management/server/http/...

View File

@@ -63,7 +63,7 @@ jobs:
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }} - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }} - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV - run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' })" >> $env:GITHUB_ENV
- name: test - name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1" run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell - name: codespell
uses: codespell-project/actions-codespell@v2 uses: codespell-project/actions-codespell@v2
with: with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros
skip: go.mod,go.sum skip: go.mod,go.sum
golangci: golangci:
strategy: strategy:

View File

@@ -9,7 +9,7 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.0.22" SIGN_PIPE_VER: "v0.0.23"
GORELEASER_VER: "v2.3.2" GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird" PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH" COPYRIGHT: "NetBird GmbH"

View File

@@ -0,0 +1,67 @@
name: Wasm
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
js_lint:
name: "JS / Lint"
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: Install golangci-lint
uses: golangci/golangci-lint-action@d6238b002a20823d52840fda27e2d4891c5952dc
with:
version: latest
install-mode: binary
skip-cache: true
skip-pkg-cache: true
skip-build-cache: true
- name: Run golangci-lint for WASM
run: |
GOOS=js GOARCH=wasm golangci-lint run --timeout=12m --out-format colored-line-number ./client/...
continue-on-error: true
js_build:
name: "JS / Build"
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
- name: Build Wasm client
run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd
env:
CGO_ENABLED: 0
- name: Check Wasm build size
run: |
echo "Wasm build size:"
ls -lh netbird.wasm
SIZE=$(stat -c%s netbird.wasm)
SIZE_MB=$((SIZE / 1024 / 1024))
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
if [ ${SIZE} -gt 52428800 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 50MB limit!"
exit 1
fi

0
.gitmodules vendored Normal file
View File

View File

@@ -2,6 +2,18 @@ version: 2
project_name: netbird project_name: netbird
builds: builds:
- id: netbird-wasm
dir: client/wasm/cmd
binary: netbird
env: [GOOS=js, GOARCH=wasm, CGO_ENABLED=0]
goos:
- js
goarch:
- wasm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird - id: netbird
dir: client dir: client
binary: netbird binary: netbird
@@ -115,6 +127,11 @@ archives:
- builds: - builds:
- netbird - netbird
- netbird-static - netbird-static
- id: netbird-wasm
builds:
- netbird-wasm
name_template: "{{ .ProjectName }}_{{ .Version }}"
format: binary
nfpms: nfpms:
- maintainer: Netbird <dev@netbird.io> - maintainer: Netbird <dev@netbird.io>

View File

@@ -1,3 +1,4 @@
<div align="center"> <div align="center">
<br/> <br/>
<br/> <br/>
@@ -52,7 +53,7 @@
### Open Source Network Security in a Single Platform ### Open Source Network Security in a Single Platform
<img width="1188" alt="centralized-network-management 1" src="https://github.com/user-attachments/assets/c28cc8e4-15d2-4d2f-bb97-a6433db39d56" /> https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
### NetBird on Lawrence Systems (Video) ### NetBird on Lawrence Systems (Video)
[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw) [![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)

View File

@@ -18,7 +18,7 @@ ENV \
NB_LOG_FILE="console,/var/log/netbird/client.log" \ NB_LOG_FILE="console,/var/log/netbird/client.log" \
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \ NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \ NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
NB_ENTRYPOINT_LOGIN_TIMEOUT="1" NB_ENTRYPOINT_LOGIN_TIMEOUT="5"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]

View File

@@ -4,6 +4,7 @@ package android
import ( import (
"context" "context"
"os"
"slices" "slices"
"sync" "sync"
@@ -18,7 +19,7 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/client/net"
) )
// ConnectionListener export internal Listener for mobile // ConnectionListener export internal Listener for mobile
@@ -83,7 +84,8 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi
} }
// Run start the internal client. It is a blocker function // Run start the internal client. It is a blocker function
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error { func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
exportEnvList(envList)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile, ConfigPath: c.cfgFile,
}) })
@@ -118,7 +120,8 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
// In this case make no sense handle registration steps. // In this case make no sense handle registration steps.
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error { func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
exportEnvList(envList)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile, ConfigPath: c.cfgFile,
}) })
@@ -249,3 +252,14 @@ func (c *Client) SetConnectionListener(listener ConnectionListener) {
func (c *Client) RemoveConnectionListener() { func (c *Client) RemoveConnectionListener() {
c.recorder.RemoveConnectionListener() c.recorder.RemoveConnectionListener()
} }
func exportEnvList(list *EnvList) {
if list == nil {
return
}
for k, v := range list.AllItems() {
if err := os.Setenv(k, v); err != nil {
log.Errorf("could not set env variable %s: %v", k, err)
}
}
}

View File

@@ -0,0 +1,32 @@
package android
import "github.com/netbirdio/netbird/client/internal/peer"
var (
// EnvKeyNBForceRelay Exported for Android java client
EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay
)
// EnvList wraps a Go map for export to Java
type EnvList struct {
data map[string]string
}
// NewEnvList creates a new EnvList
func NewEnvList() *EnvList {
return &EnvList{data: make(map[string]string)}
}
// Put adds a key-value pair
func (el *EnvList) Put(key, value string) {
el.data[key] = value
}
// Get retrieves a value by key
func (el *EnvList) Get(key string) string {
return el.data[key]
}
func (el *EnvList) AllItems() map[string]string {
return el.data
}

View File

@@ -33,6 +33,7 @@ type ErrListener interface {
// the backend want to show an url for the user // the backend want to show an url for the user
type URLOpener interface { type URLOpener interface {
Open(string) Open(string)
OnLoginSuccess()
} }
// Auth can register or login new client // Auth can register or login new client
@@ -181,6 +182,11 @@ func (a *Auth) login(urlOpener URLOpener) error {
err = a.withBackOff(a.ctx, func() error { err = a.withBackOff(a.ctx, func() error {
err := internal.Login(a.ctx, a.config, "", jwtToken) err := internal.Login(a.ctx, a.config, "", jwtToken)
if err == nil {
go urlOpener.OnLoginSuccess()
}
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return nil return nil
} }

8
client/cmd/debug_js.go Normal file
View File

@@ -0,0 +1,8 @@
package cmd
import "context"
// SetupDebugHandler is a no-op for WASM
func SetupDebugHandler(context.Context, interface{}, interface{}, interface{}, string) {
// Debug handler not needed for WASM
}

View File

@@ -27,7 +27,7 @@ var downCmd = &cobra.Command{
return err return err
} }
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7) ctx, cancel := context.WithTimeout(context.Background(), time.Second*20)
defer cancel() defer cancel()
conn, err := DialClientGRPCServer(ctx, daemonAddr) conn, err := DialClientGRPCServer(ctx, daemonAddr)

View File

@@ -227,7 +227,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
} }
// update host's static platform and system information // update host's static platform and system information
system.UpdateStaticInfo() system.UpdateStaticInfoAsync()
configFilePath, err := activeProf.FilePath() configFilePath, err := activeProf.FilePath()
if err != nil { if err != nil {

View File

@@ -231,7 +231,7 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string {
// DialClientGRPCServer returns client connection to the daemon server. // DialClientGRPCServer returns client connection to the daemon server.
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) { func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*3) ctx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel() defer cancel()
return grpc.DialContext( return grpc.DialContext(

View File

@@ -27,7 +27,7 @@ func (p *program) Start(svc service.Service) error {
log.Info("starting NetBird service") //nolint log.Info("starting NetBird service") //nolint
// Collect static system and platform information // Collect static system and platform information
system.UpdateStaticInfo() system.UpdateStaticInfoAsync()
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API. // in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
p.serv = grpc.NewServer() p.serv = grpc.NewServer()

View File

@@ -9,29 +9,28 @@ import (
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
"google.golang.org/grpc" "google.golang.org/grpc"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
clientProto "github.com/netbirdio/netbird/client/proto" clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server" client "github.com/netbirdio/netbird/client/server"
"github.com/netbirdio/netbird/management/internals/server/config"
mgmt "github.com/netbirdio/netbird/management/server" mgmt "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto" mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
sigProto "github.com/netbirdio/netbird/shared/signal/proto" sigProto "github.com/netbirdio/netbird/shared/signal/proto"
sig "github.com/netbirdio/netbird/signal/server" sig "github.com/netbirdio/netbird/signal/server"
"github.com/netbirdio/netbird/util"
) )
func startTestingServices(t *testing.T) string { func startTestingServices(t *testing.T) string {
@@ -90,15 +89,20 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
if err != nil { if err != nil {
return nil, nil return nil, nil
} }
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish) t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
permissionsManagerMock := permissions.NewMockManager(ctrl) permissionsManagerMock := permissions.NewMockManager(ctrl)
peersmanager := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl)
iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
settingsMockManager.EXPECT(). settingsMockManager.EXPECT().
@@ -112,7 +116,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
} }
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{}) mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -230,7 +230,9 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
status, err := client.Status(ctx, &proto.StatusRequest{}) status, err := client.Status(ctx, &proto.StatusRequest{
WaitForReady: func() *bool { b := true; return &b }(),
})
if err != nil { if err != nil {
return fmt.Errorf("unable to get daemon status: %v", err) return fmt.Errorf("unable to get daemon status: %v", err)
} }

View File

@@ -23,23 +23,29 @@ import (
var ErrClientAlreadyStarted = errors.New("client already started") var ErrClientAlreadyStarted = errors.New("client already started")
var ErrClientNotStarted = errors.New("client not started") var ErrClientNotStarted = errors.New("client not started")
var ErrConfigNotInitialized = errors.New("config not initialized")
// Client manages a netbird embedded client instance // Client manages a netbird embedded client instance.
type Client struct { type Client struct {
deviceName string deviceName string
config *profilemanager.Config config *profilemanager.Config
mu sync.Mutex mu sync.Mutex
cancel context.CancelFunc cancel context.CancelFunc
setupKey string setupKey string
jwtToken string
connect *internal.ConnectClient connect *internal.ConnectClient
} }
// Options configures a new Client // Options configures a new Client.
type Options struct { type Options struct {
// DeviceName is this peer's name in the network // DeviceName is this peer's name in the network
DeviceName string DeviceName string
// SetupKey is used for authentication // SetupKey is used for authentication
SetupKey string SetupKey string
// JWTToken is used for JWT-based authentication
JWTToken string
// PrivateKey is used for direct private key authentication
PrivateKey string
// ManagementURL overrides the default management server URL // ManagementURL overrides the default management server URL
ManagementURL string ManagementURL string
// PreSharedKey is the pre-shared key for the WireGuard interface // PreSharedKey is the pre-shared key for the WireGuard interface
@@ -58,8 +64,35 @@ type Options struct {
DisableClientRoutes bool DisableClientRoutes bool
} }
// New creates a new netbird embedded client // validateCredentials checks that exactly one credential type is provided
func (opts *Options) validateCredentials() error {
credentialsProvided := 0
if opts.SetupKey != "" {
credentialsProvided++
}
if opts.JWTToken != "" {
credentialsProvided++
}
if opts.PrivateKey != "" {
credentialsProvided++
}
if credentialsProvided == 0 {
return fmt.Errorf("one of SetupKey, JWTToken, or PrivateKey must be provided")
}
if credentialsProvided > 1 {
return fmt.Errorf("only one of SetupKey, JWTToken, or PrivateKey can be specified")
}
return nil
}
// New creates a new netbird embedded client.
func New(opts Options) (*Client, error) { func New(opts Options) (*Client, error) {
if err := opts.validateCredentials(); err != nil {
return nil, err
}
if opts.LogOutput != nil { if opts.LogOutput != nil {
logrus.SetOutput(opts.LogOutput) logrus.SetOutput(opts.LogOutput)
} }
@@ -107,9 +140,14 @@ func New(opts Options) (*Client, error) {
return nil, fmt.Errorf("create config: %w", err) return nil, fmt.Errorf("create config: %w", err)
} }
if opts.PrivateKey != "" {
config.PrivateKey = opts.PrivateKey
}
return &Client{ return &Client{
deviceName: opts.DeviceName, deviceName: opts.DeviceName,
setupKey: opts.SetupKey, setupKey: opts.SetupKey,
jwtToken: opts.JWTToken,
config: config, config: config,
}, nil }, nil
} }
@@ -126,7 +164,7 @@ func (c *Client) Start(startCtx context.Context) error {
ctx := internal.CtxInitState(context.Background()) ctx := internal.CtxInitState(context.Background())
// nolint:staticcheck // nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName) ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
if err := internal.Login(ctx, c.config, c.setupKey, ""); err != nil { if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
return fmt.Errorf("login: %w", err) return fmt.Errorf("login: %w", err)
} }
@@ -135,7 +173,7 @@ func (c *Client) Start(startCtx context.Context) error {
// either startup error (permanent backoff err) or nil err (successful engine up) // either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available // TODO: make after-startup backoff err available
run := make(chan struct{}, 1) run := make(chan struct{})
clientErr := make(chan error, 1) clientErr := make(chan error, 1)
go func() { go func() {
if err := client.Run(run); err != nil { if err := client.Run(run); err != nil {
@@ -187,6 +225,16 @@ func (c *Client) Stop(ctx context.Context) error {
} }
} }
// GetConfig returns a copy of the internal client config.
func (c *Client) GetConfig() (profilemanager.Config, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.config == nil {
return profilemanager.Config{}, ErrConfigNotInitialized
}
return *c.config, nil
}
// Dial dials a network address in the netbird network. // Dial dials a network address in the netbird network.
// Not applicable if the userspace networking mode is disabled. // Not applicable if the userspace networking mode is disabled.
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) { func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
@@ -211,7 +259,7 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e
return nsnet.DialContext(ctx, network, address) return nsnet.DialContext(ctx, network, address)
} }
// ListenTCP listens on the given address in the netbird network // ListenTCP listens on the given address in the netbird network.
// Not applicable if the userspace networking mode is disabled. // Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenTCP(address string) (net.Listener, error) { func (c *Client) ListenTCP(address string) (net.Listener, error) {
nsnet, addr, err := c.getNet() nsnet, addr, err := c.getNet()
@@ -232,7 +280,7 @@ func (c *Client) ListenTCP(address string) (net.Listener, error) {
return nsnet.ListenTCP(tcpAddr) return nsnet.ListenTCP(tcpAddr)
} }
// ListenUDP listens on the given address in the netbird network // ListenUDP listens on the given address in the netbird network.
// Not applicable if the userspace networking mode is disabled. // Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenUDP(address string) (net.PacketConn, error) { func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
nsnet, addr, err := c.getNet() nsnet, addr, err := c.getNet()

View File

@@ -12,7 +12,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
const ( const (

View File

@@ -19,7 +19,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
// constants needed to manage and create iptable rules // constants needed to manage and create iptable rules

View File

@@ -14,7 +14,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test" "github.com/netbirdio/netbird/client/firewall/test"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
func isIptablesSupported() bool { func isIptablesSupported() bool {

View File

@@ -16,7 +16,7 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
const ( const (

View File

@@ -22,7 +22,7 @@ import (
nbid "github.com/netbirdio/netbird/client/internal/acl/id" nbid "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
const ( const (

View File

@@ -4,15 +4,9 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"fmt"
"net"
"os/user"
"runtime" "runtime"
"time" "time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc" "google.golang.org/grpc"
@@ -21,35 +15,9 @@ import (
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/util/embeddedroots" "github.com/netbirdio/netbird/util/embeddedroots"
nbnet "github.com/netbirdio/netbird/util/net"
) )
func WithCustomDialer() grpc.DialOption { // Backoff returns a backoff configuration for gRPC calls
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
if runtime.GOOS == "linux" {
currentUser, err := user.Current()
if err != nil {
return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
}
// the custom dialer requires root permissions which are not required for use cases run as non-root
if currentUser.Uid != "0" {
log.Debug("Not running as root, using standard dialer")
dialer := &net.Dialer{}
return dialer.DialContext(ctx, "tcp", addr)
}
}
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil {
log.Errorf("Failed to dial: %s", err)
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
}
return conn, nil
})
}
// grpcDialBackoff is the backoff mechanism for the grpc calls
func Backoff(ctx context.Context) backoff.BackOff { func Backoff(ctx context.Context) backoff.BackOff {
b := backoff.NewExponentialBackOff() b := backoff.NewExponentialBackOff()
b.MaxElapsedTime = 10 * time.Second b.MaxElapsedTime = 10 * time.Second
@@ -57,7 +25,9 @@ func Backoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(b, ctx) return backoff.WithContext(b, ctx)
} }
func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) { // CreateConnection creates a gRPC client connection with the appropriate transport options.
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
if tlsEnabled { if tlsEnabled {
certPool, err := x509.SystemCertPool() certPool, err := x509.SystemCertPool()
@@ -67,18 +37,20 @@ func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
} }
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
// for js, outer websocket layer takes care of tls verification via WithCustomDialer
InsecureSkipVerify: runtime.GOOS == "js",
RootCAs: certPool, RootCAs: certPool,
})) }))
} }
connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel() defer cancel()
conn, err := grpc.DialContext( conn, err := grpc.DialContext(
connCtx, connCtx,
addr, addr,
transportOption, transportOption,
WithCustomDialer(), WithCustomDialer(tlsEnabled, component),
grpc.WithBlock(), grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{ grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second, Time: 30 * time.Second,

View File

@@ -0,0 +1,44 @@
//go:build !js
package grpc
import (
"context"
"fmt"
"net"
"os/user"
"runtime"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
nbnet "github.com/netbirdio/netbird/client/net"
)
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
if runtime.GOOS == "linux" {
currentUser, err := user.Current()
if err != nil {
return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
}
// the custom dialer requires root permissions which are not required for use cases run as non-root
if currentUser.Uid != "0" {
log.Debug("Not running as root, using standard dialer")
dialer := &net.Dialer{}
return dialer.DialContext(ctx, "tcp", addr)
}
}
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil {
log.Errorf("Failed to dial: %s", err)
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
}
return conn, nil
})
}

13
client/grpc/dialer_js.go Normal file
View File

@@ -0,0 +1,13 @@
package grpc
import (
"google.golang.org/grpc"
"github.com/netbirdio/netbird/util/wsproxy/client"
)
// WithCustomDialer returns a gRPC dial option that uses WebSocket transport for WASM/JS environments.
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
return client.WithWebSocketDialer(tlsEnabled, component)
}

View File

@@ -3,7 +3,7 @@ package bind
import ( import (
wireguard "golang.zx2c4.com/wireguard/conn" wireguard "golang.zx2c4.com/wireguard/conn"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go) // TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)

View File

@@ -1,5 +1,17 @@
package bind package bind
import wgConn "golang.zx2c4.com/wireguard/conn" import (
"net"
wgConn "golang.zx2c4.com/wireguard/conn"
)
type Endpoint = wgConn.StdNetEndpoint type Endpoint = wgConn.StdNetEndpoint
func EndpointToUDPAddr(e Endpoint) *net.UDPAddr {
return &net.UDPAddr{
IP: e.Addr().AsSlice(),
Port: int(e.Port()),
Zone: e.Addr().Zone(),
}
}

View File

@@ -0,0 +1,7 @@
package bind
import "fmt"
var (
ErrUDPMUXNotSupported = fmt.Errorf("UDPMUX is not supported in WASM")
)

View File

@@ -1,6 +1,9 @@
//go:build !js
package bind package bind
import ( import (
"context"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"net" "net"
@@ -8,22 +11,18 @@ import (
"runtime" "runtime"
"sync" "sync"
"github.com/pion/stun/v2" "github.com/pion/stun/v3"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn" wgConn "golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
type RecvMessage struct {
Endpoint *Endpoint
Buffer []byte
}
type receiverCreator struct { type receiverCreator struct {
iceBind *ICEBind iceBind *ICEBind
} }
@@ -41,37 +40,38 @@ func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UD
// use the port because in the Send function the wgConn.Endpoint the port info is not exported. // use the port because in the Send function the wgConn.Endpoint the port info is not exported.
type ICEBind struct { type ICEBind struct {
*wgConn.StdNetBind *wgConn.StdNetBind
RecvChan chan RecvMessage
transportNet transport.Net transportNet transport.Net
filterFn FilterFn filterFn udpmux.FilterFn
address wgaddr.Address
mtu uint16
endpoints map[netip.Addr]net.Conn endpoints map[netip.Addr]net.Conn
endpointsMu sync.Mutex endpointsMu sync.Mutex
recvChan chan recvMessage
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a // every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
// new closed channel. With the closedChanMu we can safely close the channel and create a new one // new closed channel. With the closedChanMu we can safely close the channel and create a new one
closedChan chan struct{} closedChan chan struct{}
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it. closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
closed bool closed bool
activityRecorder *ActivityRecorder
muUDPMux sync.Mutex muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
address wgaddr.Address
mtu uint16
activityRecorder *ActivityRecorder
} }
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address, mtu uint16) *ICEBind { func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{ ib := &ICEBind{
StdNetBind: b, StdNetBind: b,
RecvChan: make(chan RecvMessage, 1),
transportNet: transportNet, transportNet: transportNet,
filterFn: filterFn, filterFn: filterFn,
address: address,
mtu: mtu,
endpoints: make(map[netip.Addr]net.Conn), endpoints: make(map[netip.Addr]net.Conn),
recvChan: make(chan recvMessage, 1),
closedChan: make(chan struct{}), closedChan: make(chan struct{}),
closed: true, closed: true,
mtu: mtu,
address: address,
activityRecorder: NewActivityRecorder(), activityRecorder: NewActivityRecorder(),
} }
@@ -82,10 +82,6 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Ad
return ib return ib
} }
func (s *ICEBind) MTU() uint16 {
return s.mtu
}
func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) { func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
s.closed = false s.closed = false
s.closedChanMu.Lock() s.closedChanMu.Lock()
@@ -115,7 +111,7 @@ func (s *ICEBind) ActivityRecorder() *ActivityRecorder {
} }
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind // GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) { func (s *ICEBind) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) {
s.muUDPMux.Lock() s.muUDPMux.Lock()
defer s.muUDPMux.Unlock() defer s.muUDPMux.Unlock()
if s.udpMux == nil { if s.udpMux == nil {
@@ -138,6 +134,16 @@ func (b *ICEBind) RemoveEndpoint(fakeIP netip.Addr) {
delete(b.endpoints, fakeIP) delete(b.endpoints, fakeIP)
} }
func (b *ICEBind) ReceiveFromEndpoint(ctx context.Context, ep *Endpoint, buf []byte) {
select {
case <-b.closedChan:
return
case <-ctx.Done():
return
case b.recvChan <- recvMessage{ep, buf}:
}
}
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
b.endpointsMu.Lock() b.endpointsMu.Lock()
conn, ok := b.endpoints[ep.DstIP()] conn, ok := b.endpoints[ep.DstIP()]
@@ -158,8 +164,8 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
s.muUDPMux.Lock() s.muUDPMux.Lock()
defer s.muUDPMux.Unlock() defer s.muUDPMux.Unlock()
s.udpMux = NewUniversalUDPMuxDefault( s.udpMux = udpmux.NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{ udpmux.UniversalUDPMuxParams{
UDPConn: nbnet.WrapPacketConn(conn), UDPConn: nbnet.WrapPacketConn(conn),
Net: s.transportNet, Net: s.transportNet,
FilterFn: s.filterFn, FilterFn: s.filterFn,
@@ -270,7 +276,7 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
select { select {
case <-c.closedChan: case <-c.closedChan:
return 0, net.ErrClosed return 0, net.ErrClosed
case msg, ok := <-c.RecvChan: case msg, ok := <-c.recvChan:
if !ok { if !ok {
return 0, net.ErrClosed return 0, net.ErrClosed
} }

View File

@@ -0,0 +1,6 @@
package bind
type recvMessage struct {
Endpoint *Endpoint
Buffer []byte
}

View File

@@ -0,0 +1,125 @@
package bind
import (
"context"
"net"
"net/netip"
"sync"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/udpmux"
)
// RelayBindJS is a conn.Bind implementation for WebAssembly environments.
// Do not limit to build only js, because we want to be able to run tests
type RelayBindJS struct {
*conn.StdNetBind
recvChan chan recvMessage
endpoints map[netip.Addr]net.Conn
endpointsMu sync.Mutex
activityRecorder *ActivityRecorder
ctx context.Context
cancel context.CancelFunc
}
func NewRelayBindJS() *RelayBindJS {
return &RelayBindJS{
recvChan: make(chan recvMessage, 100),
endpoints: make(map[netip.Addr]net.Conn),
activityRecorder: NewActivityRecorder(),
}
}
// Open creates a receive function for handling relay packets in WASM.
func (s *RelayBindJS) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
log.Debugf("Open: creating receive function for port %d", uport)
s.ctx, s.cancel = context.WithCancel(context.Background())
receiveFn := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) {
select {
case <-s.ctx.Done():
return 0, net.ErrClosed
case msg, ok := <-s.recvChan:
if !ok {
return 0, net.ErrClosed
}
copy(bufs[0], msg.Buffer)
sizes[0] = len(msg.Buffer)
eps[0] = conn.Endpoint(msg.Endpoint)
return 1, nil
}
}
log.Debugf("Open: receive function created, returning port %d", uport)
return []conn.ReceiveFunc{receiveFn}, uport, nil
}
func (s *RelayBindJS) Close() error {
if s.cancel == nil {
return nil
}
log.Debugf("close RelayBindJS")
s.cancel()
return nil
}
func (s *RelayBindJS) ReceiveFromEndpoint(ctx context.Context, ep *Endpoint, buf []byte) {
select {
case <-s.ctx.Done():
return
case <-ctx.Done():
return
case s.recvChan <- recvMessage{ep, buf}:
}
}
// Send forwards packets through the relay connection for WASM.
func (s *RelayBindJS) Send(bufs [][]byte, ep conn.Endpoint) error {
if ep == nil {
return nil
}
fakeIP := ep.DstIP()
s.endpointsMu.Lock()
relayConn, ok := s.endpoints[fakeIP]
s.endpointsMu.Unlock()
if !ok {
return nil
}
for _, buf := range bufs {
if _, err := relayConn.Write(buf); err != nil {
return err
}
}
return nil
}
func (b *RelayBindJS) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
b.endpointsMu.Lock()
b.endpoints[fakeIP] = conn
b.endpointsMu.Unlock()
}
func (s *RelayBindJS) RemoveEndpoint(fakeIP netip.Addr) {
s.endpointsMu.Lock()
defer s.endpointsMu.Unlock()
delete(s.endpoints, fakeIP)
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *RelayBindJS) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) {
return nil, ErrUDPMUXNotSupported
}
func (s *RelayBindJS) ActivityRecorder() *ActivityRecorder {
return s.activityRecorder
}

View File

@@ -1,7 +0,0 @@
//go:build ios
package bind
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
// iOS doesn't support nbnet hooks, so this is a no-op
}

View File

@@ -1,4 +1,4 @@
//go:build linux || windows || freebsd //go:build linux || windows || freebsd || js || wasip1
package configurer package configurer

View File

@@ -1,4 +1,4 @@
//go:build !windows //go:build !windows && !js
package configurer package configurer

View File

@@ -0,0 +1,23 @@
package configurer
import (
"net"
)
type noopListener struct{}
func (n *noopListener) Accept() (net.Conn, error) {
return nil, net.ErrClosed
}
func (n *noopListener) Close() error {
return nil
}
func (n *noopListener) Addr() net.Addr {
return nil
}
func openUAPI(deviceName string) (net.Listener, error) {
return &noopListener{}, nil
}

View File

@@ -17,8 +17,8 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/monotime" "github.com/netbirdio/netbird/monotime"
nbnet "github.com/netbirdio/netbird/util/net"
) )
const ( const (
@@ -394,6 +394,13 @@ func toLastHandshake(stringVar string) (time.Time, error) {
if err != nil { if err != nil {
return time.Time{}, fmt.Errorf("parse handshake sec: %w", err) return time.Time{}, fmt.Errorf("parse handshake sec: %w", err)
} }
// If sec is 0 (Unix epoch), return zero time instead
// This indicates no handshake has occurred
if sec == 0 {
return time.Time{}, nil
}
return time.Unix(sec, 0), nil return time.Unix(sec, 0), nil
} }
@@ -402,7 +409,7 @@ func toBytes(s string) (int64, error) {
} }
func getFwmark() int { func getFwmark() int {
if nbnet.AdvancedRouting() { if nbnet.AdvancedRouting() && runtime.GOOS == "linux" {
return nbnet.ControlPlaneMark return nbnet.ControlPlaneMark
} }
return 0 return 0

View File

@@ -7,14 +7,14 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type WGTunDevice interface { type WGTunDevice interface {
Create() (device.WGConfigurer, error) Create() (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*udpmux.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address WgAddress() wgaddr.Address
MTU() uint16 MTU() uint16

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -29,7 +30,7 @@ type WGTunDevice struct {
name string name string
device *device.Device device *device.Device
filteredDevice *FilteredDevice filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer configurer WGConfigurer
} }
@@ -88,7 +89,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
} }
return t.configurer, nil return t.configurer, nil
} }
func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
err := t.device.Up() err := t.device.Up()
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -26,7 +27,7 @@ type TunDevice struct {
device *device.Device device *device.Device
filteredDevice *FilteredDevice filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer configurer WGConfigurer
} }
@@ -71,7 +72,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil return t.configurer, nil
} }
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
err := t.device.Up() err := t.device.Up()
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -28,7 +29,7 @@ type TunDevice struct {
device *device.Device device *device.Device
filteredDevice *FilteredDevice filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer configurer WGConfigurer
} }
@@ -83,7 +84,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil return t.configurer, nil
} }
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
err := t.device.Up() err := t.device.Up()
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -12,11 +12,11 @@ import (
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/sharedsock" "github.com/netbirdio/netbird/sharedsock"
nbnet "github.com/netbirdio/netbird/util/net"
) )
type TunKernelDevice struct { type TunKernelDevice struct {
@@ -31,9 +31,9 @@ type TunKernelDevice struct {
link *wgLink link *wgLink
udpMuxConn net.PacketConn udpMuxConn net.PacketConn
udpMux *bind.UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
filterFn bind.FilterFn filterFn udpmux.FilterFn
} }
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice { func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice {
@@ -79,7 +79,7 @@ func (t *TunKernelDevice) Create() (WGConfigurer, error) {
return configurer, nil return configurer, nil
} }
func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
if t.udpMux != nil { if t.udpMux != nil {
return t.udpMux, nil return t.udpMux, nil
} }
@@ -101,19 +101,14 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return nil, err return nil, err
} }
var udpConn net.PacketConn = rawSock bindParams := udpmux.UniversalUDPMuxParams{
if !nbnet.AdvancedRouting() { UDPConn: nbnet.WrapPacketConn(rawSock),
udpConn = nbnet.WrapPacketConn(rawSock)
}
bindParams := bind.UniversalUDPMuxParams{
UDPConn: udpConn,
Net: t.transportNet, Net: t.transportNet,
FilterFn: t.filterFn, FilterFn: t.filterFn,
WGAddress: t.address, WGAddress: t.address,
MTU: t.mtu, MTU: t.mtu,
} }
mux := bind.NewUniversalUDPMuxDefault(bindParams) mux := udpmux.NewUniversalUDPMuxDefault(bindParams)
go mux.ReadFromConn(t.ctx) go mux.ReadFromConn(t.ctx)
t.udpMuxConn = rawSock t.udpMuxConn = rawSock
t.udpMux = mux t.udpMux = mux

View File

@@ -1,19 +1,28 @@
package device package device
import ( import (
"errors"
"fmt" "fmt"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
type Bind interface {
conn.Bind
GetICEMux() (*udpmux.UniversalUDPMuxDefault, error)
ActivityRecorder() *bind.ActivityRecorder
}
type TunNetstackDevice struct { type TunNetstackDevice struct {
name string name string
address wgaddr.Address address wgaddr.Address
@@ -21,18 +30,18 @@ type TunNetstackDevice struct {
key string key string
mtu uint16 mtu uint16
listenAddress string listenAddress string
iceBind *bind.ICEBind bind Bind
device *device.Device device *device.Device
filteredDevice *FilteredDevice filteredDevice *FilteredDevice
nsTun *nbnetstack.NetStackTun nsTun *nbnetstack.NetStackTun
udpMux *bind.UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer configurer WGConfigurer
net *netstack.Net net *netstack.Net
} }
func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice { func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, bind Bind, listenAddress string) *TunNetstackDevice {
return &TunNetstackDevice{ return &TunNetstackDevice{
name: name, name: name,
address: address, address: address,
@@ -40,7 +49,7 @@ func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key stri
key: key, key: key,
mtu: mtu, mtu: mtu,
listenAddress: listenAddress, listenAddress: listenAddress,
iceBind: iceBind, bind: bind,
} }
} }
@@ -65,11 +74,11 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
t.device = device.NewDevice( t.device = device.NewDevice(
t.filteredDevice, t.filteredDevice,
t.iceBind, t.bind,
device.NewLogger(wgLogLevel(), "[netbird] "), device.NewLogger(wgLogLevel(), "[netbird] "),
) )
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder()) t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port) err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil { if err != nil {
_ = tunIface.Close() _ = tunIface.Close()
@@ -80,7 +89,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
return t.configurer, nil return t.configurer, nil
} }
func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) { func (t *TunNetstackDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
if t.device == nil { if t.device == nil {
return nil, fmt.Errorf("device is not ready yet") return nil, fmt.Errorf("device is not ready yet")
} }
@@ -90,11 +99,15 @@ func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return nil, err return nil, err
} }
udpMux, err := t.iceBind.GetICEMux() udpMux, err := t.bind.GetICEMux()
if err != nil { if err != nil && !errors.Is(err, bind.ErrUDPMUXNotSupported) {
return nil, err return nil, err
} }
if udpMux != nil {
t.udpMux = udpMux t.udpMux = udpMux
}
log.Debugf("netstack device is ready to use") log.Debugf("netstack device is ready to use")
return udpMux, nil return udpMux, nil
} }

View File

@@ -0,0 +1,27 @@
package device
import (
"testing"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func TestNewNetstackDevice(t *testing.T) {
privateKey, _ := wgtypes.GeneratePrivateKey()
wgAddress, _ := wgaddr.ParseWGAddress("1.2.3.4/24")
relayBind := bind.NewRelayBindJS()
nsTun := NewNetstackDevice("wtx", wgAddress, 1234, privateKey.String(), 1500, relayBind, netstack.ListenAddr())
cfgr, err := nsTun.Create()
if err != nil {
t.Fatalf("failed to create netstack device: %v", err)
}
if cfgr == nil {
t.Fatal("expected non-nil configurer")
}
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -25,7 +26,7 @@ type USPDevice struct {
device *device.Device device *device.Device
filteredDevice *FilteredDevice filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer configurer WGConfigurer
} }
@@ -74,7 +75,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) {
return t.configurer, nil return t.configurer, nil
} }
func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) { func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
if t.device == nil { if t.device == nil {
return nil, fmt.Errorf("device is not ready yet") return nil, fmt.Errorf("device is not ready yet")
} }

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -29,7 +30,7 @@ type TunDevice struct {
device *device.Device device *device.Device
nativeTunDevice *tun.NativeTun nativeTunDevice *tun.NativeTun
filteredDevice *FilteredDevice filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer configurer WGConfigurer
} }
@@ -104,7 +105,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil return t.configurer, nil
} }
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
err := t.device.Up() err := t.device.Up()
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -5,14 +5,14 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type WGTunDevice interface { type WGTunDevice interface {
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error) Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*udpmux.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address WgAddress() wgaddr.Address
MTU() uint16 MTU() uint16

View File

@@ -16,9 +16,9 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/monotime" "github.com/netbirdio/netbird/monotime"
@@ -61,7 +61,7 @@ type WGIFaceOpts struct {
MTU uint16 MTU uint16
MobileArgs *device.MobileIFaceArguments MobileArgs *device.MobileIFaceArguments
TransportNet transport.Net TransportNet transport.Net
FilterFn bind.FilterFn FilterFn udpmux.FilterFn
DisableDNS bool DisableDNS bool
} }
@@ -114,7 +114,7 @@ func (r *WGIface) ToInterface() *net.Interface {
// Up configures a Wireguard interface // Up configures a Wireguard interface
// The interface must exist before calling this method (e.g. call interface.Create() before) // The interface must exist before calling this method (e.g. call interface.Create() before)
func (w *WGIface) Up() (*bind.UniversalUDPMuxDefault, error) { func (w *WGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()

View File

@@ -0,0 +1,6 @@
package iface
// Destroy is a no-op on WASM
func (w *WGIface) Destroy() error {
return nil
}

View File

@@ -21,7 +21,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{ wgIFace := &WGIface{
userspaceBind: true, userspaceBind: true,
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()), tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind), wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
} }
return wgIFace, nil return wgIFace, nil
} }
@@ -29,7 +29,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{ wgIFace := &WGIface{
userspaceBind: true, userspaceBind: true,
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS), tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind), wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
} }
return wgIFace, nil return wgIFace, nil
} }

View File

@@ -29,7 +29,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{ wgIFace := &WGIface{
userspaceBind: true, userspaceBind: true,
tun: tun, tun: tun,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind), wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
} }
return wgIFace, nil return wgIFace, nil
} }

View File

@@ -0,0 +1,41 @@
//go:build freebsd
package iface
import (
"fmt"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{}
if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
return wgIFace, nil
}
if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
return wgIFace, nil
}
return nil, fmt.Errorf("couldn't check or load tun module")
}

View File

@@ -21,7 +21,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{ wgIFace := &WGIface{
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd), tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),
userspaceBind: true, userspaceBind: true,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind), wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
} }
return wgIFace, nil return wgIFace, nil
} }

View File

@@ -0,0 +1,27 @@
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace creates a new WireGuard interface for WASM (always uses netstack mode)
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
relayBind := bind.NewRelayBindJS()
wgIface := &WGIface{
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, relayBind, netstack.ListenAddr()),
userspaceBind: true,
wgProxyFactory: wgproxy.NewUSPFactory(relayBind, opts.MTU),
}
return wgIface, nil
}

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd //go:build linux && !android
package iface package iface
@@ -25,7 +25,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
return wgIFace, nil return wgIFace, nil
} }
@@ -38,7 +38,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
return wgIFace, nil return wgIFace, nil
} }

View File

@@ -26,7 +26,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{ wgIFace := &WGIface{
userspaceBind: true, userspaceBind: true,
tun: tun, tun: tun,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind), wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
} }
return wgIFace, nil return wgIFace, nil

View File

@@ -1,3 +1,5 @@
//go:build !js
package netstack package netstack
import ( import (

View File

@@ -0,0 +1,12 @@
package netstack
const EnvUseNetstackMode = "NB_USE_NETSTACK_MODE"
// IsEnabled always returns true for js since it's the only mode available
func IsEnabled() bool {
return true
}
func ListenAddr() string {
return ""
}

View File

@@ -1,4 +1,4 @@
package bind package udpmux
/* /*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements
@@ -16,11 +16,12 @@ import (
) )
type udpMuxedConnParams struct { type udpMuxedConnParams struct {
Mux *UDPMuxDefault Mux *SingleSocketUDPMux
AddrPool *sync.Pool AddrPool *sync.Pool
Key string Key string
LocalAddr net.Addr LocalAddr net.Addr
Logger logging.LeveledLogger Logger logging.LeveledLogger
CandidateID string
} }
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag // udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag
@@ -119,6 +120,10 @@ func (c *udpMuxedConn) Close() error {
return err return err
} }
func (c *udpMuxedConn) GetCandidateID() string {
return c.params.CandidateID
}
func (c *udpMuxedConn) isClosed() bool { func (c *udpMuxedConn) isClosed() bool {
select { select {
case <-c.closedChan: case <-c.closedChan:

View File

@@ -0,0 +1,64 @@
// Package udpmux provides a custom implementation of a UDP multiplexer
// that allows multiple logical ICE connections to share a single underlying
// UDP socket. This is based on Pion's ICE library, with modifications for
// NetBird's requirements.
//
// # Background
//
// In WebRTC and NAT traversal scenarios, ICE (Interactive Connectivity
// Establishment) is responsible for discovering candidate network paths
// and maintaining connectivity between peers. Each ICE connection
// normally requires a dedicated UDP socket. However, using one socket
// per candidate can be inefficient and difficult to manage.
//
// This package introduces SingleSocketUDPMux, which allows multiple ICE
// candidate connections (muxed connections) to share a single UDP socket.
// It handles demultiplexing of packets based on ICE ufrag values, STUN
// attributes, and candidate IDs.
//
// # Usage
//
// The typical flow is:
//
// 1. Create a UDP socket (net.PacketConn).
// 2. Construct Params with the socket and optional logger/net stack.
// 3. Call NewSingleSocketUDPMux(params).
// 4. For each ICE candidate ufrag, call GetConn(ufrag, addr, candidateID)
// to obtain a logical PacketConn.
// 5. Use the returned PacketConn just like a normal UDP connection.
//
// # STUN Message Routing Logic
//
// When a STUN packet arrives, the mux decides which connection should
// receive it using this routing logic:
//
// Primary Routing: Candidate Pair ID
// - Extract the candidate pair ID from the STUN message using
// ice.CandidatePairIDFromSTUN(msg)
// - The target candidate is the locally generated candidate that
// corresponds to the connection that should handle this STUN message
// - If found, use the target candidate ID to lookup the specific
// connection in candidateConnMap
// - Route the message directly to that connection
//
// Fallback Routing: Broadcasting
// When candidate pair ID is not available or lookup fails:
// - Collect connections from addressMap based on source address
// - Find connection using username attribute (ufrag) from STUN message
// - Remove duplicate connections from the list
// - Send the STUN message to all collected connections
//
// # Peer Reflexive Candidate Discovery
//
// When a remote peer sends a STUN message from an unknown source address
// (from a candidate that has not been exchanged via signal), the ICE
// library will:
// - Generate a new peer reflexive candidate for this source address
// - Extract or assign a candidate ID based on the STUN message attributes
// - Create a mapping between the new peer reflexive candidate ID and
// the appropriate local connection
//
// This discovery mechanism ensures that STUN messages from newly discovered
// peer reflexive candidates can be properly routed to the correct local
// connection without requiring fallback broadcasting.
package udpmux

View File

@@ -1,4 +1,4 @@
package bind package udpmux
import ( import (
"fmt" "fmt"
@@ -8,9 +8,9 @@ import (
"strings" "strings"
"sync" "sync"
"github.com/pion/ice/v3" "github.com/pion/ice/v4"
"github.com/pion/logging" "github.com/pion/logging"
"github.com/pion/stun/v2" "github.com/pion/stun/v3"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
"github.com/pion/transport/v3/stdnet" "github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -22,9 +22,9 @@ import (
const receiveMTU = 8192 const receiveMTU = 8192
// UDPMuxDefault is an implementation of the interface // SingleSocketUDPMux is an implementation of the interface
type UDPMuxDefault struct { type SingleSocketUDPMux struct {
params UDPMuxParams params Params
closedChan chan struct{} closedChan chan struct{}
closeOnce sync.Once closeOnce sync.Once
@@ -32,6 +32,9 @@ type UDPMuxDefault struct {
// connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType // connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType
connsIPv4, connsIPv6 map[string]*udpMuxedConn connsIPv4, connsIPv6 map[string]*udpMuxedConn
// candidateConnMap maps local candidate IDs to their corresponding connection.
candidateConnMap map[string]*udpMuxedConn
addressMapMu sync.RWMutex addressMapMu sync.RWMutex
addressMap map[string][]*udpMuxedConn addressMap map[string][]*udpMuxedConn
@@ -46,8 +49,8 @@ type UDPMuxDefault struct {
const maxAddrSize = 512 const maxAddrSize = 512
// UDPMuxParams are parameters for UDPMux. // Params are parameters for UDPMux.
type UDPMuxParams struct { type Params struct {
Logger logging.LeveledLogger Logger logging.LeveledLogger
UDPConn net.PacketConn UDPConn net.PacketConn
@@ -147,17 +150,18 @@ func isZeros(ip net.IP) bool {
return true return true
} }
// NewUDPMuxDefault creates an implementation of UDPMux // NewSingleSocketUDPMux creates an implementation of UDPMux
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { func NewSingleSocketUDPMux(params Params) *SingleSocketUDPMux {
if params.Logger == nil { if params.Logger == nil {
params.Logger = getLogger() params.Logger = getLogger()
} }
mux := &UDPMuxDefault{ mux := &SingleSocketUDPMux{
addressMap: map[string][]*udpMuxedConn{}, addressMap: map[string][]*udpMuxedConn{},
params: params, params: params,
connsIPv4: make(map[string]*udpMuxedConn), connsIPv4: make(map[string]*udpMuxedConn),
connsIPv6: make(map[string]*udpMuxedConn), connsIPv6: make(map[string]*udpMuxedConn),
candidateConnMap: make(map[string]*udpMuxedConn),
closedChan: make(chan struct{}, 1), closedChan: make(chan struct{}, 1),
pool: &sync.Pool{ pool: &sync.Pool{
New: func() interface{} { New: func() interface{} {
@@ -171,15 +175,15 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
return mux return mux
} }
func (m *UDPMuxDefault) updateLocalAddresses() { func (m *SingleSocketUDPMux) updateLocalAddresses() {
var localAddrsForUnspecified []net.Addr var localAddrsForUnspecified []net.Addr
if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok { if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr()) m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr())
} else if ok && addr.IP.IsUnspecified() { } else if ok && addr.IP.IsUnspecified() {
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but // For unspecified addresses, the correct behavior is to return errListenUnspecified, but
// it will break the applications that are already using unspecified UDP connection // it will break the applications that are already using unspecified UDP connection
// with UDPMuxDefault, so print a warn log and create a local address list for mux. // with SingleSocketUDPMux, so print a warn log and create a local address list for mux.
m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") m.params.Logger.Warn("SingleSocketUDPMux should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
var networks []ice.NetworkType var networks []ice.NetworkType
switch { switch {
@@ -216,13 +220,13 @@ func (m *UDPMuxDefault) updateLocalAddresses() {
m.mu.Unlock() m.mu.Unlock()
} }
// LocalAddr returns the listening address of this UDPMuxDefault // LocalAddr returns the listening address of this SingleSocketUDPMux
func (m *UDPMuxDefault) LocalAddr() net.Addr { func (m *SingleSocketUDPMux) LocalAddr() net.Addr {
return m.params.UDPConn.LocalAddr() return m.params.UDPConn.LocalAddr()
} }
// GetListenAddresses returns the list of addresses that this mux is listening on // GetListenAddresses returns the list of addresses that this mux is listening on
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr { func (m *SingleSocketUDPMux) GetListenAddresses() []net.Addr {
m.updateLocalAddresses() m.updateLocalAddresses()
m.mu.Lock() m.mu.Lock()
@@ -236,7 +240,7 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
// GetConn returns a PacketConn given the connection's ufrag and network address // GetConn returns a PacketConn given the connection's ufrag and network address
// creates the connection if an existing one can't be found // creates the connection if an existing one can't be found
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) { func (m *SingleSocketUDPMux) GetConn(ufrag string, addr net.Addr, candidateID string) (net.PacketConn, error) {
// don't check addr for mux using unspecified address // don't check addr for mux using unspecified address
m.mu.Lock() m.mu.Lock()
lenLocalAddrs := len(m.localAddrsForUnspecified) lenLocalAddrs := len(m.localAddrsForUnspecified)
@@ -260,12 +264,14 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er
return conn, nil return conn, nil
} }
c := m.createMuxedConn(ufrag) c := m.createMuxedConn(ufrag, candidateID)
go func() { go func() {
<-c.CloseChannel() <-c.CloseChannel()
m.RemoveConnByUfrag(ufrag) m.RemoveConnByUfrag(ufrag)
}() }()
m.candidateConnMap[candidateID] = c
if isIPv6 { if isIPv6 {
m.connsIPv6[ufrag] = c m.connsIPv6[ufrag] = c
} else { } else {
@@ -276,7 +282,7 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er
} }
// RemoveConnByUfrag stops and removes the muxed packet connection // RemoveConnByUfrag stops and removes the muxed packet connection
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { func (m *SingleSocketUDPMux) RemoveConnByUfrag(ufrag string) {
removedConns := make([]*udpMuxedConn, 0, 2) removedConns := make([]*udpMuxedConn, 0, 2)
// Keep lock section small to avoid deadlock with conn lock // Keep lock section small to avoid deadlock with conn lock
@@ -284,10 +290,12 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
if c, ok := m.connsIPv4[ufrag]; ok { if c, ok := m.connsIPv4[ufrag]; ok {
delete(m.connsIPv4, ufrag) delete(m.connsIPv4, ufrag)
removedConns = append(removedConns, c) removedConns = append(removedConns, c)
delete(m.candidateConnMap, c.GetCandidateID())
} }
if c, ok := m.connsIPv6[ufrag]; ok { if c, ok := m.connsIPv6[ufrag]; ok {
delete(m.connsIPv6, ufrag) delete(m.connsIPv6, ufrag)
removedConns = append(removedConns, c) removedConns = append(removedConns, c)
delete(m.candidateConnMap, c.GetCandidateID())
} }
m.mu.Unlock() m.mu.Unlock()
@@ -314,7 +322,7 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
} }
// IsClosed returns true if the mux had been closed // IsClosed returns true if the mux had been closed
func (m *UDPMuxDefault) IsClosed() bool { func (m *SingleSocketUDPMux) IsClosed() bool {
select { select {
case <-m.closedChan: case <-m.closedChan:
return true return true
@@ -324,7 +332,7 @@ func (m *UDPMuxDefault) IsClosed() bool {
} }
// Close the mux, no further connections could be created // Close the mux, no further connections could be created
func (m *UDPMuxDefault) Close() error { func (m *SingleSocketUDPMux) Close() error {
var err error var err error
m.closeOnce.Do(func() { m.closeOnce.Do(func() {
m.mu.Lock() m.mu.Lock()
@@ -347,11 +355,11 @@ func (m *UDPMuxDefault) Close() error {
return err return err
} }
func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) { func (m *SingleSocketUDPMux) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
return m.params.UDPConn.WriteTo(buf, rAddr) return m.params.UDPConn.WriteTo(buf, rAddr)
} }
func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) { func (m *SingleSocketUDPMux) registerConnForAddress(conn *udpMuxedConn, addr string) {
if m.IsClosed() { if m.IsClosed() {
return return
} }
@@ -368,81 +376,109 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key) log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
} }
func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn { func (m *SingleSocketUDPMux) createMuxedConn(key string, candidateID string) *udpMuxedConn {
c := newUDPMuxedConn(&udpMuxedConnParams{ c := newUDPMuxedConn(&udpMuxedConnParams{
Mux: m, Mux: m,
Key: key, Key: key,
AddrPool: m.pool, AddrPool: m.pool,
LocalAddr: m.LocalAddr(), LocalAddr: m.LocalAddr(),
Logger: m.params.Logger, Logger: m.params.Logger,
CandidateID: candidateID,
}) })
return c return c
} }
// HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library // HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library
func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error { func (m *SingleSocketUDPMux) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
remoteAddr, ok := addr.(*net.UDPAddr) remoteAddr, ok := addr.(*net.UDPAddr)
if !ok { if !ok {
return fmt.Errorf("underlying PacketConn did not return a UDPAddr") return fmt.Errorf("underlying PacketConn did not return a UDPAddr")
} }
// If we have already seen this address dispatch to the appropriate destination // Try to route to specific candidate connection first
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one if conn := m.findCandidateConnection(msg); conn != nil {
// muxed connection - one for the SRFLX candidate and the other one for the HOST one. return conn.writePacket(msg.Raw, remoteAddr)
// We will then forward STUN packets to each of these connections. }
m.addressMapMu.RLock()
// Fallback: route to all possible connections
return m.forwardToAllConnections(msg, addr, remoteAddr)
}
// findCandidateConnection attempts to find the specific connection for a STUN message
func (m *SingleSocketUDPMux) findCandidateConnection(msg *stun.Message) *udpMuxedConn {
candidatePairID, ok, err := ice.CandidatePairIDFromSTUN(msg)
if err != nil {
return nil
} else if !ok {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
conn, exists := m.candidateConnMap[candidatePairID.TargetCandidateID()]
if !exists {
return nil
}
return conn
}
// forwardToAllConnections forwards STUN message to all relevant connections
func (m *SingleSocketUDPMux) forwardToAllConnections(msg *stun.Message, addr net.Addr, remoteAddr *net.UDPAddr) error {
var destinationConnList []*udpMuxedConn var destinationConnList []*udpMuxedConn
// Add connections from address map
m.addressMapMu.RLock()
if storedConns, ok := m.addressMap[addr.String()]; ok { if storedConns, ok := m.addressMap[addr.String()]; ok {
destinationConnList = append(destinationConnList, storedConns...) destinationConnList = append(destinationConnList, storedConns...)
} }
m.addressMapMu.RUnlock() m.addressMapMu.RUnlock()
var isIPv6 bool if conn, ok := m.findConnectionByUsername(msg, addr); ok {
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil { // If we have already seen this address dispatch to the appropriate destination
isIPv6 = true // If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
// We will then forward STUN packets to each of these connections.
if !m.connectionExists(conn, destinationConnList) {
destinationConnList = append(destinationConnList, conn)
}
} }
// This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront. // Forward to all found connections
// However, we can take a username attribute from the STUN message which contains ufrag.
// We can use ufrag to identify the destination conn to route packet to.
attr, stunAttrErr := msg.Get(stun.AttrUsername)
if stunAttrErr == nil {
ufrag := strings.Split(string(attr), ":")[0]
m.mu.Lock()
destinationConn := m.connsIPv4[ufrag]
if isIPv6 {
destinationConn = m.connsIPv6[ufrag]
}
if destinationConn != nil {
exists := false
for _, conn := range destinationConnList {
if conn.params.Key == destinationConn.params.Key {
exists = true
break
}
}
if !exists {
destinationConnList = append(destinationConnList, destinationConn)
}
}
m.mu.Unlock()
}
// Forward STUN packets to each destination connections even thought the STUN packet might not belong there.
// It will be discarded by the further ICE candidate logic if so.
for _, conn := range destinationConnList { for _, conn := range destinationConnList {
if err := conn.writePacket(msg.Raw, remoteAddr); err != nil { if err := conn.writePacket(msg.Raw, remoteAddr); err != nil {
log.Errorf("could not write packet: %v", err) log.Errorf("could not write packet: %v", err)
} }
} }
return nil return nil
} }
func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) { // findConnectionByUsername finds connection using username attribute from STUN message
func (m *SingleSocketUDPMux) findConnectionByUsername(msg *stun.Message, addr net.Addr) (*udpMuxedConn, bool) {
attr, err := msg.Get(stun.AttrUsername)
if err != nil {
return nil, false
}
ufrag := strings.Split(string(attr), ":")[0]
isIPv6 := isIPv6Address(addr)
m.mu.Lock()
defer m.mu.Unlock()
return m.getConn(ufrag, isIPv6)
}
// connectionExists checks if a connection already exists in the list
func (m *SingleSocketUDPMux) connectionExists(target *udpMuxedConn, conns []*udpMuxedConn) bool {
for _, conn := range conns {
if conn.params.Key == target.params.Key {
return true
}
}
return false
}
func (m *SingleSocketUDPMux) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
if isIPv6 { if isIPv6 {
val, ok = m.connsIPv6[ufrag] val, ok = m.connsIPv6[ufrag]
} else { } else {
@@ -451,6 +487,13 @@ func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, o
return return
} }
func isIPv6Address(addr net.Addr) bool {
if udpAddr, ok := addr.(*net.UDPAddr); ok {
return udpAddr.IP.To4() == nil
}
return false
}
type bufferHolder struct { type bufferHolder struct {
buf []byte buf []byte
} }

View File

@@ -1,12 +1,12 @@
//go:build !ios //go:build !ios
package bind package udpmux
import ( import (
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) { func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
// Kernel mode: direct nbnet.PacketConn (SharedSocket wrapped with nbnet) // Kernel mode: direct nbnet.PacketConn (SharedSocket wrapped with nbnet)
if conn, ok := m.params.UDPConn.(*nbnet.PacketConn); ok { if conn, ok := m.params.UDPConn.(*nbnet.PacketConn); ok {
conn.RemoveAddress(addr) conn.RemoveAddress(addr)

View File

@@ -0,0 +1,7 @@
//go:build ios
package udpmux
func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
// iOS doesn't support nbnet hooks, so this is a no-op
}

View File

@@ -1,4 +1,4 @@
package bind package udpmux
/* /*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements. Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements.
@@ -15,7 +15,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/pion/logging" "github.com/pion/logging"
"github.com/pion/stun/v2" "github.com/pion/stun/v3"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bufsize" "github.com/netbirdio/netbird/client/iface/bufsize"
@@ -29,7 +29,7 @@ type FilterFn func(address netip.Addr) (bool, netip.Prefix, error)
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn // UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
// It then passes packets to the UDPMux that does the actual connection muxing. // It then passes packets to the UDPMux that does the actual connection muxing.
type UniversalUDPMuxDefault struct { type UniversalUDPMuxDefault struct {
*UDPMuxDefault *SingleSocketUDPMux
params UniversalUDPMuxParams params UniversalUDPMuxParams
// since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents // since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents
@@ -72,12 +72,12 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
address: params.WGAddress, address: params.WGAddress,
} }
udpMuxParams := UDPMuxParams{ udpMuxParams := Params{
Logger: params.Logger, Logger: params.Logger,
UDPConn: m.params.UDPConn, UDPConn: m.params.UDPConn,
Net: m.params.Net, Net: m.params.Net,
} }
m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams) m.SingleSocketUDPMux = NewSingleSocketUDPMux(udpMuxParams)
return m return m
} }
@@ -211,8 +211,8 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers // GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
// and return a unique connection per server. // and return a unique connection per server.
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) { func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr, candidateID string) (net.PacketConn, error) {
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr) return m.SingleSocketUDPMux.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr, candidateID)
} }
// HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server. // HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server.
@@ -233,7 +233,7 @@ func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.A
} }
return nil return nil
} }
return m.UDPMuxDefault.HandleSTUNMessage(msg, addr) return m.SingleSocketUDPMux.HandleSTUNMessage(msg, addr)
} }
// isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server. // isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server.

View File

@@ -16,28 +16,38 @@ import (
"github.com/netbirdio/netbird/client/iface/wgproxy/listener" "github.com/netbirdio/netbird/client/iface/wgproxy/listener"
) )
type ProxyBind struct { type Bind interface {
Bind *bind.ICEBind SetEndpoint(addr netip.Addr, conn net.Conn)
RemoveEndpoint(addr netip.Addr)
ReceiveFromEndpoint(ctx context.Context, ep *bind.Endpoint, buf []byte)
}
fakeNetIP *netip.AddrPort type ProxyBind struct {
wgBindEndpoint *bind.Endpoint bind Bind
// wgRelayedEndpoint is a fake address that generated by the Bind.SetEndpoint based on the remote NetBird peer address
wgRelayedEndpoint *bind.Endpoint
wgCurrentUsed *bind.Endpoint
remoteConn net.Conn remoteConn net.Conn
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
closeMu sync.Mutex closeMu sync.Mutex
closed bool closed bool
pausedMu sync.Mutex
paused bool paused bool
pausedCond *sync.Cond
isStarted bool isStarted bool
closeListener *listener.CloseListener closeListener *listener.CloseListener
mtu uint16
} }
func NewProxyBind(bind *bind.ICEBind) *ProxyBind { func NewProxyBind(bind Bind, mtu uint16) *ProxyBind {
p := &ProxyBind{ p := &ProxyBind{
Bind: bind, bind: bind,
closeListener: listener.NewCloseListener(), closeListener: listener.NewCloseListener(),
pausedCond: sync.NewCond(&sync.Mutex{}),
mtu: mtu + bufsize.WGBufferOverhead,
} }
return p return p
@@ -46,25 +56,25 @@ func NewProxyBind(bind *bind.ICEBind) *ProxyBind {
// AddTurnConn adds a new connection to the bind. // AddTurnConn adds a new connection to the bind.
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the // endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
// WireGuard configuration. // WireGuard configuration.
//
// Parameters:
// - ctx: Context is used for proxyToLocal to avoid unnecessary error messages
// - nbAddr: The NetBird UDP address of the remote peer, it required to generate fake address
// - remoteConn: The established TURN connection to the remote peer
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error { func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
fakeNetIP, err := fakeAddress(nbAddr) fakeNetIP, err := fakeAddress(nbAddr)
if err != nil { if err != nil {
return err return err
} }
p.wgRelayedEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
p.fakeNetIP = fakeNetIP
p.wgBindEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
p.remoteConn = remoteConn p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx) p.ctx, p.cancel = context.WithCancel(ctx)
return nil return nil
} }
func (p *ProxyBind) EndpointAddr() *net.UDPAddr { func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
return &net.UDPAddr{ return bind.EndpointToUDPAddr(*p.wgRelayedEndpoint)
IP: p.fakeNetIP.Addr().AsSlice(),
Port: int(p.fakeNetIP.Port()),
Zone: p.fakeNetIP.Addr().Zone(),
}
} }
func (p *ProxyBind) SetDisconnectListener(disconnected func()) { func (p *ProxyBind) SetDisconnectListener(disconnected func()) {
@@ -76,17 +86,21 @@ func (p *ProxyBind) Work() {
return return
} }
p.Bind.SetEndpoint(p.fakeNetIP.Addr(), p.remoteConn) p.bind.SetEndpoint(p.wgRelayedEndpoint.Addr(), p.remoteConn)
p.pausedMu.Lock() p.pausedCond.L.Lock()
p.paused = false p.paused = false
p.pausedMu.Unlock()
p.wgCurrentUsed = p.wgRelayedEndpoint
// Start the proxy only once // Start the proxy only once
if !p.isStarted { if !p.isStarted {
p.isStarted = true p.isStarted = true
go p.proxyToLocal(p.ctx) go p.proxyToLocal(p.ctx)
} }
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
} }
func (p *ProxyBind) Pause() { func (p *ProxyBind) Pause() {
@@ -94,9 +108,25 @@ func (p *ProxyBind) Pause() {
return return
} }
p.pausedMu.Lock() p.pausedCond.L.Lock()
p.paused = true p.paused = true
p.pausedMu.Unlock() p.pausedCond.L.Unlock()
}
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Lock()
p.paused = false
p.wgCurrentUsed = addrToEndpoint(endpoint)
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
}
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
ip, _ := netip.AddrFromSlice(addr.IP.To4())
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}
} }
func (p *ProxyBind) CloseConn() error { func (p *ProxyBind) CloseConn() error {
@@ -107,6 +137,10 @@ func (p *ProxyBind) CloseConn() error {
} }
func (p *ProxyBind) close() error { func (p *ProxyBind) close() error {
if p.remoteConn == nil {
return nil
}
p.closeMu.Lock() p.closeMu.Lock()
defer p.closeMu.Unlock() defer p.closeMu.Unlock()
@@ -120,7 +154,12 @@ func (p *ProxyBind) close() error {
p.cancel() p.cancel()
p.Bind.RemoveEndpoint(p.fakeNetIP.Addr()) p.pausedCond.L.Lock()
p.paused = false
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
p.bind.RemoveEndpoint(p.wgRelayedEndpoint.Addr())
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) { if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
return rErr return rErr
@@ -136,7 +175,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
}() }()
for { for {
buf := make([]byte, p.Bind.MTU()+bufsize.WGBufferOverhead) buf := make([]byte, p.mtu)
n, err := p.remoteConn.Read(buf) n, err := p.remoteConn.Read(buf)
if err != nil { if err != nil {
if ctx.Err() != nil { if ctx.Err() != nil {
@@ -147,18 +186,13 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
return return
} }
p.pausedMu.Lock() p.pausedCond.L.Lock()
if p.paused { for p.paused {
p.pausedMu.Unlock() p.pausedCond.Wait()
continue
} }
msg := bind.RecvMessage{ p.bind.ReceiveFromEndpoint(ctx, p.wgCurrentUsed, buf[:n])
Endpoint: p.wgBindEndpoint, p.pausedCond.L.Unlock()
Buffer: buf[:n],
}
p.Bind.RecvChan <- msg
p.pausedMu.Unlock()
} }
} }

View File

@@ -6,9 +6,7 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"os"
"sync" "sync"
"syscall"
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
@@ -18,15 +16,20 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/bufsize" "github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket"
"github.com/netbirdio/netbird/client/internal/ebpf" "github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
const ( const (
loopbackAddr = "127.0.0.1" loopbackAddr = "127.0.0.1"
) )
var (
localHostNetIP = net.ParseIP("127.0.0.1")
)
// WGEBPFProxy definition for proxy with EBPF support // WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct { type WGEBPFProxy struct {
localWGListenPort int localWGListenPort int
@@ -64,7 +67,7 @@ func (p *WGEBPFProxy) Listen() error {
return err return err
} }
p.rawConn, err = p.prepareSenderRawSocket() p.rawConn, err = rawsocket.PrepareSenderRawSocket()
if err != nil { if err != nil {
return err return err
} }
@@ -214,57 +217,17 @@ generatePort:
return p.lastUsedPort, nil return p.lastUsedPort, nil
} }
func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) { func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
// Create a raw socket.
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
if err != nil {
return nil, fmt.Errorf("creating raw socket failed: %w", err)
}
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
if err != nil {
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
}
// Bind the socket to the "lo" interface.
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
if err != nil {
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
}
// Set the fwmark on the socket.
err = nbnet.SetSocketOpt(fd)
if err != nil {
return nil, fmt.Errorf("setting fwmark failed: %w", err)
}
// Convert the file descriptor to a PacketConn.
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
if file == nil {
return nil, fmt.Errorf("converting fd to file failed")
}
packetConn, err := net.FilePacketConn(file)
if err != nil {
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
}
return packetConn, nil
}
func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
localhost := net.ParseIP("127.0.0.1")
payload := gopacket.Payload(data) payload := gopacket.Payload(data)
ipH := &layers.IPv4{ ipH := &layers.IPv4{
DstIP: localhost, DstIP: localHostNetIP,
SrcIP: localhost, SrcIP: endpointAddr.IP,
Version: 4, Version: 4,
TTL: 64, TTL: 64,
Protocol: layers.IPProtocolUDP, Protocol: layers.IPProtocolUDP,
} }
udpH := &layers.UDP{ udpH := &layers.UDP{
SrcPort: layers.UDPPort(port), SrcPort: layers.UDPPort(endpointAddr.Port),
DstPort: layers.UDPPort(p.localWGListenPort), DstPort: layers.UDPPort(p.localWGListenPort),
} }
@@ -279,7 +242,7 @@ func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
if err != nil { if err != nil {
return fmt.Errorf("serialize layers: %w", err) return fmt.Errorf("serialize layers: %w", err)
} }
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}); err != nil { if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil {
return fmt.Errorf("write to raw conn: %w", err) return fmt.Errorf("write to raw conn: %w", err)
} }
return nil return nil

View File

@@ -18,41 +18,42 @@ import (
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
type ProxyWrapper struct { type ProxyWrapper struct {
WgeBPFProxy *WGEBPFProxy wgeBPFProxy *WGEBPFProxy
remoteConn net.Conn remoteConn net.Conn
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
wgEndpointAddr *net.UDPAddr wgRelayedEndpointAddr *net.UDPAddr
wgEndpointCurrentUsedAddr *net.UDPAddr
pausedMu sync.Mutex
paused bool paused bool
pausedCond *sync.Cond
isStarted bool isStarted bool
closeListener *listener.CloseListener closeListener *listener.CloseListener
} }
func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper { func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
return &ProxyWrapper{ return &ProxyWrapper{
WgeBPFProxy: WgeBPFProxy, wgeBPFProxy: proxy,
pausedCond: sync.NewCond(&sync.Mutex{}),
closeListener: listener.NewCloseListener(), closeListener: listener.NewCloseListener(),
} }
} }
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error { func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn) addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
if err != nil { if err != nil {
return fmt.Errorf("add turn conn: %w", err) return fmt.Errorf("add turn conn: %w", err)
} }
p.remoteConn = remoteConn p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx) p.ctx, p.cancel = context.WithCancel(ctx)
p.wgEndpointAddr = addr p.wgRelayedEndpointAddr = addr
return err return err
} }
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr { func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
return p.wgEndpointAddr return p.wgRelayedEndpointAddr
} }
func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) { func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) {
@@ -64,14 +65,18 @@ func (p *ProxyWrapper) Work() {
return return
} }
p.pausedMu.Lock() p.pausedCond.L.Lock()
p.paused = false p.paused = false
p.pausedMu.Unlock()
p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr
if !p.isStarted { if !p.isStarted {
p.isStarted = true p.isStarted = true
go p.proxyToLocal(p.ctx) go p.proxyToLocal(p.ctx)
} }
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
} }
func (p *ProxyWrapper) Pause() { func (p *ProxyWrapper) Pause() {
@@ -80,45 +85,59 @@ func (p *ProxyWrapper) Pause() {
} }
log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr()) log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
p.pausedMu.Lock() p.pausedCond.L.Lock()
p.paused = true p.paused = true
p.pausedMu.Unlock() p.pausedCond.L.Unlock()
}
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Lock()
p.paused = false
p.wgEndpointCurrentUsedAddr = endpoint
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
} }
// CloseConn close the remoteConn and automatically remove the conn instance from the map // CloseConn close the remoteConn and automatically remove the conn instance from the map
func (e *ProxyWrapper) CloseConn() error { func (p *ProxyWrapper) CloseConn() error {
if e.cancel == nil { if p.cancel == nil {
return fmt.Errorf("proxy not started") return fmt.Errorf("proxy not started")
} }
e.cancel() p.cancel()
e.closeListener.SetCloseListener(nil) p.closeListener.SetCloseListener(nil)
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { p.pausedCond.L.Lock()
return fmt.Errorf("close remote conn: %w", err) p.paused = false
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("failed to close remote conn: %w", err)
} }
return nil return nil
} }
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) { func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port)) defer p.wgeBPFProxy.removeTurnConn(uint16(p.wgRelayedEndpointAddr.Port))
buf := make([]byte, p.WgeBPFProxy.mtu+bufsize.WGBufferOverhead) buf := make([]byte, p.wgeBPFProxy.mtu+bufsize.WGBufferOverhead)
for { for {
n, err := p.readFromRemote(ctx, buf) n, err := p.readFromRemote(ctx, buf)
if err != nil { if err != nil {
return return
} }
p.pausedMu.Lock() p.pausedCond.L.Lock()
if p.paused { for p.paused {
p.pausedMu.Unlock() p.pausedCond.Wait()
continue
} }
err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port) err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr)
p.pausedMu.Unlock() p.pausedCond.L.Unlock()
if err != nil { if err != nil {
if ctx.Err() != nil { if ctx.Err() != nil {
@@ -137,7 +156,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
} }
p.closeListener.Notify() p.closeListener.Notify()
if !errors.Is(err, io.EOF) { if !errors.Is(err, io.EOF) {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err) log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgRelayedEndpointAddr.Port, err)
} }
return 0, err return 0, err
} }

View File

@@ -39,7 +39,6 @@ func (w *KernelFactory) GetProxy() Proxy {
} }
return ebpf.NewProxyWrapper(w.ebpfProxy) return ebpf.NewProxyWrapper(w.ebpfProxy)
} }
func (w *KernelFactory) Free() error { func (w *KernelFactory) Free() error {

View File

@@ -1,31 +0,0 @@
package wgproxy
import (
log "github.com/sirupsen/logrus"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
)
// KernelFactory todo: check eBPF support on FreeBSD
type KernelFactory struct {
wgPort int
mtu uint16
}
func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
f := &KernelFactory{
wgPort: wgPort,
mtu: mtu,
}
return f
}
func (w *KernelFactory) GetProxy() Proxy {
return udpProxy.NewWGUDPProxy(w.wgPort, w.mtu)
}
func (w *KernelFactory) Free() error {
return nil
}

View File

@@ -3,24 +3,25 @@ package wgproxy
import ( import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind"
proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind" proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
) )
type USPFactory struct { type USPFactory struct {
bind *bind.ICEBind bind proxyBind.Bind
mtu uint16
} }
func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory { func NewUSPFactory(bind proxyBind.Bind, mtu uint16) *USPFactory {
log.Infof("WireGuard Proxy Factory will produce bind proxy") log.Infof("WireGuard Proxy Factory will produce bind proxy")
f := &USPFactory{ f := &USPFactory{
bind: iceBind, bind: bind,
mtu: mtu,
} }
return f return f
} }
func (w *USPFactory) GetProxy() Proxy { func (w *USPFactory) GetProxy() Proxy {
return proxyBind.NewProxyBind(w.bind) return proxyBind.NewProxyBind(w.bind, w.mtu)
} }
func (w *USPFactory) Free() error { func (w *USPFactory) Free() error {

View File

@@ -11,6 +11,11 @@ type Proxy interface {
EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint
Work() // Work start or resume the proxy Work() // Work start or resume the proxy
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works. Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
//RedirectAs resume the forwarding the packages from relayed connection to WireGuard interface if it was paused
//and rewrite the src address to the endpoint address.
//With this logic can avoid the package loss from relayed connections.
RedirectAs(endpoint *net.UDPAddr)
CloseConn() error CloseConn() error
SetDisconnectListener(disconnected func()) SetDisconnectListener(disconnected func())
} }

View File

@@ -3,54 +3,82 @@
package wgproxy package wgproxy
import ( import (
"context" "fmt"
"os" "net"
"testing"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/wgaddr"
bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
"github.com/netbirdio/netbird/client/iface/wgproxy/udp"
) )
func TestProxyCloseByRemoteConnEBPF(t *testing.T) { func seedProxies() ([]proxyInstance, error) {
if os.Getenv("GITHUB_ACTIONS") != "true" { pl := make([]proxyInstance, 0)
t.Skip("Skipping test as it requires root privileges")
}
ctx := context.Background()
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280) ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
if err := ebpfProxy.Listen(); err != nil { if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %s", err) return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err)
} }
defer func() { pEbpf := proxyInstance{
if err := ebpfProxy.Free(); err != nil { name: "ebpf kernel proxy",
t.Errorf("failed to free ebpf proxy: %s", err) proxy: ebpf.NewProxyWrapper(ebpfProxy),
wgPort: 51831,
closeFn: ebpfProxy.Free,
} }
}() pl = append(pl, pEbpf)
tests := []struct { pUDP := proxyInstance{
name string name: "udp kernel proxy",
proxy Proxy proxy: udp.NewWGUDPProxy(51832, 1280),
}{ wgPort: 51832,
{ closeFn: func() error { return nil },
name: "ebpf proxy", }
proxy: &ebpf.ProxyWrapper{ pl = append(pl, pUDP)
WgeBPFProxy: ebpfProxy, return pl, nil
},
},
} }
for _, tt := range tests { func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
t.Run(tt.name, func(t *testing.T) { pl := make([]proxyInstance, 0)
relayedConn := newMockConn()
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn) ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
if err := ebpfProxy.Listen(); err != nil {
return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err)
}
pEbpf := proxyInstance{
name: "ebpf kernel proxy",
proxy: ebpf.NewProxyWrapper(ebpfProxy),
wgPort: 51831,
closeFn: ebpfProxy.Free,
}
pl = append(pl, pEbpf)
pUDP := proxyInstance{
name: "udp kernel proxy",
proxy: udp.NewWGUDPProxy(51832, 1280),
wgPort: 51832,
closeFn: func() error { return nil },
}
pl = append(pl, pUDP)
wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32")
if err != nil { if err != nil {
t.Errorf("error: %v", err) return nil, err
}
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
endpointAddress := &net.UDPAddr{
IP: net.IPv4(10, 0, 0, 1),
Port: 1234,
} }
_ = relayedConn.Close() pBind := proxyInstance{
if err := tt.proxy.CloseConn(); err != nil { name: "bind proxy",
t.Errorf("error: %v", err) proxy: bindproxy.NewProxyBind(iceBind, 0),
} endpointAddr: endpointAddress,
}) closeFn: func() error { return nil },
} }
pl = append(pl, pBind)
return pl, nil
} }

View File

@@ -0,0 +1,39 @@
//go:build !linux
package wgproxy
import (
"net"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/wgaddr"
bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
)
func seedProxies() ([]proxyInstance, error) {
// todo extend with Bind proxy
pl := make([]proxyInstance, 0)
return pl, nil
}
func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
pl := make([]proxyInstance, 0)
wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32")
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
endpointAddress := &net.UDPAddr{
IP: net.IPv4(10, 0, 0, 1),
Port: 1234,
}
pBind := proxyInstance{
name: "bind proxy",
proxy: bindproxy.NewProxyBind(iceBind, 0),
endpointAddr: endpointAddress,
closeFn: func() error { return nil },
}
pl = append(pl, pBind)
return pl, nil
}

View File

@@ -1,5 +1,3 @@
//go:build linux
package wgproxy package wgproxy
import ( import (
@@ -7,12 +5,9 @@ import (
"io" "io"
"net" "net"
"os" "os"
"runtime"
"testing" "testing"
"time" "time"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@@ -22,6 +17,14 @@ func TestMain(m *testing.M) {
os.Exit(code) os.Exit(code)
} }
type proxyInstance struct {
name string
proxy Proxy
wgPort int
endpointAddr *net.UDPAddr
closeFn func() error
}
type mocConn struct { type mocConn struct {
closeChan chan struct{} closeChan chan struct{}
closed bool closed bool
@@ -78,41 +81,21 @@ func (m *mocConn) SetWriteDeadline(t time.Time) error {
func TestProxyCloseByRemoteConn(t *testing.T) { func TestProxyCloseByRemoteConn(t *testing.T) {
ctx := context.Background() ctx := context.Background()
tests := []struct { tests, err := seedProxyForProxyCloseByRemoteConn()
name string if err != nil {
proxy Proxy t.Fatalf("error: %v", err)
}{
{
name: "userspace proxy",
proxy: udpProxy.NewWGUDPProxy(51830, 1280),
},
} }
if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" { relayedConn, _ := net.Dial("udp", "127.0.0.1:1234")
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %s", err)
}
defer func() { defer func() {
if err := ebpfProxy.Free(); err != nil { _ = relayedConn.Close()
t.Errorf("failed to free ebpf proxy: %s", err)
}
}() }()
proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy)
tests = append(tests, struct {
name string
proxy Proxy
}{
name: "ebpf proxy",
proxy: proxyWrapper,
})
}
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
addr, _ := net.ResolveUDPAddr("udp", "100.108.135.221:51892")
relayedConn := newMockConn() relayedConn := newMockConn()
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn) err := tt.proxy.AddTurnConn(ctx, addr, relayedConn)
if err != nil { if err != nil {
t.Errorf("error: %v", err) t.Errorf("error: %v", err)
} }
@@ -124,3 +107,104 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
}) })
} }
} }
// TestProxyRedirect todo extend the proxies with Bind proxy
func TestProxyRedirect(t *testing.T) {
tests, err := seedProxies()
if err != nil {
t.Fatalf("error: %v", err)
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
redirectTraffic(t, tt.proxy, tt.wgPort, tt.endpointAddr)
if err := tt.closeFn(); err != nil {
t.Errorf("error: %v", err)
}
})
}
}
func redirectTraffic(t *testing.T, proxy Proxy, wgPort int, endPointAddr *net.UDPAddr) {
t.Helper()
msgHelloFromRelay := []byte("hello from relay")
msgRedirected := [][]byte{
[]byte("hello 1. to p2p"),
[]byte("hello 2. to p2p"),
[]byte("hello 3. to p2p"),
}
dummyWgListener, err := net.ListenUDP("udp", &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: wgPort})
if err != nil {
t.Fatalf("failed to listen on udp port: %s", err)
}
relayedServer, _ := net.ListenUDP("udp",
&net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 1234,
},
)
relayedConn, _ := net.Dial("udp", "127.0.0.1:1234")
defer func() {
_ = dummyWgListener.Close()
_ = relayedConn.Close()
_ = relayedServer.Close()
}()
if err := proxy.AddTurnConn(context.Background(), endPointAddr, relayedConn); err != nil {
t.Errorf("error: %v", err)
}
defer func() {
if err := proxy.CloseConn(); err != nil {
t.Errorf("error: %v", err)
}
}()
proxy.Work()
if _, err := relayedServer.WriteTo(msgHelloFromRelay, relayedConn.LocalAddr()); err != nil {
t.Errorf("error relayedServer.Write(msgHelloFromRelay): %v", err)
}
n, err := dummyWgListener.Read(make([]byte, 1024))
if err != nil {
t.Errorf("error: %v", err)
}
if n != len(msgHelloFromRelay) {
t.Errorf("expected %d bytes, got %d", len(msgHelloFromRelay), n)
}
p2pEndpointAddr := &net.UDPAddr{
IP: net.IPv4(192, 168, 0, 56),
Port: 1234,
}
proxy.RedirectAs(p2pEndpointAddr)
for _, msg := range msgRedirected {
if _, err := relayedServer.WriteTo(msg, relayedConn.LocalAddr()); err != nil {
t.Errorf("error: %v", err)
}
}
for i := 0; i < len(msgRedirected); i++ {
buf := make([]byte, 1024)
n, rAddr, err := dummyWgListener.ReadFrom(buf)
if err != nil {
t.Errorf("error: %v", err)
}
if rAddr.String() != p2pEndpointAddr.String() {
t.Errorf("expected %s, got %s", p2pEndpointAddr.String(), rAddr.String())
}
if string(buf[:n]) != string(msgRedirected[i]) {
t.Errorf("expected %s, got %s", string(msgRedirected[i]), string(buf[:n]))
}
}
}

View File

@@ -0,0 +1,50 @@
//go:build linux && !android
package rawsocket
import (
"fmt"
"net"
"os"
"syscall"
nbnet "github.com/netbirdio/netbird/client/net"
)
func PrepareSenderRawSocket() (net.PacketConn, error) {
// Create a raw socket.
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
if err != nil {
return nil, fmt.Errorf("creating raw socket failed: %w", err)
}
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
if err != nil {
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
}
// Bind the socket to the "lo" interface.
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
if err != nil {
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
}
// Set the fwmark on the socket.
err = nbnet.SetSocketOpt(fd)
if err != nil {
return nil, fmt.Errorf("setting fwmark failed: %w", err)
}
// Convert the file descriptor to a PacketConn.
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
if file == nil {
return nil, fmt.Errorf("converting fd to file failed")
}
packetConn, err := net.FilePacketConn(file)
if err != nil {
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
}
return packetConn, nil
}

View File

@@ -1,3 +1,5 @@
//go:build linux && !android
package udp package udp
import ( import (
@@ -23,13 +25,15 @@ type WGUDPProxy struct {
remoteConn net.Conn remoteConn net.Conn
localConn net.Conn localConn net.Conn
srcFakerConn *SrcFaker
sendPkg func(data []byte) (int, error)
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
closeMu sync.Mutex closeMu sync.Mutex
closed bool closed bool
pausedMu sync.Mutex
paused bool paused bool
pausedCond *sync.Cond
isStarted bool isStarted bool
closeListener *listener.CloseListener closeListener *listener.CloseListener
@@ -41,6 +45,7 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy {
p := &WGUDPProxy{ p := &WGUDPProxy{
localWGListenPort: wgPort, localWGListenPort: wgPort,
mtu: mtu, mtu: mtu,
pausedCond: sync.NewCond(&sync.Mutex{}),
closeListener: listener.NewCloseListener(), closeListener: listener.NewCloseListener(),
} }
return p return p
@@ -61,6 +66,7 @@ func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, rem
p.ctx, p.cancel = context.WithCancel(ctx) p.ctx, p.cancel = context.WithCancel(ctx)
p.localConn = localConn p.localConn = localConn
p.sendPkg = p.localConn.Write
p.remoteConn = remoteConn p.remoteConn = remoteConn
return err return err
@@ -84,15 +90,24 @@ func (p *WGUDPProxy) Work() {
return return
} }
p.pausedMu.Lock() p.pausedCond.L.Lock()
p.paused = false p.paused = false
p.pausedMu.Unlock() p.sendPkg = p.localConn.Write
if p.srcFakerConn != nil {
if err := p.srcFakerConn.Close(); err != nil {
log.Errorf("failed to close src faker conn: %s", err)
}
p.srcFakerConn = nil
}
if !p.isStarted { if !p.isStarted {
p.isStarted = true p.isStarted = true
go p.proxyToRemote(p.ctx) go p.proxyToRemote(p.ctx)
go p.proxyToLocal(p.ctx) go p.proxyToLocal(p.ctx)
} }
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
} }
// Pause pauses the proxy from receiving data from the remote peer // Pause pauses the proxy from receiving data from the remote peer
@@ -101,9 +116,35 @@ func (p *WGUDPProxy) Pause() {
return return
} }
p.pausedMu.Lock() p.pausedCond.L.Lock()
p.paused = true p.paused = true
p.pausedMu.Unlock() p.pausedCond.L.Unlock()
}
// RedirectAs start to use the fake sourced raw socket as package sender
func (p *WGUDPProxy) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Lock()
defer func() {
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
}()
p.paused = false
if p.srcFakerConn != nil {
if err := p.srcFakerConn.Close(); err != nil {
log.Errorf("failed to close src faker conn: %s", err)
}
p.srcFakerConn = nil
}
srcFakerConn, err := NewSrcFaker(p.localWGListenPort, endpoint)
if err != nil {
log.Errorf("failed to create src faker conn: %s", err)
// fallback to continue without redirecting
p.paused = true
return
}
p.srcFakerConn = srcFakerConn
p.sendPkg = p.srcFakerConn.SendPkg
} }
// CloseConn close the localConn // CloseConn close the localConn
@@ -115,6 +156,8 @@ func (p *WGUDPProxy) CloseConn() error {
} }
func (p *WGUDPProxy) close() error { func (p *WGUDPProxy) close() error {
var result *multierror.Error
p.closeMu.Lock() p.closeMu.Lock()
defer p.closeMu.Unlock() defer p.closeMu.Unlock()
@@ -128,7 +171,11 @@ func (p *WGUDPProxy) close() error {
p.cancel() p.cancel()
var result *multierror.Error p.pausedCond.L.Lock()
p.paused = false
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err)) result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
} }
@@ -136,6 +183,13 @@ func (p *WGUDPProxy) close() error {
if err := p.localConn.Close(); err != nil { if err := p.localConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("local conn: %s", err)) result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
} }
if p.srcFakerConn != nil {
if err := p.srcFakerConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("src faker raw conn: %s", err))
}
}
return cerrors.FormatErrorOrNil(result) return cerrors.FormatErrorOrNil(result)
} }
@@ -194,14 +248,12 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) {
return return
} }
p.pausedMu.Lock() p.pausedCond.L.Lock()
if p.paused { for p.paused {
p.pausedMu.Unlock() p.pausedCond.Wait()
continue
} }
_, err = p.sendPkg(buf[:n])
_, err = p.localConn.Write(buf[:n]) p.pausedCond.L.Unlock()
p.pausedMu.Unlock()
if err != nil { if err != nil {
if ctx.Err() != nil { if ctx.Err() != nil {

View File

@@ -0,0 +1,101 @@
//go:build linux && !android
package udp
import (
"fmt"
"net"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket"
)
var (
serializeOpts = gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
localHostNetIPAddr = &net.IPAddr{
IP: net.ParseIP("127.0.0.1"),
}
)
type SrcFaker struct {
srcAddr *net.UDPAddr
rawSocket net.PacketConn
ipH gopacket.SerializableLayer
udpH gopacket.SerializableLayer
layerBuffer gopacket.SerializeBuffer
}
func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
rawSocket, err := rawsocket.PrepareSenderRawSocket()
if err != nil {
return nil, err
}
ipH, udpH, err := prepareHeaders(dstPort, srcAddr)
if err != nil {
return nil, err
}
f := &SrcFaker{
srcAddr: srcAddr,
rawSocket: rawSocket,
ipH: ipH,
udpH: udpH,
layerBuffer: gopacket.NewSerializeBuffer(),
}
return f, nil
}
func (f *SrcFaker) Close() error {
return f.rawSocket.Close()
}
func (f *SrcFaker) SendPkg(data []byte) (int, error) {
defer func() {
if err := f.layerBuffer.Clear(); err != nil {
log.Errorf("failed to clear layer buffer: %s", err)
}
}()
payload := gopacket.Payload(data)
err := gopacket.SerializeLayers(f.layerBuffer, serializeOpts, f.ipH, f.udpH, payload)
if err != nil {
return 0, fmt.Errorf("serialize layers: %w", err)
}
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr)
if err != nil {
return 0, fmt.Errorf("write to raw conn: %w", err)
}
return n, nil
}
func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) {
ipH := &layers.IPv4{
DstIP: net.ParseIP("127.0.0.1"),
SrcIP: srcAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(srcAddr.Port),
DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port
}
err := udpH.SetNetworkLayerForChecksum(ipH)
if err != nil {
return nil, nil, fmt.Errorf("set network layer for checksum: %w", err)
}
return ipH, udpH, nil
}

View File

@@ -34,7 +34,7 @@ import (
relayClient "github.com/netbirdio/netbird/shared/relay/client" relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client" signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
@@ -275,7 +275,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
c.engine.SetSyncResponsePersistence(c.persistSyncResponse) c.engine.SetSyncResponsePersistence(c.persistSyncResponse)
c.engineMutex.Unlock() c.engineMutex.Unlock()
if err := c.engine.Start(); err != nil { if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err) log.Errorf("error while starting Netbird Connection Engine: %s", err)
return wrapErr(err) return wrapErr(err)
} }
@@ -284,10 +284,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
state.Set(StatusConnected) state.Set(StatusConnected)
if runningChan != nil { if runningChan != nil {
select { close(runningChan)
case runningChan <- struct{}{}: runningChan = nil
default:
}
} }
<-engineCtx.Done() <-engineCtx.Done()

View File

@@ -0,0 +1,201 @@
package config
import (
"errors"
"fmt"
"net"
"net/netip"
"net/url"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
var (
ErrEmptyURL = errors.New("empty URL")
ErrEmptyHost = errors.New("empty host")
ErrIPNotAllowed = errors.New("IP address not allowed")
)
// ServerDomains represents the management server domains extracted from NetBird configuration
type ServerDomains struct {
Signal domain.Domain
Relay []domain.Domain
Flow domain.Domain
Stuns []domain.Domain
Turns []domain.Domain
}
// ExtractFromNetbirdConfig extracts domain information from NetBird protobuf configuration
func ExtractFromNetbirdConfig(config *mgmProto.NetbirdConfig) ServerDomains {
if config == nil {
return ServerDomains{}
}
domains := ServerDomains{}
domains.Signal = extractSignalDomain(config)
domains.Relay = extractRelayDomains(config)
domains.Flow = extractFlowDomain(config)
domains.Stuns = extractStunDomains(config)
domains.Turns = extractTurnDomains(config)
return domains
}
// ExtractValidDomain extracts a valid domain from a URL, filtering out IP addresses
func ExtractValidDomain(rawURL string) (domain.Domain, error) {
if rawURL == "" {
return "", ErrEmptyURL
}
parsedURL, err := url.Parse(rawURL)
if err == nil {
if domain, err := extractFromParsedURL(parsedURL); err != nil || domain != "" {
return domain, err
}
}
return extractFromRawString(rawURL)
}
// extractFromParsedURL handles domain extraction from successfully parsed URLs
func extractFromParsedURL(parsedURL *url.URL) (domain.Domain, error) {
if parsedURL.Hostname() != "" {
return extractDomainFromHost(parsedURL.Hostname())
}
if parsedURL.Opaque == "" || parsedURL.Scheme == "" {
return "", nil
}
// Handle URLs with opaque content (e.g., stun:host:port)
if strings.Contains(parsedURL.Scheme, ".") {
// This is likely "domain.com:port" being parsed as scheme:opaque
reconstructed := parsedURL.Scheme + ":" + parsedURL.Opaque
if host, _, err := net.SplitHostPort(reconstructed); err == nil {
return extractDomainFromHost(host)
}
return extractDomainFromHost(parsedURL.Scheme)
}
// Valid scheme with opaque content (e.g., stun:host:port)
host := parsedURL.Opaque
if queryIndex := strings.Index(host, "?"); queryIndex > 0 {
host = host[:queryIndex]
}
if hostOnly, _, err := net.SplitHostPort(host); err == nil {
return extractDomainFromHost(hostOnly)
}
return extractDomainFromHost(host)
}
// extractFromRawString handles domain extraction when URL parsing fails or returns no results
func extractFromRawString(rawURL string) (domain.Domain, error) {
if host, _, err := net.SplitHostPort(rawURL); err == nil {
return extractDomainFromHost(host)
}
return extractDomainFromHost(rawURL)
}
// extractDomainFromHost extracts domain from a host string, filtering out IP addresses
func extractDomainFromHost(host string) (domain.Domain, error) {
if host == "" {
return "", ErrEmptyHost
}
if _, err := netip.ParseAddr(host); err == nil {
return "", fmt.Errorf("%w: %s", ErrIPNotAllowed, host)
}
d, err := domain.FromString(host)
if err != nil {
return "", fmt.Errorf("invalid domain: %v", err)
}
return d, nil
}
// extractSingleDomain extracts a single domain from a URL with error logging
func extractSingleDomain(url, serviceType string) domain.Domain {
if url == "" {
return ""
}
d, err := ExtractValidDomain(url)
if err != nil {
log.Debugf("Skipping %s: %v", serviceType, err)
return ""
}
return d
}
// extractMultipleDomains extracts multiple domains from URLs with error logging
func extractMultipleDomains(urls []string, serviceType string) []domain.Domain {
var domains []domain.Domain
for _, url := range urls {
if url == "" {
continue
}
d, err := ExtractValidDomain(url)
if err != nil {
log.Debugf("Skipping %s: %v", serviceType, err)
continue
}
domains = append(domains, d)
}
return domains
}
// extractSignalDomain extracts the signal domain from NetBird configuration.
func extractSignalDomain(config *mgmProto.NetbirdConfig) domain.Domain {
if config.Signal != nil {
return extractSingleDomain(config.Signal.Uri, "signal")
}
return ""
}
// extractRelayDomains extracts relay server domains from NetBird configuration.
func extractRelayDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
if config.Relay != nil {
return extractMultipleDomains(config.Relay.Urls, "relay")
}
return nil
}
// extractFlowDomain extracts the traffic flow domain from NetBird configuration.
func extractFlowDomain(config *mgmProto.NetbirdConfig) domain.Domain {
if config.Flow != nil {
return extractSingleDomain(config.Flow.Url, "flow")
}
return ""
}
// extractStunDomains extracts STUN server domains from NetBird configuration.
func extractStunDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
var urls []string
for _, stun := range config.Stuns {
if stun != nil && stun.Uri != "" {
urls = append(urls, stun.Uri)
}
}
return extractMultipleDomains(urls, "STUN")
}
// extractTurnDomains extracts TURN server domains from NetBird configuration.
func extractTurnDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
var urls []string
for _, turn := range config.Turns {
if turn != nil && turn.HostConfig != nil && turn.HostConfig.Uri != "" {
urls = append(urls, turn.HostConfig.Uri)
}
}
return extractMultipleDomains(urls, "TURN")
}

View File

@@ -0,0 +1,213 @@
package config
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestExtractValidDomain(t *testing.T) {
tests := []struct {
name string
url string
expected string
expectError bool
}{
{
name: "HTTPS URL with port",
url: "https://api.netbird.io:443",
expected: "api.netbird.io",
},
{
name: "HTTP URL without port",
url: "http://signal.example.com",
expected: "signal.example.com",
},
{
name: "Host with port (no scheme)",
url: "signal.netbird.io:443",
expected: "signal.netbird.io",
},
{
name: "STUN URL",
url: "stun:stun.netbird.io:443",
expected: "stun.netbird.io",
},
{
name: "STUN URL with different port",
url: "stun:stun.netbird.io:5555",
expected: "stun.netbird.io",
},
{
name: "TURNS URL with query params",
url: "turns:turn.netbird.io:443?transport=tcp",
expected: "turn.netbird.io",
},
{
name: "TURN URL",
url: "turn:turn.example.com:3478",
expected: "turn.example.com",
},
{
name: "REL URL",
url: "rel://relay.example.com:443",
expected: "relay.example.com",
},
{
name: "RELS URL",
url: "rels://relay.netbird.io:443",
expected: "relay.netbird.io",
},
{
name: "Raw hostname",
url: "example.org",
expected: "example.org",
},
{
name: "IP address should be rejected",
url: "192.168.1.1",
expectError: true,
},
{
name: "IP address with port should be rejected",
url: "192.168.1.1:443",
expectError: true,
},
{
name: "IPv6 address should be rejected",
url: "2001:db8::1",
expectError: true,
},
{
name: "HTTP URL with IPv4 should be rejected",
url: "http://192.168.1.1:8080",
expectError: true,
},
{
name: "HTTPS URL with IPv4 should be rejected",
url: "https://10.0.0.1:443",
expectError: true,
},
{
name: "STUN URL with IPv4 should be rejected",
url: "stun:192.168.1.1:3478",
expectError: true,
},
{
name: "TURN URL with IPv4 should be rejected",
url: "turn:10.0.0.1:3478",
expectError: true,
},
{
name: "TURNS URL with IPv4 should be rejected",
url: "turns:172.16.0.1:5349",
expectError: true,
},
{
name: "HTTP URL with IPv6 should be rejected",
url: "http://[2001:db8::1]:8080",
expectError: true,
},
{
name: "HTTPS URL with IPv6 should be rejected",
url: "https://[::1]:443",
expectError: true,
},
{
name: "STUN URL with IPv6 should be rejected",
url: "stun:[2001:db8::1]:3478",
expectError: true,
},
{
name: "IPv6 with port should be rejected",
url: "[2001:db8::1]:443",
expectError: true,
},
{
name: "Localhost IPv4 should be rejected",
url: "127.0.0.1:8080",
expectError: true,
},
{
name: "Localhost IPv6 should be rejected",
url: "[::1]:443",
expectError: true,
},
{
name: "REL URL with IPv4 should be rejected",
url: "rel://192.168.1.1:443",
expectError: true,
},
{
name: "RELS URL with IPv4 should be rejected",
url: "rels://10.0.0.1:443",
expectError: true,
},
{
name: "Empty URL",
url: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := ExtractValidDomain(tt.url)
if tt.expectError {
assert.Error(t, err, "Expected error for URL: %s", tt.url)
} else {
assert.NoError(t, err, "Unexpected error for URL: %s", tt.url)
assert.Equal(t, tt.expected, result.SafeString(), "Domain mismatch for URL: %s", tt.url)
}
})
}
}
func TestExtractDomainFromHost(t *testing.T) {
tests := []struct {
name string
host string
expected string
expectError bool
}{
{
name: "Valid domain",
host: "example.com",
expected: "example.com",
},
{
name: "Subdomain",
host: "api.example.com",
expected: "api.example.com",
},
{
name: "IPv4 address",
host: "192.168.1.1",
expectError: true,
},
{
name: "IPv6 address",
host: "2001:db8::1",
expectError: true,
},
{
name: "Empty host",
host: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := extractDomainFromHost(tt.host)
if tt.expectError {
assert.Error(t, err, "Expected error for host: %s", tt.host)
} else {
assert.NoError(t, err, "Unexpected error for host: %s", tt.host)
assert.Equal(t, tt.expected, result.SafeString(), "Domain mismatch for host: %s", tt.host)
}
})
}
}

View File

@@ -11,6 +11,7 @@ import (
) )
const ( const (
PriorityMgmtCache = 150
PriorityLocal = 100 PriorityLocal = 100
PriorityDNSRoute = 75 PriorityDNSRoute = 75
PriorityUpstream = 50 PriorityUpstream = 50
@@ -182,7 +183,10 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// If handler wants to continue, try next handler // If handler wants to continue, try next handler
if chainWriter.shouldContinue { if chainWriter.shouldContinue {
// Only log continue for non-management cache handlers to reduce noise
if entry.Priority != PriorityMgmtCache {
log.Tracef("handler requested continue to next handler for domain=%s", qname) log.Tracef("handler requested continue to next handler for domain=%s", qname)
}
continue continue
} }
return return

View File

@@ -240,15 +240,19 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored // if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745 // see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
for i, domain := range domains { for i, domain := range domains {
policyPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i) localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
if r.gpo { gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
policyPath = fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
}
singleDomain := []string{domain} singleDomain := []string{domain}
if err := r.configureDNSPolicy(policyPath, singleDomain, ip); err != nil { if err := r.configureDNSPolicy(localPath, singleDomain, ip); err != nil {
return i, fmt.Errorf("configure DNS policy for domain %s: %w", domain, err) return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err)
}
if r.gpo {
if err := r.configureDNSPolicy(gpoPath, singleDomain, ip); err != nil {
return i, fmt.Errorf("configure gpo DNS policy: %w", err)
}
} }
log.Debugf("added NRPT entry for domain: %s", domain) log.Debugf("added NRPT entry for domain: %s", domain)
@@ -401,6 +405,7 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error {
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil { if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove local base entry: %w", err)) merr = multierror.Append(merr, fmt.Errorf("remove local base entry: %w", err))
} }
if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil { if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove GPO base entry: %w", err)) merr = multierror.Append(merr, fmt.Errorf("remove GPO base entry: %w", err))
} }
@@ -412,6 +417,7 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error {
if err := removeRegistryKeyFromDNSPolicyConfig(localPath); err != nil { if err := removeRegistryKeyFromDNSPolicyConfig(localPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove local entry %d: %w", i, err)) merr = multierror.Append(merr, fmt.Errorf("remove local entry %d: %w", i, err))
} }
if err := removeRegistryKeyFromDNSPolicyConfig(gpoPath); err != nil { if err := removeRegistryKeyFromDNSPolicyConfig(gpoPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove GPO entry %d: %w", i, err)) merr = multierror.Append(merr, fmt.Errorf("remove GPO entry %d: %w", i, err))
} }

View File

@@ -34,7 +34,7 @@ func (d *Resolver) MatchSubdomains() bool {
// String returns a string representation of the local resolver // String returns a string representation of the local resolver
func (d *Resolver) String() string { func (d *Resolver) String() string {
return fmt.Sprintf("local resolver [%d records]", len(d.records)) return fmt.Sprintf("LocalResolver [%d records]", len(d.records))
} }
func (d *Resolver) Stop() {} func (d *Resolver) Stop() {}

View File

@@ -0,0 +1,360 @@
package mgmt
import (
"context"
"fmt"
"net"
"net/url"
"strings"
"sync"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/shared/management/domain"
)
const dnsTimeout = 5 * time.Second
// Resolver caches critical NetBird infrastructure domains
type Resolver struct {
records map[dns.Question][]dns.RR
mgmtDomain *domain.Domain
serverDomains *dnsconfig.ServerDomains
mutex sync.RWMutex
}
// NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver {
return &Resolver{
records: make(map[dns.Question][]dns.RR),
}
}
// String returns a string representation of the resolver.
func (m *Resolver) String() string {
return "MgmtCacheResolver"
}
// ServeDNS implements dns.Handler interface.
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
m.continueToNext(w, r)
return
}
question := r.Question[0]
question.Name = strings.ToLower(dns.Fqdn(question.Name))
if question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA {
m.continueToNext(w, r)
return
}
m.mutex.RLock()
records, found := m.records[question]
m.mutex.RUnlock()
if !found {
m.continueToNext(w, r)
return
}
resp := &dns.Msg{}
resp.SetReply(r)
resp.Authoritative = false
resp.RecursionAvailable = true
resp.Answer = append(resp.Answer, records...)
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write response: %v", err)
}
}
// MatchSubdomains returns false since this resolver only handles exact domain matches
// for NetBird infrastructure domains (signal, relay, flow, etc.), not their subdomains.
func (m *Resolver) MatchSubdomains() bool {
return false
}
// continueToNext signals the handler chain to continue to the next handler.
func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetRcode(r, dns.RcodeNameError)
resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write continue signal: %v", err)
}
}
// AddDomain manually adds a domain to cache by resolving it.
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
defer cancel()
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
if err != nil {
return fmt.Errorf("resolve domain %s: %w", d.SafeString(), err)
}
var aRecords, aaaaRecords []dns.RR
for _, ip := range ips {
if ip.Is4() {
rr := &dns.A{
Hdr: dns.RR_Header{
Name: dnsName,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
},
A: ip.AsSlice(),
}
aRecords = append(aRecords, rr)
} else if ip.Is6() {
rr := &dns.AAAA{
Hdr: dns.RR_Header{
Name: dnsName,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 300,
},
AAAA: ip.AsSlice(),
}
aaaaRecords = append(aaaaRecords, rr)
}
}
m.mutex.Lock()
if len(aRecords) > 0 {
aQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
m.records[aQuestion] = aRecords
}
if len(aaaaRecords) > 0 {
aaaaQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
m.records[aaaaQuestion] = aaaaRecords
}
m.mutex.Unlock()
log.Debugf("added domain=%s with %d A records and %d AAAA records",
d.SafeString(), len(aRecords), len(aaaaRecords))
return nil
}
// PopulateFromConfig extracts and caches domains from the client configuration.
func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error {
if mgmtURL == nil {
return nil
}
d, err := dnsconfig.ExtractValidDomain(mgmtURL.String())
if err != nil {
return fmt.Errorf("extract domain from URL: %w", err)
}
m.mutex.Lock()
m.mgmtDomain = &d
m.mutex.Unlock()
if err := m.AddDomain(ctx, d); err != nil {
return fmt.Errorf("add domain: %w", err)
}
return nil
}
// RemoveDomain removes a domain from the cache.
func (m *Resolver) RemoveDomain(d domain.Domain) error {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
m.mutex.Lock()
defer m.mutex.Unlock()
aQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
delete(m.records, aQuestion)
aaaaQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
delete(m.records, aaaaQuestion)
log.Debugf("removed domain=%s from cache", d.SafeString())
return nil
}
// GetCachedDomains returns a list of all cached domains.
func (m *Resolver) GetCachedDomains() domain.List {
m.mutex.RLock()
defer m.mutex.RUnlock()
domainSet := make(map[domain.Domain]struct{})
for question := range m.records {
domainName := strings.TrimSuffix(question.Name, ".")
domainSet[domain.Domain(domainName)] = struct{}{}
}
domains := make(domain.List, 0, len(domainSet))
for d := range domainSet {
domains = append(domains, d)
}
return domains
}
// UpdateFromServerDomains updates the cache with server domains from network configuration.
// It merges new domains with existing ones, replacing entire domain types when updated.
// Empty updates are ignored to prevent clearing infrastructure domains during partial updates.
func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains) (domain.List, error) {
newDomains := m.extractDomainsFromServerDomains(serverDomains)
var removedDomains domain.List
if len(newDomains) > 0 {
m.mutex.Lock()
if m.serverDomains == nil {
m.serverDomains = &dnsconfig.ServerDomains{}
}
updatedServerDomains := m.mergeServerDomains(*m.serverDomains, serverDomains)
m.serverDomains = &updatedServerDomains
m.mutex.Unlock()
allDomains := m.extractDomainsFromServerDomains(updatedServerDomains)
currentDomains := m.GetCachedDomains()
removedDomains = m.removeStaleDomains(currentDomains, allDomains)
}
m.addNewDomains(ctx, newDomains)
return removedDomains, nil
}
// removeStaleDomains removes cached domains not present in the target domain list.
// Management domains are preserved and never removed during server domain updates.
func (m *Resolver) removeStaleDomains(currentDomains, newDomains domain.List) domain.List {
var removedDomains domain.List
for _, currentDomain := range currentDomains {
if m.isDomainInList(currentDomain, newDomains) {
continue
}
if m.isManagementDomain(currentDomain) {
continue
}
removedDomains = append(removedDomains, currentDomain)
if err := m.RemoveDomain(currentDomain); err != nil {
log.Warnf("failed to remove domain=%s: %v", currentDomain.SafeString(), err)
}
}
return removedDomains
}
// mergeServerDomains merges new server domains with existing ones.
// When a domain type is provided in the new domains, it completely replaces that type.
func (m *Resolver) mergeServerDomains(existing, incoming dnsconfig.ServerDomains) dnsconfig.ServerDomains {
merged := existing
if incoming.Signal != "" {
merged.Signal = incoming.Signal
}
if len(incoming.Relay) > 0 {
merged.Relay = incoming.Relay
}
if incoming.Flow != "" {
merged.Flow = incoming.Flow
}
if len(incoming.Stuns) > 0 {
merged.Stuns = incoming.Stuns
}
if len(incoming.Turns) > 0 {
merged.Turns = incoming.Turns
}
return merged
}
// isDomainInList checks if domain exists in the list
func (m *Resolver) isDomainInList(domain domain.Domain, list domain.List) bool {
for _, d := range list {
if domain.SafeString() == d.SafeString() {
return true
}
}
return false
}
// isManagementDomain checks if domain is the protected management domain
func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
return m.mgmtDomain != nil && domain == *m.mgmtDomain
}
// addNewDomains resolves and caches all domains from the update
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
for _, newDomain := range newDomains {
if err := m.AddDomain(ctx, newDomain); err != nil {
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
} else {
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
}
}
}
func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.ServerDomains) domain.List {
var domains domain.List
if serverDomains.Signal != "" {
domains = append(domains, serverDomains.Signal)
}
for _, relay := range serverDomains.Relay {
if relay != "" {
domains = append(domains, relay)
}
}
if serverDomains.Flow != "" {
domains = append(domains, serverDomains.Flow)
}
for _, stun := range serverDomains.Stuns {
if stun != "" {
domains = append(domains, stun)
}
}
for _, turn := range serverDomains.Turns {
if turn != "" {
domains = append(domains, turn)
}
}
return domains
}

View File

@@ -0,0 +1,416 @@
package mgmt
import (
"context"
"fmt"
"net/url"
"strings"
"testing"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/shared/management/domain"
)
func TestResolver_NewResolver(t *testing.T) {
resolver := NewResolver()
assert.NotNil(t, resolver)
assert.NotNil(t, resolver.records)
assert.False(t, resolver.MatchSubdomains())
}
func TestResolver_ExtractDomainFromURL(t *testing.T) {
tests := []struct {
name string
urlStr string
expectedDom string
expectError bool
}{
{
name: "HTTPS URL with port",
urlStr: "https://api.netbird.io:443",
expectedDom: "api.netbird.io",
expectError: false,
},
{
name: "HTTP URL without port",
urlStr: "http://signal.example.com",
expectedDom: "signal.example.com",
expectError: false,
},
{
name: "URL with path",
urlStr: "https://relay.netbird.io/status",
expectedDom: "relay.netbird.io",
expectError: false,
},
{
name: "Invalid URL",
urlStr: "not-a-valid-url",
expectedDom: "not-a-valid-url",
expectError: false,
},
{
name: "Empty URL",
urlStr: "",
expectedDom: "",
expectError: true,
},
{
name: "STUN URL",
urlStr: "stun:stun.example.com:3478",
expectedDom: "stun.example.com",
expectError: false,
},
{
name: "TURN URL",
urlStr: "turn:turn.example.com:3478",
expectedDom: "turn.example.com",
expectError: false,
},
{
name: "REL URL",
urlStr: "rel://relay.example.com:443",
expectedDom: "relay.example.com",
expectError: false,
},
{
name: "RELS URL",
urlStr: "rels://relay.example.com:443",
expectedDom: "relay.example.com",
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var parsedURL *url.URL
var err error
if tt.urlStr != "" {
parsedURL, err = url.Parse(tt.urlStr)
if err != nil && !tt.expectError {
t.Fatalf("Failed to parse URL: %v", err)
}
}
domain, err := extractDomainFromURL(parsedURL)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectedDom, domain.SafeString())
}
})
}
}
func TestResolver_PopulateFromConfig(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := NewResolver()
// Test with IP address - should return error since IP addresses are rejected
mgmtURL, _ := url.Parse("https://127.0.0.1")
err := resolver.PopulateFromConfig(ctx, mgmtURL)
assert.Error(t, err)
assert.ErrorIs(t, err, dnsconfig.ErrIPNotAllowed)
// No domains should be cached when using IP addresses
domains := resolver.GetCachedDomains()
assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses")
}
func TestResolver_ServeDNS(t *testing.T) {
resolver := NewResolver()
ctx := context.Background()
// Add a test domain to the cache - use example.org which is reserved for testing
testDomain, err := domain.FromString("example.org")
if err != nil {
t.Fatalf("Failed to create domain: %v", err)
}
err = resolver.AddDomain(ctx, testDomain)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
// Test A record query for cached domain
t.Run("Cached domain A record", func(t *testing.T) {
var capturedMsg *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
capturedMsg = m
return nil
},
}
req := new(dns.Msg)
req.SetQuestion("example.org.", dns.TypeA)
resolver.ServeDNS(mockWriter, req)
assert.NotNil(t, capturedMsg)
assert.Equal(t, dns.RcodeSuccess, capturedMsg.Rcode)
assert.True(t, len(capturedMsg.Answer) > 0, "Should have at least one answer")
})
// Test uncached domain signals to continue to next handler
t.Run("Uncached domain signals continue to next handler", func(t *testing.T) {
var capturedMsg *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
capturedMsg = m
return nil
},
}
req := new(dns.Msg)
req.SetQuestion("unknown.example.com.", dns.TypeA)
resolver.ServeDNS(mockWriter, req)
assert.NotNil(t, capturedMsg)
assert.Equal(t, dns.RcodeNameError, capturedMsg.Rcode)
// Zero flag set to true signals the handler chain to continue to next handler
assert.True(t, capturedMsg.MsgHdr.Zero, "Zero flag should be set to signal continuation to next handler")
assert.Empty(t, capturedMsg.Answer, "Should have no answers for uncached domain")
})
// Test that subdomains of cached domains are NOT resolved
t.Run("Subdomains of cached domains are not resolved", func(t *testing.T) {
var capturedMsg *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
capturedMsg = m
return nil
},
}
// Query for a subdomain of our cached domain
req := new(dns.Msg)
req.SetQuestion("sub.example.org.", dns.TypeA)
resolver.ServeDNS(mockWriter, req)
assert.NotNil(t, capturedMsg)
assert.Equal(t, dns.RcodeNameError, capturedMsg.Rcode)
assert.True(t, capturedMsg.MsgHdr.Zero, "Should signal continuation to next handler for subdomains")
assert.Empty(t, capturedMsg.Answer, "Should have no answers for subdomains")
})
// Test case-insensitive matching
t.Run("Case-insensitive domain matching", func(t *testing.T) {
var capturedMsg *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
capturedMsg = m
return nil
},
}
// Query with different casing
req := new(dns.Msg)
req.SetQuestion("EXAMPLE.ORG.", dns.TypeA)
resolver.ServeDNS(mockWriter, req)
assert.NotNil(t, capturedMsg)
assert.Equal(t, dns.RcodeSuccess, capturedMsg.Rcode)
assert.True(t, len(capturedMsg.Answer) > 0, "Should resolve regardless of case")
})
}
func TestResolver_GetCachedDomains(t *testing.T) {
resolver := NewResolver()
ctx := context.Background()
testDomain, err := domain.FromString("example.org")
if err != nil {
t.Fatalf("Failed to create domain: %v", err)
}
err = resolver.AddDomain(ctx, testDomain)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
cachedDomains := resolver.GetCachedDomains()
assert.Equal(t, 1, len(cachedDomains), "Should return exactly one domain for single added domain")
assert.Equal(t, testDomain.SafeString(), cachedDomains[0].SafeString(), "Cached domain should match original")
assert.False(t, strings.HasSuffix(cachedDomains[0].PunycodeString(), "."), "Domain should not have trailing dot")
}
func TestResolver_ManagementDomainProtection(t *testing.T) {
resolver := NewResolver()
ctx := context.Background()
mgmtURL, _ := url.Parse("https://example.org")
err := resolver.PopulateFromConfig(ctx, mgmtURL)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
initialDomains := resolver.GetCachedDomains()
if len(initialDomains) == 0 {
t.Skip("Management domain failed to resolve, skipping test")
}
assert.Equal(t, 1, len(initialDomains), "Should have management domain cached")
assert.Equal(t, "example.org", initialDomains[0].SafeString())
serverDomains := dnsconfig.ServerDomains{
Signal: "google.com",
Relay: []domain.Domain{"cloudflare.com"},
}
_, err = resolver.UpdateFromServerDomains(ctx, serverDomains)
if err != nil {
t.Logf("Server domains update failed: %v", err)
}
finalDomains := resolver.GetCachedDomains()
managementStillCached := false
for _, d := range finalDomains {
if d.SafeString() == "example.org" {
managementStillCached = true
break
}
}
assert.True(t, managementStillCached, "Management domain should never be removed")
}
// extractDomainFromURL extracts a domain from a URL - test helper function
func extractDomainFromURL(u *url.URL) (domain.Domain, error) {
if u == nil {
return "", fmt.Errorf("URL is nil")
}
return dnsconfig.ExtractValidDomain(u.String())
}
func TestResolver_EmptyUpdateDoesNotRemoveDomains(t *testing.T) {
resolver := NewResolver()
ctx := context.Background()
// Set up initial domains using resolvable domains
initialDomains := dnsconfig.ServerDomains{
Signal: "example.org",
Stuns: []domain.Domain{"google.com"},
Turns: []domain.Domain{"cloudflare.com"},
}
// Add initial domains
_, err := resolver.UpdateFromServerDomains(ctx, initialDomains)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
// Verify domains were added
cachedDomains := resolver.GetCachedDomains()
assert.Len(t, cachedDomains, 3)
// Update with empty ServerDomains (simulating partial network map update)
emptyDomains := dnsconfig.ServerDomains{}
removedDomains, err := resolver.UpdateFromServerDomains(ctx, emptyDomains)
assert.NoError(t, err)
// Verify no domains were removed
assert.Len(t, removedDomains, 0, "No domains should be removed when update is empty")
// Verify all original domains are still cached
finalDomains := resolver.GetCachedDomains()
assert.Len(t, finalDomains, 3, "All original domains should still be cached")
}
func TestResolver_PartialUpdateReplacesOnlyUpdatedTypes(t *testing.T) {
resolver := NewResolver()
ctx := context.Background()
// Set up initial complete domains using resolvable domains
initialDomains := dnsconfig.ServerDomains{
Signal: "example.org",
Stuns: []domain.Domain{"google.com"},
Turns: []domain.Domain{"cloudflare.com"},
}
// Add initial domains
_, err := resolver.UpdateFromServerDomains(ctx, initialDomains)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
assert.Len(t, resolver.GetCachedDomains(), 3)
// Update with partial ServerDomains (only signal domain - this should replace signal but preserve stun/turn)
partialDomains := dnsconfig.ServerDomains{
Signal: "github.com",
}
removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
// Should remove only the old signal domain
assert.Len(t, removedDomains, 1, "Should remove only the old signal domain")
assert.Equal(t, "example.org", removedDomains[0].SafeString())
finalDomains := resolver.GetCachedDomains()
assert.Len(t, finalDomains, 3, "Should have new signal plus preserved stun/turn domains")
domainStrings := make([]string, len(finalDomains))
for i, d := range finalDomains {
domainStrings[i] = d.SafeString()
}
assert.Contains(t, domainStrings, "github.com")
assert.Contains(t, domainStrings, "google.com")
assert.Contains(t, domainStrings, "cloudflare.com")
assert.NotContains(t, domainStrings, "example.org")
}
func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
resolver := NewResolver()
ctx := context.Background()
// Set up initial complete domains using resolvable domains
initialDomains := dnsconfig.ServerDomains{
Signal: "example.org",
Stuns: []domain.Domain{"google.com"},
Turns: []domain.Domain{"cloudflare.com"},
}
// Add initial domains
_, err := resolver.UpdateFromServerDomains(ctx, initialDomains)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
assert.Len(t, resolver.GetCachedDomains(), 3)
// Update with partial ServerDomains (only flow domain - new type, should preserve all existing)
partialDomains := dnsconfig.ServerDomains{
Flow: "github.com",
}
removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
assert.Len(t, removedDomains, 0, "Should not remove any domains when adding new type")
finalDomains := resolver.GetCachedDomains()
assert.Len(t, finalDomains, 4, "Should have all original domains plus new flow domain")
domainStrings := make([]string, len(finalDomains))
for i, d := range finalDomains {
domainStrings[i] = d.SafeString()
}
assert.Contains(t, domainStrings, "example.org")
assert.Contains(t, domainStrings, "google.com")
assert.Contains(t, domainStrings, "cloudflare.com")
assert.Contains(t, domainStrings, "github.com")
}

View File

@@ -3,9 +3,11 @@ package dns
import ( import (
"fmt" "fmt"
"net/netip" "net/netip"
"net/url"
"github.com/miekg/dns" "github.com/miekg/dns"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
) )
@@ -17,6 +19,7 @@ type MockServer struct {
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
RegisterHandlerFunc func(domain.List, dns.Handler, int) RegisterHandlerFunc func(domain.List, dns.Handler, int)
DeregisterHandlerFunc func(domain.List, int) DeregisterHandlerFunc func(domain.List, int)
UpdateServerConfigFunc func(domains dnsconfig.ServerDomains) error
} }
func (m *MockServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) { func (m *MockServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
@@ -70,3 +73,14 @@ func (m *MockServer) SearchDomains() []string {
// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface // ProbeAvailability mocks implementation of ProbeAvailability from the Server interface
func (m *MockServer) ProbeAvailability() { func (m *MockServer) ProbeAvailability() {
} }
func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
if m.UpdateServerConfigFunc != nil {
return m.UpdateServerConfigFunc(domains)
}
return nil
}
func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
return nil
}

View File

@@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/netip" "net/netip"
"net/url"
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
@@ -15,7 +16,9 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dns/local" "github.com/netbirdio/netbird/client/internal/dns/local"
"github.com/netbirdio/netbird/client/internal/dns/mgmt"
"github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
@@ -45,6 +48,8 @@ type Server interface {
OnUpdatedHostDNSServer(addrs []netip.AddrPort) OnUpdatedHostDNSServer(addrs []netip.AddrPort)
SearchDomains() []string SearchDomains() []string
ProbeAvailability() ProbeAvailability()
UpdateServerConfig(domains dnsconfig.ServerDomains) error
PopulateManagementDomain(mgmtURL *url.URL) error
} }
type nsGroupsByDomain struct { type nsGroupsByDomain struct {
@@ -77,6 +82,8 @@ type DefaultServer struct {
handlerChain *HandlerChain handlerChain *HandlerChain
extraDomains map[domain.Domain]int extraDomains map[domain.Domain]int
mgmtCacheResolver *mgmt.Resolver
// permanent related properties // permanent related properties
permanent bool permanent bool
hostsDNSHolder *hostsDNSHolder hostsDNSHolder *hostsDNSHolder
@@ -104,18 +111,20 @@ type handlerWrapper struct {
type registeredHandlerMap map[types.HandlerID]handlerWrapper type registeredHandlerMap map[types.HandlerID]handlerWrapper
// DefaultServerConfig holds configuration parameters for NewDefaultServer
type DefaultServerConfig struct {
WgInterface WGIface
CustomAddress string
StatusRecorder *peer.Status
StateManager *statemanager.Manager
DisableSys bool
}
// NewDefaultServer returns a new dns server // NewDefaultServer returns a new dns server
func NewDefaultServer( func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*DefaultServer, error) {
ctx context.Context,
wgInterface WGIface,
customAddress string,
statusRecorder *peer.Status,
stateManager *statemanager.Manager,
disableSys bool,
) (*DefaultServer, error) {
var addrPort *netip.AddrPort var addrPort *netip.AddrPort
if customAddress != "" { if config.CustomAddress != "" {
parsedAddrPort, err := netip.ParseAddrPort(customAddress) parsedAddrPort, err := netip.ParseAddrPort(config.CustomAddress)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err) return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err)
} }
@@ -123,13 +132,14 @@ func NewDefaultServer(
} }
var dnsService service var dnsService service
if wgInterface.IsUserspaceBind() { if config.WgInterface.IsUserspaceBind() {
dnsService = NewServiceViaMemory(wgInterface) dnsService = NewServiceViaMemory(config.WgInterface)
} else { } else {
dnsService = newServiceViaListener(wgInterface, addrPort) dnsService = newServiceViaListener(config.WgInterface, addrPort)
} }
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager, disableSys), nil server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys)
return server, nil
} }
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems // NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
@@ -178,6 +188,9 @@ func newDefaultServer(
) *DefaultServer { ) *DefaultServer {
handlerChain := NewHandlerChain() handlerChain := NewHandlerChain()
ctx, stop := context.WithCancel(ctx) ctx, stop := context.WithCancel(ctx)
mgmtCacheResolver := mgmt.NewResolver()
defaultServer := &DefaultServer{ defaultServer := &DefaultServer{
ctx: ctx, ctx: ctx,
ctxCancel: stop, ctxCancel: stop,
@@ -192,6 +205,7 @@ func newDefaultServer(
stateManager: stateManager, stateManager: stateManager,
hostsDNSHolder: newHostsDNSHolder(), hostsDNSHolder: newHostsDNSHolder(),
hostManager: &noopHostConfigurator{}, hostManager: &noopHostConfigurator{},
mgmtCacheResolver: mgmtCacheResolver,
} }
// register with root zone, handler chain takes care of the routing // register with root zone, handler chain takes care of the routing
@@ -217,7 +231,7 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler
} }
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) { func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
log.Debugf("registering handler %s with priority %d", handler, priority) log.Debugf("registering handler %s with priority %d for %v", handler, priority, domains)
for _, domain := range domains { for _, domain := range domains {
if domain == "" { if domain == "" {
@@ -246,7 +260,7 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
} }
func (s *DefaultServer) deregisterHandler(domains []string, priority int) { func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
log.Debugf("deregistering handler %v with priority %d", domains, priority) log.Debugf("deregistering handler with priority %d for %v", priority, domains)
for _, domain := range domains { for _, domain := range domains {
if domain == "" { if domain == "" {
@@ -432,6 +446,29 @@ func (s *DefaultServer) ProbeAvailability() {
wg.Wait() wg.Wait()
} }
func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
s.mux.Lock()
defer s.mux.Unlock()
if s.mgmtCacheResolver != nil {
removedDomains, err := s.mgmtCacheResolver.UpdateFromServerDomains(s.ctx, domains)
if err != nil {
return fmt.Errorf("update management cache resolver: %w", err)
}
if len(removedDomains) > 0 {
s.deregisterHandler(removedDomains.ToPunycodeList(), PriorityMgmtCache)
}
newDomains := s.mgmtCacheResolver.GetCachedDomains()
if len(newDomains) > 0 {
s.registerHandler(newDomains.ToPunycodeList(), s.mgmtCacheResolver, PriorityMgmtCache)
}
}
return nil
}
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
// is the service should be Disabled, we stop the listener or fake resolver // is the service should be Disabled, we stop the listener or fake resolver
if update.ServiceEnable { if update.ServiceEnable {
@@ -961,3 +998,11 @@ func toZone(d domain.Domain) domain.Domain {
), ),
) )
} }
// PopulateManagementDomain populates the DNS cache with management domain
func (s *DefaultServer) PopulateManagementDomain(mgmtURL *url.URL) error {
if s.mgmtCacheResolver != nil {
return s.mgmtCacheResolver.PopulateFromConfig(s.ctx, mgmtURL)
}
return nil
}

View File

@@ -0,0 +1,5 @@
package dns
func (s *DefaultServer) initialize() (hostManager, error) {
return &noopHostConfigurator{}, nil
}

View File

@@ -363,7 +363,13 @@ func TestUpdateDNSServer(t *testing.T) {
t.Log(err) t.Log(err)
} }
}() }()
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false) dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
WgInterface: wgIface,
CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -473,7 +479,13 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
return return
} }
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false) dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
WgInterface: wgIface,
CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
})
if err != nil { if err != nil {
t.Errorf("create DNS server: %v", err) t.Errorf("create DNS server: %v", err)
return return
@@ -575,7 +587,13 @@ func TestDNSServerStartStop(t *testing.T) {
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, peer.NewRecorder("mgm"), nil, false) dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
WgInterface: &mocWGIface{},
CustomAddress: testCase.addrPort,
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
})
if err != nil { if err != nil {
t.Fatalf("%v", err) t.Fatalf("%v", err)
} }

View File

@@ -10,7 +10,7 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
type ServiceViaMemory struct { type ServiceViaMemory struct {

View File

@@ -0,0 +1,19 @@
package dns
import (
"context"
)
type ShutdownState struct{}
func (s *ShutdownState) Name() string {
return "dns_state"
}
func (s *ShutdownState) Cleanup() error {
return nil
}
func (s *ShutdownState) RestoreUncleanShutdownConfigs(context.Context) error {
return nil
}

View File

@@ -33,9 +33,11 @@ func SetCurrentMTU(mtu uint16) {
} }
const ( const (
UpstreamTimeout = 15 * time.Second UpstreamTimeout = 4 * time.Second
// ClientTimeout is the timeout for the dns.Client.
// Set longer than UpstreamTimeout to ensure context timeout takes precedence
ClientTimeout = 5 * time.Second
failsTillDeact = int32(5)
reactivatePeriod = 30 * time.Second reactivatePeriod = 30 * time.Second
probeTimeout = 2 * time.Second probeTimeout = 2 * time.Second
) )
@@ -58,9 +60,7 @@ type upstreamResolverBase struct {
upstreamServers []netip.AddrPort upstreamServers []netip.AddrPort
domain string domain string
disabled bool disabled bool
failsCount atomic.Int32
successCount atomic.Int32 successCount atomic.Int32
failsTillDeact int32
mutex sync.Mutex mutex sync.Mutex
reactivatePeriod time.Duration reactivatePeriod time.Duration
upstreamTimeout time.Duration upstreamTimeout time.Duration
@@ -79,14 +79,13 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d
domain: domain, domain: domain,
upstreamTimeout: UpstreamTimeout, upstreamTimeout: UpstreamTimeout,
reactivatePeriod: reactivatePeriod, reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
} }
} }
// String returns a string representation of the upstream resolver // String returns a string representation of the upstream resolver
func (u *upstreamResolverBase) String() string { func (u *upstreamResolverBase) String() string {
return fmt.Sprintf("upstream %s", u.upstreamServers) return fmt.Sprintf("Upstream %s", u.upstreamServers)
} }
// ID returns the unique handler ID // ID returns the unique handler ID
@@ -116,58 +115,102 @@ func (u *upstreamResolverBase) Stop() {
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := GenerateRequestID() requestID := GenerateRequestID()
logger := log.WithField("request_id", requestID) logger := log.WithField("request_id", requestID)
var err error
defer func() {
u.checkUpstreamFails(err)
}()
logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
u.prepareRequest(r)
if u.ctx.Err() != nil {
logger.Tracef("%s has been stopped", u)
return
}
if u.tryUpstreamServers(w, r, logger) {
return
}
u.writeErrorResponse(w, r, logger)
}
func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
if r.Extra == nil { if r.Extra == nil {
r.MsgHdr.AuthenticatedData = true r.MsgHdr.AuthenticatedData = true
} }
}
select { func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) bool {
case <-u.ctx.Done(): timeout := u.upstreamTimeout
logger.Tracef("%s has been stopped", u) if len(u.upstreamServers) > 1 {
return maxTotal := 5 * time.Second
default: minPerUpstream := 2 * time.Second
scaledTimeout := maxTotal / time.Duration(len(u.upstreamServers))
if scaledTimeout > minPerUpstream {
timeout = scaledTimeout
} else {
timeout = minPerUpstream
}
} }
for _, upstream := range u.upstreamServers { for _, upstream := range u.upstreamServers {
if u.queryUpstream(w, r, upstream, timeout, logger) {
return true
}
}
return false
}
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) bool {
var rm *dns.Msg var rm *dns.Msg
var t time.Duration var t time.Duration
var err error
var startTime time.Time
func() { func() {
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout) ctx, cancel := context.WithTimeout(u.ctx, timeout)
defer cancel() defer cancel()
startTime = time.Now()
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r) rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
}() }()
if err != nil { if err != nil {
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) { u.handleUpstreamError(err, upstream, r.Question[0].Name, startTime, timeout, logger)
logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name) return false
continue
}
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
continue
} }
if rm == nil || !rm.Response { if rm == nil || !rm.Response {
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
continue return false
} }
u.successCount.Add(1) return u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
if err = w.WriteMsg(rm); err != nil {
logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
} }
// count the fails only if they happen sequentially
u.failsCount.Store(0) func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, domain string, startTime time.Time, timeout time.Duration, logger *log.Entry) {
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, domain, err)
return return
} }
u.failsCount.Add(1)
elapsed := time.Since(startTime)
timeoutMsg := fmt.Sprintf("upstream %s timed out for question domain=%s after %v (timeout=%v)", upstream, domain, elapsed.Truncate(time.Millisecond), timeout)
if peerInfo := u.debugUpstreamTimeout(upstream); peerInfo != "" {
timeoutMsg += " " + peerInfo
}
timeoutMsg += fmt.Sprintf(" - error: %v", err)
logger.Warnf(timeoutMsg)
}
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
u.successCount.Add(1)
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, domain)
if err := w.WriteMsg(rm); err != nil {
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
}
return true
}
func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) {
logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name) logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
m := new(dns.Msg) m := new(dns.Msg)
@@ -177,41 +220,6 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
} }
// checkUpstreamFails counts fails and disables or enables upstream resolving
//
// If fails count is greater that failsTillDeact, upstream resolving
// will be disabled for reactivatePeriod, after that time period fails counter
// will be reset and upstream will be reactivated.
func (u *upstreamResolverBase) checkUpstreamFails(err error) {
u.mutex.Lock()
defer u.mutex.Unlock()
if u.failsCount.Load() < u.failsTillDeact || u.disabled {
return
}
select {
case <-u.ctx.Done():
return
default:
}
u.disable(err)
if u.statusRecorder == nil {
return
}
u.statusRecorder.PublishEvent(
proto.SystemEvent_WARNING,
proto.SystemEvent_DNS,
"All upstream servers failed (fail count exceeded)",
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
map[string]string{"upstreams": u.upstreamServersString()},
// TODO add domain meta
)
}
// ProbeAvailability tests all upstream servers simultaneously and // ProbeAvailability tests all upstream servers simultaneously and
// disables the resolver if none work // disables the resolver if none work
func (u *upstreamResolverBase) ProbeAvailability() { func (u *upstreamResolverBase) ProbeAvailability() {
@@ -224,8 +232,8 @@ func (u *upstreamResolverBase) ProbeAvailability() {
default: default:
} }
// avoid probe if upstreams could resolve at least one query and fails count is less than failsTillDeact // avoid probe if upstreams could resolve at least one query
if u.successCount.Load() > 0 && u.failsCount.Load() < u.failsTillDeact { if u.successCount.Load() > 0 {
return return
} }
@@ -312,7 +320,6 @@ func (u *upstreamResolverBase) waitUntilResponse() {
} }
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString()) log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
u.failsCount.Store(0)
u.successCount.Add(1) u.successCount.Add(1)
u.reactivate() u.reactivate()
u.disabled = false u.disabled = false
@@ -416,3 +423,80 @@ func GenerateRequestID() string {
} }
return hex.EncodeToString(bytes) return hex.EncodeToString(bytes)
} }
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts
func FormatPeerStatus(peerState *peer.State) string {
isConnected := peerState.ConnStatus == peer.StatusConnected
hasRecentHandshake := !peerState.LastWireguardHandshake.IsZero() &&
time.Since(peerState.LastWireguardHandshake) < 3*time.Minute
statusInfo := fmt.Sprintf("%s:%s", peerState.FQDN, peerState.IP)
switch {
case !isConnected:
statusInfo += " DISCONNECTED"
case !hasRecentHandshake:
statusInfo += " NO_RECENT_HANDSHAKE"
default:
statusInfo += " connected"
}
if !peerState.LastWireguardHandshake.IsZero() {
timeSinceHandshake := time.Since(peerState.LastWireguardHandshake)
statusInfo += fmt.Sprintf(" last_handshake=%v_ago", timeSinceHandshake.Truncate(time.Second))
} else {
statusInfo += " no_handshake"
}
if peerState.Relayed {
statusInfo += " via_relay"
}
if peerState.Latency > 0 {
statusInfo += fmt.Sprintf(" latency=%v", peerState.Latency)
}
return statusInfo
}
// findPeerForIP finds which peer handles the given IP address
func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State {
if statusRecorder == nil {
return nil
}
fullStatus := statusRecorder.GetFullStatus()
var bestMatch *peer.State
var bestPrefixLen int
for _, peerState := range fullStatus.Peers {
routes := peerState.GetRoutes()
for route := range routes {
prefix, err := netip.ParsePrefix(route)
if err != nil {
continue
}
if prefix.Contains(ip) && prefix.Bits() > bestPrefixLen {
peerStateCopy := peerState
bestMatch = &peerStateCopy
bestPrefixLen = prefix.Bits()
}
}
}
return bestMatch
}
func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string {
if u.statusRecorder == nil {
return ""
}
peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder)
if peerInfo == nil {
return ""
}
return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo))
}

View File

@@ -10,7 +10,7 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
type upstreamResolver struct { type upstreamResolver struct {
@@ -50,7 +50,9 @@ func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns
} }
func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
upstreamExchangeClient := &dns.Client{} upstreamExchangeClient := &dns.Client{
Timeout: ClientTimeout,
}
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream) return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
} }
@@ -73,9 +75,10 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri
upstreamExchangeClient := &dns.Client{ upstreamExchangeClient := &dns.Client{
Dialer: dialer, Dialer: dialer,
Timeout: timeout,
} }
return upstreamExchangeClient.Exchange(r, upstream) return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
} }
func (u *upstreamResolver) isLocalResolver(upstream string) bool { func (u *upstreamResolver) isLocalResolver(upstream string) bool {

View File

@@ -34,7 +34,10 @@ func newUpstreamResolver(
} }
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream) client := &dns.Client{
Timeout: ClientTimeout,
}
return ExchangeWithFallback(ctx, client, r, upstream)
} }
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {

View File

@@ -47,7 +47,9 @@ func newUpstreamResolver(
} }
func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
client := &dns.Client{} client := &dns.Client{
Timeout: ClientTimeout,
}
upstreamHost, _, err := net.SplitHostPort(upstream) upstreamHost, _, err := net.SplitHostPort(upstream)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err) return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err)
@@ -111,6 +113,7 @@ func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Dura
} }
client := &dns.Client{ client := &dns.Client{
Dialer: dialer, Dialer: dialer,
Timeout: dialTimeout,
} }
return client, nil return client, nil
} }

View File

@@ -124,29 +124,26 @@ func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg)
} }
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
resolver := &upstreamResolverBase{ mockClient := &mockUpstreamResolver{
ctx: context.TODO(), err: dns.ErrTime,
upstreamClient: &mockUpstreamResolver{
err: nil,
r: new(dns.Msg), r: new(dns.Msg),
rtt: time.Millisecond, rtt: time.Millisecond,
}, }
resolver := &upstreamResolverBase{
ctx: context.TODO(),
upstreamClient: mockClient,
upstreamTimeout: UpstreamTimeout, upstreamTimeout: UpstreamTimeout,
reactivatePeriod: reactivatePeriod, reactivatePeriod: time.Microsecond * 100,
failsTillDeact: failsTillDeact,
} }
addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection
resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())} resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())}
resolver.failsTillDeact = 0
resolver.reactivatePeriod = time.Microsecond * 100
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error { return nil },
}
failed := false failed := false
resolver.deactivate = func(error) { resolver.deactivate = func(error) {
failed = true failed = true
// After deactivation, make the mock client work again
mockClient.err = nil
} }
reactivated := false reactivated := false
@@ -154,7 +151,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
reactivated = true reactivated = true
} }
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA)) resolver.ProbeAvailability()
if !failed { if !failed {
t.Errorf("expected that resolving was deactivated") t.Errorf("expected that resolving was deactivated")
@@ -173,11 +170,6 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
return return
} }
if resolver.failsCount.Load() != 0 {
t.Errorf("fails count after reactivation should be 0")
return
}
if resolver.disabled { if resolver.disabled {
t.Errorf("should be enabled") t.Errorf("should be enabled")
} }

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"sync"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -11,13 +12,17 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
var (
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
listenPort uint16 = 5353
listenPortMu sync.RWMutex
) )
const ( const (
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
ListenPort = 5353
dnsTTL = 60 //seconds dnsTTL = 60 //seconds
) )
@@ -35,12 +40,20 @@ type Manager struct {
fwRules []firewall.Rule fwRules []firewall.Rule
tcpRules []firewall.Rule tcpRules []firewall.Rule
dnsForwarder *DNSForwarder dnsForwarder *DNSForwarder
port uint16
} }
func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager { func ListenPort() uint16 {
listenPortMu.RLock()
defer listenPortMu.RUnlock()
return listenPort
}
func NewManager(fw firewall.Manager, statusRecorder *peer.Status, port uint16) *Manager {
return &Manager{ return &Manager{
firewall: fw, firewall: fw,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
port: port,
} }
} }
@@ -54,7 +67,13 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
return err return err
} }
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, m.firewall, m.statusRecorder) if m.port > 0 {
listenPortMu.Lock()
listenPort = m.port
listenPortMu.Unlock()
}
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder)
go func() { go func() {
if err := m.dnsForwarder.Listen(fwdEntries); err != nil { if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
// todo handle close error if it is exists // todo handle close error if it is exists
@@ -94,7 +113,7 @@ func (m *Manager) Stop(ctx context.Context) error {
func (m *Manager) allowDNSFirewall() error { func (m *Manager) allowDNSFirewall() error {
dport := &firewall.Port{ dport := &firewall.Port{
IsRange: false, IsRange: false,
Values: []uint16{ListenPort}, Values: []uint16{ListenPort()},
} }
if m.firewall == nil { if m.firewall == nil {

View File

@@ -7,6 +7,7 @@ import (
"math/rand" "math/rand"
"net" "net"
"net/netip" "net/netip"
"net/url"
"os" "os"
"reflect" "reflect"
"runtime" "runtime"
@@ -17,8 +18,8 @@ import (
"time" "time"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/pion/ice/v3" "github.com/pion/ice/v4"
"github.com/pion/stun/v2" "github.com/pion/stun/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -28,11 +29,12 @@ import (
"github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager" firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/ingressgw" "github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/netflow"
@@ -164,7 +166,7 @@ type Engine struct {
wgInterface WGIface wgInterface WGIface
udpMux *bind.UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service // networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
networkSerial uint64 networkSerial uint64
@@ -196,6 +198,13 @@ type Engine struct {
latestSyncResponse *mgmProto.SyncResponse latestSyncResponse *mgmProto.SyncResponse
connSemaphore *semaphoregroup.SemaphoreGroup connSemaphore *semaphoregroup.SemaphoreGroup
flowManager nftypes.FlowManager flowManager nftypes.FlowManager
// WireGuard interface monitor
wgIfaceMonitor *WGIfaceMonitor
wgIfaceMonitorWg sync.WaitGroup
// dns forwarder port
dnsFwdPort uint16
} }
// Peer is an instance of the Connection Peer // Peer is an instance of the Connection Peer
@@ -238,6 +247,7 @@ func NewEngine(
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
checks: checks, checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
dnsFwdPort: dnsfwd.ListenPort(),
} }
sm := profilemanager.NewServiceManager("") sm := profilemanager.NewServiceManager("")
@@ -339,13 +349,16 @@ func (e *Engine) Stop() error {
log.Errorf("failed to persist state: %v", err) log.Errorf("failed to persist state: %v", err)
} }
// Stop WireGuard interface monitor and wait for it to exit
e.wgIfaceMonitorWg.Wait()
return nil return nil
} }
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services // Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
// Connections to remote peers are not established here. // Connections to remote peers are not established here.
// However, they will be established once an event with a list of peers to connect to will be received from Management Service // However, they will be established once an event with a list of peers to connect to will be received from Management Service
func (e *Engine) Start() error { func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
@@ -401,6 +414,11 @@ func (e *Engine) Start() error {
} }
e.dnsServer = dnsServer e.dnsServer = dnsServer
// Populate DNS cache with NetbirdConfig and management URL for early resolution
if err := e.PopulateNetbirdConfig(netbirdConfig, mgmtURL); err != nil {
log.Warnf("failed to populate DNS cache: %v", err)
}
e.routeManager = routemanager.NewManager(routemanager.ManagerConfig{ e.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
Context: e.ctx, Context: e.ctx,
PublicKey: e.config.WgPrivateKey.PublicKey().String(), PublicKey: e.config.WgPrivateKey.PublicKey().String(),
@@ -450,14 +468,7 @@ func (e *Engine) Start() error {
return fmt.Errorf("initialize dns server: %w", err) return fmt.Errorf("initialize dns server: %w", err)
} }
iceCfg := icemaker.Config{ iceCfg := e.createICEConfig()
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
UDPMux: e.udpMux.UDPMuxDefault,
UDPMuxSrflx: e.udpMux,
NATExternalIPs: e.parseNATExternalIPMappings(),
}
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface) e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface)
e.connMgr.Start(e.ctx) e.connMgr.Start(e.ctx)
@@ -470,6 +481,22 @@ func (e *Engine) Start() error {
// starting network monitor at the very last to avoid disruptions // starting network monitor at the very last to avoid disruptions
e.startNetworkMonitor() e.startNetworkMonitor()
// monitor WireGuard interface lifecycle and restart engine on changes
e.wgIfaceMonitor = NewWGIfaceMonitor()
e.wgIfaceMonitorWg.Add(1)
go func() {
defer e.wgIfaceMonitorWg.Done()
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
e.restartEngine()
} else if err != nil {
log.Warnf("WireGuard interface monitor: %s", err)
}
}()
return nil return nil
} }
@@ -661,6 +688,30 @@ func (e *Engine) removePeer(peerKey string) error {
return nil return nil
} }
// PopulateNetbirdConfig populates the DNS cache with infrastructure domains from login response
func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error {
if e.dnsServer == nil {
return nil
}
// Populate management URL if provided
if mgmtURL != nil {
if err := e.dnsServer.PopulateManagementDomain(mgmtURL); err != nil {
log.Warnf("failed to populate DNS cache with management URL: %v", err)
}
}
// Populate NetbirdConfig domains if provided
if netbirdConfig != nil {
serverDomains := dnsconfig.ExtractFromNetbirdConfig(netbirdConfig)
if err := e.dnsServer.UpdateServerConfig(serverDomains); err != nil {
return fmt.Errorf("update DNS server config from NetbirdConfig: %w", err)
}
}
return nil
}
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
@@ -692,6 +743,10 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return fmt.Errorf("handle the flow configuration: %w", err) return fmt.Errorf("handle the flow configuration: %w", err)
} }
if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil {
log.Warnf("Failed to update DNS server config: %v", err)
}
// todo update signal // todo update signal
} }
@@ -914,7 +969,6 @@ func (e *Engine) receiveManagementEvents() {
e.config.LazyConnectionEnabled, e.config.LazyConnectionEnabled,
) )
// err = e.mgmClient.Sync(info, e.handleSync)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync) err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
if err != nil { if err != nil {
// happens if management is unavailable for a long time. // happens if management is unavailable for a long time.
@@ -925,7 +979,7 @@ func (e *Engine) receiveManagementEvents() {
} }
log.Debugf("stopped receiving updates from Management Service") log.Debugf("stopped receiving updates from Management Service")
}() }()
log.Debugf("connecting to Management Service updates stream") log.Infof("connecting to Management Service updates stream")
} }
func (e *Engine) updateSTUNs(stuns []*mgmProto.HostConfig) error { func (e *Engine) updateSTUNs(stuns []*mgmProto.HostConfig) error {
@@ -1030,7 +1084,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
} }
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes) fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries) e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries, uint16(protoDNSConfig.ForwarderPort))
// Ingress forward rules // Ingress forward rules
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules()) forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
@@ -1288,14 +1342,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
Addr: e.getRosenpassAddr(), Addr: e.getRosenpassAddr(),
PermissiveMode: e.config.RosenpassPermissive, PermissiveMode: e.config.RosenpassPermissive,
}, },
ICEConfig: icemaker.Config{ ICEConfig: e.createICEConfig(),
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
UDPMux: e.udpMux.UDPMuxDefault,
UDPMuxSrflx: e.udpMux,
NATExternalIPs: e.parseNATExternalIPMappings(),
},
} }
serviceDependencies := peer.ServiceDependencies{ serviceDependencies := peer.ServiceDependencies{
@@ -1557,7 +1604,14 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
return dnsServer, nil return dnsServer, nil
default: default:
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS)
dnsServer, err := dns.NewDefaultServer(e.ctx, dns.DefaultServerConfig{
WgInterface: e.wgInterface,
CustomAddress: e.config.CustomDNSAddress,
StatusRecorder: e.statusRecorder,
StateManager: e.stateManager,
DisableSys: e.config.DisableDNS,
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1789,6 +1843,7 @@ func (e *Engine) GetWgAddr() netip.Addr {
func (e *Engine) updateDNSForwarder( func (e *Engine) updateDNSForwarder(
enabled bool, enabled bool,
fwdEntries []*dnsfwd.ForwarderEntry, fwdEntries []*dnsfwd.ForwarderEntry,
forwarderPort uint16,
) { ) {
if e.config.DisableServerRoutes { if e.config.DisableServerRoutes {
return return
@@ -1805,16 +1860,20 @@ func (e *Engine) updateDNSForwarder(
} }
if len(fwdEntries) > 0 { if len(fwdEntries) > 0 {
if e.dnsForwardMgr == nil { switch {
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) case e.dnsForwardMgr == nil:
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort)
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
log.Errorf("failed to start DNS forward: %v", err) log.Errorf("failed to start DNS forward: %v", err)
e.dnsForwardMgr = nil e.dnsForwardMgr = nil
} }
log.Infof("started domain router service with %d entries", len(fwdEntries)) log.Infof("started domain router service with %d entries", len(fwdEntries))
} else { case e.dnsFwdPort != forwarderPort:
log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort)
e.restartDnsFwd(fwdEntries, forwarderPort)
e.dnsFwdPort = forwarderPort
default:
e.dnsForwardMgr.UpdateDomains(fwdEntries) e.dnsForwardMgr.UpdateDomains(fwdEntries)
} }
} else if e.dnsForwardMgr != nil { } else if e.dnsForwardMgr != nil {
@@ -1824,6 +1883,20 @@ func (e *Engine) updateDNSForwarder(
} }
e.dnsForwardMgr = nil e.dnsForwardMgr = nil
} }
}
func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPort uint16) {
log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort)
// stop and start the forwarder to apply the new port
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
log.Errorf("failed to stop DNS forward: %v", err)
}
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort)
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
log.Errorf("failed to start DNS forward: %v", err)
e.dnsForwardMgr = nil
}
} }
func (e *Engine) GetNet() (*netstack.Net, error) { func (e *Engine) GetNet() (*netstack.Net, error) {

Some files were not shown because too many files have changed in this diff Show More