mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-16 12:59:55 +00:00
Compare commits
51 Commits
use-conntr
...
remove-ubu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1e273c9017 | ||
|
|
1a6d6b3109 | ||
|
|
f686615876 | ||
|
|
a4311f574d | ||
|
|
0bb8eae903 | ||
|
|
e0b33d325d | ||
|
|
c38e07d89a | ||
|
|
a37368fff4 | ||
|
|
0c93bd3d06 | ||
|
|
a675531b5c | ||
|
|
7cb366bc7d | ||
|
|
a354004564 | ||
|
|
75bdd47dfb | ||
|
|
b165f63327 | ||
|
|
51bb52cdf5 | ||
|
|
4134b857b4 | ||
|
|
7839d2c169 | ||
|
|
b9f82e2f8a | ||
|
|
fd2a21c65d | ||
|
|
82d982b0ab | ||
|
|
9e24fe7701 | ||
|
|
e470701b80 | ||
|
|
e3ce026355 | ||
|
|
5ea2806663 | ||
|
|
d6b0673580 | ||
|
|
14913cfa7a | ||
|
|
03f600b576 | ||
|
|
192c97aa63 | ||
|
|
4db78db49a | ||
|
|
87e600a4f3 | ||
|
|
6162aeb82d | ||
|
|
1ba1e092ce | ||
|
|
86dbb4ee4f | ||
|
|
4af177215f | ||
|
|
df9c1b9883 | ||
|
|
5752bb78f2 | ||
|
|
fbd783ad58 | ||
|
|
80702b9323 | ||
|
|
09243a0fe0 | ||
|
|
3658215747 | ||
|
|
48ffec95dd | ||
|
|
cbec7bda80 | ||
|
|
21464ac770 | ||
|
|
ed5647028a | ||
|
|
29a6e5be71 | ||
|
|
6124e3b937 | ||
|
|
50f5cc48cd | ||
|
|
101cce27f2 | ||
|
|
a4f04f5570 | ||
|
|
fceb3ca392 | ||
|
|
34d86c5ab8 |
27
.git-branches.toml
Normal file
27
.git-branches.toml
Normal file
@@ -0,0 +1,27 @@
|
||||
# More info around this file at https://www.git-town.com/configuration-file
|
||||
|
||||
[branches]
|
||||
main = "main"
|
||||
perennials = []
|
||||
perennial-regex = ""
|
||||
|
||||
[create]
|
||||
new-branch-type = "feature"
|
||||
push-new-branches = false
|
||||
|
||||
[hosting]
|
||||
dev-remote = "origin"
|
||||
# platform = ""
|
||||
# origin-hostname = ""
|
||||
|
||||
[ship]
|
||||
delete-tracking-branch = false
|
||||
strategy = "squash-merge"
|
||||
|
||||
[sync]
|
||||
feature-strategy = "merge"
|
||||
perennial-strategy = "rebase"
|
||||
prototype-strategy = "merge"
|
||||
push-hook = true
|
||||
tags = true
|
||||
upstream = false
|
||||
4
.github/pull_request_template.md
vendored
4
.github/pull_request_template.md
vendored
@@ -2,6 +2,10 @@
|
||||
|
||||
## Issue ticket number and link
|
||||
|
||||
## Stack
|
||||
|
||||
<!-- branch-stack -->
|
||||
|
||||
### Checklist
|
||||
- [ ] Is it a bug fix
|
||||
- [ ] Is a typo/documentation fix
|
||||
|
||||
21
.github/workflows/git-town.yml
vendored
Normal file
21
.github/workflows/git-town.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
name: Git Town
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- '**'
|
||||
|
||||
jobs:
|
||||
git-town:
|
||||
name: Display the branch stack
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: git-town/action@v1
|
||||
with:
|
||||
skip-single-stacks: true
|
||||
10
.github/workflows/golang-test-freebsd.yml
vendored
10
.github/workflows/golang-test-freebsd.yml
vendored
@@ -22,14 +22,20 @@ jobs:
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
release: "14.1"
|
||||
release: "14.2"
|
||||
prepare: |
|
||||
pkg install -y go pkgconf xorg
|
||||
pkg install -y curl pkgconf xorg
|
||||
LATEST_VERSION=$(curl -s https://go.dev/VERSION?m=text|head -n 1)
|
||||
GO_TARBALL="$LATEST_VERSION.freebsd-amd64.tar.gz"
|
||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||
curl -vLO "$GO_URL"
|
||||
tar -C /usr/local -vxzf "$GO_TARBALL"
|
||||
|
||||
# -x - to print all executed commands
|
||||
# -e - to faile on first error
|
||||
run: |
|
||||
set -e -x
|
||||
export PATH=$PATH:/usr/local/go/bin:$HOME/go/bin
|
||||
time go build -o netbird client/main.go
|
||||
# check all component except management, since we do not support management server on freebsd
|
||||
time go test -timeout 1m -failfast ./base62/...
|
||||
|
||||
42
.github/workflows/golang-test-linux.yml
vendored
42
.github/workflows/golang-test-linux.yml
vendored
@@ -314,6 +314,7 @@ jobs:
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
go test -tags=devcert \
|
||||
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
|
||||
-timeout 20m ./management/...
|
||||
@@ -380,7 +381,8 @@ jobs:
|
||||
- name: Test
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
go test -tags devcert -run=^$ -bench=. \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||
-timeout 20m ./...
|
||||
@@ -396,6 +398,33 @@ jobs:
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Create Docker network
|
||||
run: docker network create promnet
|
||||
|
||||
- name: Start Prometheus Pushgateway
|
||||
run: docker run -d --name pushgateway --network promnet -p 9091:9091 prom/pushgateway
|
||||
|
||||
- name: Start Prometheus (for Pushgateway forwarding)
|
||||
run: |
|
||||
echo '
|
||||
global:
|
||||
scrape_interval: 15s
|
||||
scrape_configs:
|
||||
- job_name: "pushgateway"
|
||||
static_configs:
|
||||
- targets: ["pushgateway:9091"]
|
||||
remote_write:
|
||||
- url: ${{ secrets.GRAFANA_URL }}
|
||||
basic_auth:
|
||||
username: ${{ secrets.GRAFANA_USER }}
|
||||
password: ${{ secrets.GRAFANA_API_KEY }}
|
||||
' > prometheus.yml
|
||||
|
||||
docker run -d --name prometheus --network promnet \
|
||||
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
|
||||
-p 9090:9090 \
|
||||
prom/prometheus
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
@@ -447,11 +476,13 @@ jobs:
|
||||
- name: Test
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
GIT_BRANCH=${{ github.ref_name }} \
|
||||
go test -tags=benchmark \
|
||||
-run=^$ \
|
||||
-bench=. \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
||||
-timeout 20m ./management/...
|
||||
|
||||
api_integration_test:
|
||||
@@ -505,7 +536,8 @@ jobs:
|
||||
- name: Test
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
go test -tags=integration \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||
-timeout 20m ./management/...
|
||||
@@ -513,7 +545,7 @@ jobs:
|
||||
test_client_on_docker:
|
||||
name: "Client (Docker) / Unit"
|
||||
needs: [ build-cache ]
|
||||
runs-on: ubuntu-20.04
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
|
||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
|
||||
skip: go.mod,go.sum
|
||||
only_warn: 1
|
||||
golangci:
|
||||
|
||||
@@ -178,6 +178,7 @@ jobs:
|
||||
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
|
||||
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
|
||||
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
|
||||
grep DisablePromptLogin management.json | grep 'true'
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
20
README.md
20
README.md
@@ -57,16 +57,16 @@
|
||||
|
||||
### Key features
|
||||
|
||||
| Connectivity | Management | Security | Automation | Platforms |
|
||||
|------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|
|
||||
| <ul><li> - \[x] Kernel WireGuard </ul></li> | <ul><li> - \[x] [Admin Web UI](https://github.com/netbirdio/dashboard) </ul></li> | <ul><li> - \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) </ul></li> | <ul><li> - \[x] [Public API](https://docs.netbird.io/api) </ul></li> | <ul><li> - \[x] Linux </ul></li> |
|
||||
| <ul><li> - \[x] Peer-to-peer connections </ul></li> | <ul><li> - \[x] Auto peer discovery and configuration </ul></li> | <ul><li> - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access) </ul></li> | <ul><li> - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) </ul></li> | <ul><li> - \[x] Mac </ul></li> |
|
||||
| <ul><li> - \[x] Connection relay fallback </ul></li> | <ul><li> - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) </ul></li> | <ul><li> - \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity) </ul></li> | <ul><li> - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) </ul></li> | <ul><li> - \[x] Windows </ul></li> |
|
||||
| <ul><li> - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) </ul></li> | <ul><li> - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) </ul></li> | <ul><li> - \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks) </ul></li> | <ul><li> - \[x] IdP groups sync with JWT </ul></li> | <ul><li> - \[x] Android </ul></li> |
|
||||
| <ul><li> - \[x] NAT traversal with BPF </ul></li> | <ul><li> - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) </ul></li> | <ul><li> - \[x] Peer-to-peer encryption </ul></li> | | <ul><li> - \[x] iOS </ul></li> |
|
||||
| | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
|
||||
| | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> |
|
||||
| | | | | <ul><li> - \[x] Docker </ul></li> |
|
||||
| Connectivity | Management | Security | Automation| Platforms |
|
||||
|----|----|----|----|----|
|
||||
| <ul><li>- \[x] Kernel WireGuard</ul></li> | <ul><li>- \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)</ul></li> | <ul><li>- \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)</ul></li> | <ul><li>- \[x] [Public API](https://docs.netbird.io/api)</ul></li> | <ul><li>- \[x] Linux</ul></li> |
|
||||
| <ul><li>- \[x] Peer-to-peer connections</ul></li> | <ul><li>- \[x] Auto peer discovery and configuration</ui></li> | <ul><li>- \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access)</ui></li> | <ul><li>- \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys)</ui></li> | <ul><li>- \[x] Mac</ui></li> |
|
||||
| <ul><li>- \[x] Connection relay fallback</ui></li> | <ul><li>- \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers)</ui></li> | <ul><li>- \[x] [Activity logging](https://docs.netbird.io/how-to/audit-events-logging)</ui></li> | <ul><li>- \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart)</ui></li> | <ul><li>- \[x] Windows</ui></li> |
|
||||
| <ul><li>- \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks)</ui></li> | <ul><li>- \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network)</ui></li> | <ul><li>- \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks)</ui></li> | <ul><li>- \[x] IdP groups sync with JWT</ui></li> | <ul><li>- \[x] Android</ui></li> |
|
||||
| <ul><li>- \[x] NAT traversal with BPF</ui></li> | <ul><li>- \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network)</ui></li> | <ul><li>- \[x] Peer-to-peer encryption</ui></li> || <ul><li>- \[x] iOS</ui></li> |
|
||||
||| <ul><li>- \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)</ui></li> || <ul><li>- \[x] OpenWRT</ui></li> |
|
||||
||| <ul><li>- \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ui></li> || <ul><li>- \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas)</ui></li> |
|
||||
||||| <ul><li>- \[x] Docker</ui></li> |
|
||||
|
||||
### Quickstart with NetBird Cloud
|
||||
|
||||
|
||||
@@ -11,9 +11,12 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
const errCloseConnection = "Failed to close connection: %v"
|
||||
@@ -326,3 +329,34 @@ func formatDuration(d time.Duration) string {
|
||||
s := d / time.Second
|
||||
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
|
||||
}
|
||||
|
||||
func generateDebugBundle(config *internal.Config, recorder *peer.Status, connectClient *internal.ConnectClient, logFilePath string) {
|
||||
var networkMap *mgmProto.NetworkMap
|
||||
var err error
|
||||
|
||||
if connectClient != nil {
|
||||
networkMap, err = connectClient.GetLatestNetworkMap()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to get latest network map: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
bundleGenerator := debug.NewBundleGenerator(
|
||||
debug.GeneratorDependencies{
|
||||
InternalConfig: config,
|
||||
StatusRecorder: recorder,
|
||||
NetworkMap: networkMap,
|
||||
LogFile: logFilePath,
|
||||
},
|
||||
debug.BundleConfig{
|
||||
IncludeSystemInfo: true,
|
||||
},
|
||||
)
|
||||
|
||||
path, err := bundleGenerator.Generate()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to generate debug bundle: %v", err)
|
||||
return
|
||||
}
|
||||
log.Infof("Generated debug bundle from SIGUSR1 at: %s", path)
|
||||
}
|
||||
|
||||
39
client/cmd/debug_unix.go
Normal file
39
client/cmd/debug_unix.go
Normal file
@@ -0,0 +1,39 @@
|
||||
//go:build unix
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
func SetupDebugHandler(
|
||||
ctx context.Context,
|
||||
config *internal.Config,
|
||||
recorder *peer.Status,
|
||||
connectClient *internal.ConnectClient,
|
||||
logFilePath string,
|
||||
) {
|
||||
usr1Ch := make(chan os.Signal, 1)
|
||||
|
||||
signal.Notify(usr1Ch, syscall.SIGUSR1)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-usr1Ch:
|
||||
log.Info("Received SIGUSR1. Triggering debug bundle generation.")
|
||||
go generateDebugBundle(config, recorder, connectClient, logFilePath)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
126
client/cmd/debug_windows.go
Normal file
126
client/cmd/debug_windows.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
const (
|
||||
envListenEvent = "NB_LISTEN_DEBUG_EVENT"
|
||||
debugTriggerEventName = `Global\NetbirdDebugTriggerEvent`
|
||||
|
||||
waitTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// SetupDebugHandler sets up a Windows event to listen for a signal to generate a debug bundle.
|
||||
// Example usage with PowerShell:
|
||||
// $evt = [System.Threading.EventWaitHandle]::OpenExisting("Global\NetbirdDebugTriggerEvent")
|
||||
// $evt.Set()
|
||||
// $evt.Close()
|
||||
func SetupDebugHandler(
|
||||
ctx context.Context,
|
||||
config *internal.Config,
|
||||
recorder *peer.Status,
|
||||
connectClient *internal.ConnectClient,
|
||||
logFilePath string,
|
||||
) {
|
||||
env := os.Getenv(envListenEvent)
|
||||
if env == "" {
|
||||
return
|
||||
}
|
||||
|
||||
listenEvent, err := strconv.ParseBool(env)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse %s: %v", envListenEvent, err)
|
||||
return
|
||||
}
|
||||
if !listenEvent {
|
||||
return
|
||||
}
|
||||
|
||||
eventNamePtr, err := windows.UTF16PtrFromString(debugTriggerEventName)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to convert event name '%s' to UTF16: %v", debugTriggerEventName, err)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: restrict access by ACL
|
||||
eventHandle, err := windows.CreateEvent(nil, 1, 0, eventNamePtr)
|
||||
if err != nil {
|
||||
if errors.Is(err, windows.ERROR_ALREADY_EXISTS) {
|
||||
log.Warnf("Debug trigger event '%s' already exists. Attempting to open.", debugTriggerEventName)
|
||||
// SYNCHRONIZE is needed for WaitForSingleObject, EVENT_MODIFY_STATE for ResetEvent.
|
||||
eventHandle, err = windows.OpenEvent(windows.SYNCHRONIZE|windows.EVENT_MODIFY_STATE, false, eventNamePtr)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to open existing debug trigger event '%s': %v", debugTriggerEventName, err)
|
||||
return
|
||||
}
|
||||
log.Infof("Successfully opened existing debug trigger event '%s'.", debugTriggerEventName)
|
||||
} else {
|
||||
log.Errorf("Failed to create debug trigger event '%s': %v", debugTriggerEventName, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if eventHandle == windows.InvalidHandle {
|
||||
log.Errorf("Obtained an invalid handle for debug trigger event '%s'", debugTriggerEventName)
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("Debug handler waiting for signal on event: %s", debugTriggerEventName)
|
||||
|
||||
go waitForEvent(ctx, config, recorder, connectClient, logFilePath, eventHandle)
|
||||
}
|
||||
|
||||
func waitForEvent(
|
||||
ctx context.Context,
|
||||
config *internal.Config,
|
||||
recorder *peer.Status,
|
||||
connectClient *internal.ConnectClient,
|
||||
logFilePath string,
|
||||
eventHandle windows.Handle,
|
||||
) {
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(eventHandle); err != nil {
|
||||
log.Errorf("Failed to close debug event handle '%s': %v", debugTriggerEventName, err)
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
status, err := windows.WaitForSingleObject(eventHandle, uint32(waitTimeout.Milliseconds()))
|
||||
|
||||
switch status {
|
||||
case windows.WAIT_OBJECT_0:
|
||||
log.Info("Received signal on debug event. Triggering debug bundle generation.")
|
||||
|
||||
// reset the event so it can be triggered again later (manual reset == 1)
|
||||
if err := windows.ResetEvent(eventHandle); err != nil {
|
||||
log.Errorf("Failed to reset debug event '%s': %v", debugTriggerEventName, err)
|
||||
}
|
||||
|
||||
go generateDebugBundle(config, recorder, connectClient, logFilePath)
|
||||
case uint32(windows.WAIT_TIMEOUT):
|
||||
|
||||
default:
|
||||
log.Errorf("Unexpected status %d from WaitForSingleObject for debug event '%s': %v", status, debugTriggerEventName, err)
|
||||
select {
|
||||
case <-time.After(5 * time.Second):
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -19,6 +19,10 @@ import (
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func init() {
|
||||
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
}
|
||||
|
||||
var loginCmd = &cobra.Command{
|
||||
Use: "login",
|
||||
Short: "login to the Netbird Management Service (first run)",
|
||||
@@ -127,7 +131,7 @@ var loginCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
if loginResp.NeedsSSOLogin {
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
||||
|
||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||
if err != nil {
|
||||
@@ -198,7 +202,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
||||
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
||||
}
|
||||
|
||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
|
||||
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
|
||||
@@ -212,19 +216,27 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
||||
return &tokenInfo, nil
|
||||
}
|
||||
|
||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
|
||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser bool) {
|
||||
var codeMsg string
|
||||
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
|
||||
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
||||
}
|
||||
|
||||
cmd.Println("Please do the SSO login in your browser. \n" +
|
||||
"If your browser didn't open automatically, use this URL to log in:\n\n" +
|
||||
verificationURIComplete + " " + codeMsg)
|
||||
if noBrowser {
|
||||
cmd.Println("Use this URL to log in:\n\n" + verificationURIComplete + " " + codeMsg)
|
||||
} else {
|
||||
cmd.Println("Please do the SSO login in your browser. \n" +
|
||||
"If your browser didn't open automatically, use this URL to log in:\n\n" +
|
||||
verificationURIComplete + " " + codeMsg)
|
||||
}
|
||||
|
||||
cmd.Println("")
|
||||
if err := open.Run(verificationURIComplete); err != nil {
|
||||
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||
|
||||
if !noBrowser {
|
||||
if err := open.Run(verificationURIComplete); err != nil {
|
||||
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -115,6 +115,7 @@ var runCmd = &cobra.Command{
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
SetupCloseHandler(ctx, cancel)
|
||||
SetupDebugHandler(ctx, nil, nil, nil, logFile)
|
||||
|
||||
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
|
||||
if err != nil {
|
||||
|
||||
@@ -12,9 +12,11 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
@@ -32,7 +34,7 @@ import (
|
||||
|
||||
func startTestingServices(t *testing.T) string {
|
||||
t.Helper()
|
||||
config := &mgmt.Config{}
|
||||
config := &types.Config{}
|
||||
_, err := util.ReadJson("../testdata/management.json", config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -67,7 +69,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
return s, lis
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.Server, net.Listener) {
|
||||
func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc.Server, net.Listener) {
|
||||
t.Helper()
|
||||
|
||||
lis, err := net.Listen("tcp", ":0")
|
||||
@@ -90,13 +92,13 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -32,12 +32,16 @@ const (
|
||||
|
||||
const (
|
||||
dnsLabelsFlag = "extra-dns-labels"
|
||||
|
||||
noBrowserFlag = "no-browser"
|
||||
noBrowserDesc = "do not open the browser for SSO login"
|
||||
)
|
||||
|
||||
var (
|
||||
foregroundMode bool
|
||||
dnsLabels []string
|
||||
dnsLabelsValidated domain.List
|
||||
noBrowser bool
|
||||
|
||||
upCmd = &cobra.Command{
|
||||
Use: "up",
|
||||
@@ -65,6 +69,9 @@ func init() {
|
||||
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
|
||||
`or --extra-dns-labels ""`,
|
||||
)
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
|
||||
}
|
||||
|
||||
func upFunc(cmd *cobra.Command, args []string) error {
|
||||
@@ -212,6 +219,8 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
r.GetFullStatus()
|
||||
|
||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||
SetupDebugHandler(ctx, config, r, connectClient, "")
|
||||
|
||||
return connectClient.Run(nil)
|
||||
}
|
||||
|
||||
@@ -349,7 +358,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
|
||||
if loginResp.NeedsSSOLogin {
|
||||
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
||||
|
||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||
if err != nil {
|
||||
|
||||
@@ -38,10 +38,12 @@ const (
|
||||
routingFinalForwardJump = "ACCEPT"
|
||||
routingFinalNatJump = "MASQUERADE"
|
||||
|
||||
jumpManglePre = "jump-mangle-pre"
|
||||
jumpNatPre = "jump-nat-pre"
|
||||
jumpNatPost = "jump-nat-post"
|
||||
matchSet = "--match-set"
|
||||
jumpManglePre = "jump-mangle-pre"
|
||||
jumpNatPre = "jump-nat-pre"
|
||||
jumpNatPost = "jump-nat-post"
|
||||
markManglePre = "mark-mangle-pre"
|
||||
markManglePost = "mark-mangle-post"
|
||||
matchSet = "--match-set"
|
||||
|
||||
dnatSuffix = "_dnat"
|
||||
snatSuffix = "_snat"
|
||||
@@ -115,6 +117,10 @@ func (r *router) init(stateManager *statemanager.Manager) error {
|
||||
return fmt.Errorf("create containers: %w", err)
|
||||
}
|
||||
|
||||
if err := r.setupDataPlaneMark(); err != nil {
|
||||
log.Errorf("failed to set up data plane mark: %v", err)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nil
|
||||
@@ -348,12 +354,16 @@ func (r *router) Reset() error {
|
||||
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
r.rules = make(map[string][]string)
|
||||
|
||||
if err := r.ipsetCounter.Flush(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.cleanupDataPlaneMark(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
r.rules = make(map[string][]string)
|
||||
r.updateState()
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
@@ -423,6 +433,57 @@ func (r *router) createContainers() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupDataPlaneMark configures the fwmark for the data plane
|
||||
func (r *router) setupDataPlaneMark() error {
|
||||
var merr *multierror.Error
|
||||
preRule := []string{
|
||||
"-i", r.wgIface.Name(),
|
||||
"-m", "conntrack", "--ctstate", "NEW",
|
||||
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkIn),
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.AppendUnique(tableMangle, chainPREROUTING, preRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add mangle prerouting rule: %w", err))
|
||||
} else {
|
||||
r.rules[markManglePre] = preRule
|
||||
}
|
||||
|
||||
postRule := []string{
|
||||
"-o", r.wgIface.Name(),
|
||||
"-m", "conntrack", "--ctstate", "NEW",
|
||||
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkOut),
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.AppendUnique(tableMangle, chainPOSTROUTING, postRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add mangle postrouting rule: %w", err))
|
||||
} else {
|
||||
r.rules[markManglePost] = postRule
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) cleanupDataPlaneMark() error {
|
||||
var merr *multierror.Error
|
||||
if preRule, exists := r.rules[markManglePre]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPREROUTING, preRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove mangle prerouting rule: %w", err))
|
||||
} else {
|
||||
delete(r.rules, markManglePre)
|
||||
}
|
||||
}
|
||||
|
||||
if postRule, exists := r.rules[markManglePost]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPOSTROUTING, postRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove mangle postrouting rule: %w", err))
|
||||
} else {
|
||||
delete(r.rules, markManglePost)
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) addPostroutingRules() error {
|
||||
// First rule for outbound masquerade
|
||||
rule1 := []string{
|
||||
@@ -464,7 +525,7 @@ func (r *router) insertEstablishedRule(chain string) error {
|
||||
}
|
||||
|
||||
func (r *router) addJumpRules() error {
|
||||
// Jump to NAT chain
|
||||
// Jump to nat chain
|
||||
natRule := []string{"-j", chainRTNAT}
|
||||
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
|
||||
return fmt.Errorf("add nat postrouting jump rule: %v", err)
|
||||
|
||||
@@ -46,7 +46,9 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
// 5. jump rule to PRE nat chain
|
||||
// 6. static outbound masquerade rule
|
||||
// 7. static return masquerade rule
|
||||
require.Len(t, manager.rules, 7, "should have created rules map")
|
||||
// 8. mangle prerouting mark rule
|
||||
// 9. mangle postrouting mark rule
|
||||
require.Len(t, manager.rules, 9, "should have created rules map")
|
||||
|
||||
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
||||
|
||||
@@ -25,9 +25,10 @@ const (
|
||||
chainNameInputRules = "netbird-acl-input-rules"
|
||||
|
||||
// filter chains contains the rules that jump to the rules chains
|
||||
chainNameInputFilter = "netbird-acl-input-filter"
|
||||
chainNameForwardFilter = "netbird-acl-forward-filter"
|
||||
chainNamePrerouting = "netbird-rt-prerouting"
|
||||
chainNameInputFilter = "netbird-acl-input-filter"
|
||||
chainNameForwardFilter = "netbird-acl-forward-filter"
|
||||
chainNameManglePrerouting = "netbird-mangle-prerouting"
|
||||
chainNameManglePostrouting = "netbird-mangle-postrouting"
|
||||
|
||||
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
|
||||
)
|
||||
@@ -462,13 +463,15 @@ func (m *AclManager) createDefaultChains() (err error) {
|
||||
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
||||
// netbird peer IP.
|
||||
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
||||
m.chainPrerouting = m.rConn.AddChain(&nftables.Chain{
|
||||
Name: chainNamePrerouting,
|
||||
// Chain is created by route manager
|
||||
// TODO: move creation to a common place
|
||||
m.chainPrerouting = &nftables.Chain{
|
||||
Name: chainNameManglePrerouting,
|
||||
Table: m.workTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
})
|
||||
}
|
||||
|
||||
m.addFwmarkToForward(chainFwFilter)
|
||||
|
||||
|
||||
@@ -100,6 +100,10 @@ func (r *router) init(workTable *nftables.Table) error {
|
||||
return fmt.Errorf("create containers: %w", err)
|
||||
}
|
||||
|
||||
if err := r.setupDataPlaneMark(); err != nil {
|
||||
log.Errorf("failed to set up data plane mark: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -196,15 +200,21 @@ func (r *router) createContainers() error {
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
// Chain is created by acl manager
|
||||
// TODO: move creation to a common place
|
||||
r.chains[chainNamePrerouting] = &nftables.Chain{
|
||||
Name: chainNamePrerouting,
|
||||
r.chains[chainNameManglePostrouting] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameManglePostrouting,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookPostrouting,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
})
|
||||
|
||||
r.chains[chainNameManglePrerouting] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameManglePrerouting,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
}
|
||||
})
|
||||
|
||||
// Add the single NAT rule that matches on mark
|
||||
if err := r.addPostroutingRules(); err != nil {
|
||||
@@ -220,7 +230,83 @@ func (r *router) createContainers() error {
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("nftables: unable to initialize table: %v", err)
|
||||
return fmt.Errorf("initialize tables: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupDataPlaneMark configures the fwmark for the data plane
|
||||
func (r *router) setupDataPlaneMark() error {
|
||||
if r.chains[chainNameManglePrerouting] == nil || r.chains[chainNameManglePostrouting] == nil {
|
||||
return errors.New("no mangle chains found")
|
||||
}
|
||||
|
||||
ctNew := getCtNewExprs()
|
||||
preExprs := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
}
|
||||
preExprs = append(preExprs, ctNew...)
|
||||
preExprs = append(preExprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkIn),
|
||||
},
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
)
|
||||
|
||||
preNftRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameManglePrerouting],
|
||||
Exprs: preExprs,
|
||||
}
|
||||
r.conn.AddRule(preNftRule)
|
||||
|
||||
postExprs := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyOIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
}
|
||||
postExprs = append(postExprs, ctNew...)
|
||||
postExprs = append(postExprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkOut),
|
||||
},
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
)
|
||||
|
||||
postNftRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameManglePostrouting],
|
||||
Exprs: postExprs,
|
||||
}
|
||||
r.conn.AddRule(postNftRule)
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -516,26 +602,10 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
op = expr.CmpOpNeq
|
||||
}
|
||||
|
||||
exprs := []expr.Any{
|
||||
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
|
||||
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeySTATE,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
|
||||
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
|
||||
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
|
||||
exprs := getCtNewExprs()
|
||||
exprs = append(exprs,
|
||||
// interface matching
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
@@ -546,7 +616,7 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
exprs = append(exprs, sourceExp...)
|
||||
exprs = append(exprs, destExp...)
|
||||
@@ -578,7 +648,7 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
|
||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNamePrerouting],
|
||||
Chain: r.chains[chainNameManglePrerouting],
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleKey),
|
||||
})
|
||||
@@ -1324,3 +1394,24 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
|
||||
|
||||
return exprs
|
||||
}
|
||||
|
||||
func getCtNewExprs() []expr.Any {
|
||||
return []expr.Any{
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeySTATE,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -100,7 +100,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||
found := 0
|
||||
for _, chain := range rtr.chains {
|
||||
if chain.Name == chainNamePrerouting {
|
||||
if chain.Name == chainNameManglePrerouting {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
@@ -141,7 +141,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
// Verify the rule was added
|
||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||
found := false
|
||||
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
|
||||
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||
require.NoError(t, err, "should list rules")
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
@@ -157,7 +157,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
|
||||
// Verify the rule was removed
|
||||
found = false
|
||||
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
|
||||
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||
require.NoError(t, err, "should list rules after removal")
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
@@ -12,7 +11,7 @@ import (
|
||||
)
|
||||
|
||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||
|
||||
// Memory pressure tests
|
||||
func BenchmarkMemoryPressure(b *testing.B) {
|
||||
|
||||
@@ -189,7 +189,7 @@ func (t *ICMPTracker) cleanup() {
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Debug("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||
t.logger.Trace("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
|
||||
@@ -23,11 +23,11 @@ const (
|
||||
)
|
||||
|
||||
const (
|
||||
TCPSyn uint8 = 0x02
|
||||
TCPAck uint8 = 0x10
|
||||
TCPFin uint8 = 0x01
|
||||
TCPSyn uint8 = 0x02
|
||||
TCPRst uint8 = 0x04
|
||||
TCPPush uint8 = 0x08
|
||||
TCPAck uint8 = 0x10
|
||||
TCPUrg uint8 = 0x20
|
||||
)
|
||||
|
||||
@@ -41,7 +41,7 @@ const (
|
||||
)
|
||||
|
||||
// TCPState represents the state of a TCP connection
|
||||
type TCPState int
|
||||
type TCPState int32
|
||||
|
||||
func (s TCPState) String() string {
|
||||
switch s {
|
||||
@@ -89,22 +89,25 @@ const (
|
||||
// TCPConnTrack represents a TCP connection state
|
||||
type TCPConnTrack struct {
|
||||
BaseConnTrack
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
State TCPState
|
||||
established atomic.Bool
|
||||
tombstone atomic.Bool
|
||||
sync.RWMutex
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
state atomic.Int32
|
||||
tombstone atomic.Bool
|
||||
}
|
||||
|
||||
// IsEstablished safely checks if connection is established
|
||||
func (t *TCPConnTrack) IsEstablished() bool {
|
||||
return t.established.Load()
|
||||
// GetState safely retrieves the current state
|
||||
func (t *TCPConnTrack) GetState() TCPState {
|
||||
return TCPState(t.state.Load())
|
||||
}
|
||||
|
||||
// SetEstablished safely sets the established state
|
||||
func (t *TCPConnTrack) SetEstablished(state bool) {
|
||||
t.established.Store(state)
|
||||
// SetState safely updates the current state
|
||||
func (t *TCPConnTrack) SetState(state TCPState) {
|
||||
t.state.Store(int32(state))
|
||||
}
|
||||
|
||||
// CompareAndSwapState atomically changes the state from old to new if current == old
|
||||
func (t *TCPConnTrack) CompareAndSwapState(old, newState TCPState) bool {
|
||||
return t.state.CompareAndSwap(int32(old), int32(newState))
|
||||
}
|
||||
|
||||
// IsTombstone safely checks if the connection is marked for deletion
|
||||
@@ -125,13 +128,17 @@ type TCPTracker struct {
|
||||
cleanupTicker *time.Ticker
|
||||
tickerCancel context.CancelFunc
|
||||
timeout time.Duration
|
||||
waitTimeout time.Duration
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
|
||||
// NewTCPTracker creates a new TCP connection tracker
|
||||
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker {
|
||||
waitTimeout := TimeWaitTimeout
|
||||
if timeout == 0 {
|
||||
timeout = DefaultTCPTimeout
|
||||
} else {
|
||||
waitTimeout = timeout / 45
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@@ -142,6 +149,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
|
||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||
tickerCancel: cancel,
|
||||
timeout: timeout,
|
||||
waitTimeout: waitTimeout,
|
||||
flowLogger: flowLogger,
|
||||
}
|
||||
|
||||
@@ -149,7 +157,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
|
||||
return tracker
|
||||
}
|
||||
|
||||
func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
@@ -162,12 +170,7 @@ func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if exists {
|
||||
conn.Lock()
|
||||
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
|
||||
conn.Unlock()
|
||||
|
||||
conn.UpdateCounters(direction, size)
|
||||
|
||||
t.updateState(key, conn, flags, direction, size)
|
||||
return key, true
|
||||
}
|
||||
|
||||
@@ -175,7 +178,7 @@ func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound TCP connection
|
||||
func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) {
|
||||
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists {
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
|
||||
@@ -183,14 +186,14 @@ func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort u
|
||||
}
|
||||
|
||||
// TrackInbound processes an inbound TCP packet and updates connection state
|
||||
func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, ruleID []byte, size int) {
|
||||
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
|
||||
}
|
||||
|
||||
// track is the common implementation for tracking both inbound and outbound connections
|
||||
func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
|
||||
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
|
||||
if exists {
|
||||
if exists || flags&TCPSyn == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -205,12 +208,11 @@ func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
|
||||
DestPort: dstPort,
|
||||
}
|
||||
|
||||
conn.established.Store(false)
|
||||
conn.tombstone.Store(false)
|
||||
conn.state.Store(int32(TCPStateNew))
|
||||
|
||||
t.logger.Trace("New %s TCP connection: %s", direction, key)
|
||||
t.updateState(key, conn, flags, direction == nftypes.Egress)
|
||||
conn.UpdateCounters(direction, size)
|
||||
t.updateState(key, conn, flags, direction, size)
|
||||
|
||||
t.mutex.Lock()
|
||||
t.connections[key] = conn
|
||||
@@ -220,7 +222,7 @@ func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||
func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) bool {
|
||||
func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) bool {
|
||||
key := ConnKey{
|
||||
SrcIP: dstIP,
|
||||
DstIP: srcIP,
|
||||
@@ -232,134 +234,125 @@ func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
if !exists || conn.IsTombstone() {
|
||||
return false
|
||||
}
|
||||
|
||||
// Handle RST flag specially - it always causes transition to closed
|
||||
if flags&TCPRst != 0 {
|
||||
return t.handleRst(key, conn, size)
|
||||
currentState := conn.GetState()
|
||||
if !t.isValidStateForFlags(currentState, flags) {
|
||||
t.logger.Warn("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
|
||||
// allow all flags for established for now
|
||||
if currentState == TCPStateEstablished {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
conn.Lock()
|
||||
t.updateState(key, conn, flags, false)
|
||||
isEstablished := conn.IsEstablished()
|
||||
isValidState := t.isValidStateForFlags(conn.State, flags)
|
||||
conn.Unlock()
|
||||
conn.UpdateCounters(nftypes.Ingress, size)
|
||||
|
||||
return isEstablished || isValidState
|
||||
}
|
||||
|
||||
func (t *TCPTracker) handleRst(key ConnKey, conn *TCPConnTrack, size int) bool {
|
||||
if conn.IsTombstone() {
|
||||
return true
|
||||
}
|
||||
|
||||
conn.Lock()
|
||||
conn.SetTombstone()
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetEstablished(false)
|
||||
conn.Unlock()
|
||||
conn.UpdateCounters(nftypes.Ingress, size)
|
||||
|
||||
t.logger.Trace("TCP connection reset: %s", key)
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
t.updateState(key, conn, flags, nftypes.Ingress, size)
|
||||
return true
|
||||
}
|
||||
|
||||
// updateState updates the TCP connection state based on flags
|
||||
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) {
|
||||
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, packetDir nftypes.Direction, size int) {
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(packetDir, size)
|
||||
|
||||
state := conn.State
|
||||
defer func() {
|
||||
if state != conn.State {
|
||||
t.logger.Trace("TCP connection %s transitioned from %s to %s", key, state, conn.State)
|
||||
currentState := conn.GetState()
|
||||
|
||||
if flags&TCPRst != 0 {
|
||||
if conn.CompareAndSwapState(currentState, TCPStateClosed) {
|
||||
conn.SetTombstone()
|
||||
t.logger.Trace("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
}()
|
||||
return
|
||||
}
|
||||
|
||||
switch state {
|
||||
var newState TCPState
|
||||
switch currentState {
|
||||
case TCPStateNew:
|
||||
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
||||
conn.State = TCPStateSynSent
|
||||
if conn.Direction == nftypes.Egress {
|
||||
newState = TCPStateSynSent
|
||||
} else {
|
||||
newState = TCPStateSynReceived
|
||||
}
|
||||
}
|
||||
|
||||
case TCPStateSynSent:
|
||||
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
|
||||
if isOutbound {
|
||||
conn.State = TCPStateEstablished
|
||||
conn.SetEstablished(true)
|
||||
if packetDir != conn.Direction {
|
||||
newState = TCPStateEstablished
|
||||
} else {
|
||||
// Simultaneous open
|
||||
conn.State = TCPStateSynReceived
|
||||
newState = TCPStateSynReceived
|
||||
}
|
||||
}
|
||||
|
||||
case TCPStateSynReceived:
|
||||
if flags&TCPAck != 0 && flags&TCPSyn == 0 {
|
||||
conn.State = TCPStateEstablished
|
||||
conn.SetEstablished(true)
|
||||
if packetDir == conn.Direction {
|
||||
newState = TCPStateEstablished
|
||||
}
|
||||
}
|
||||
|
||||
case TCPStateEstablished:
|
||||
if flags&TCPFin != 0 {
|
||||
if isOutbound {
|
||||
conn.State = TCPStateFinWait1
|
||||
if packetDir == conn.Direction {
|
||||
newState = TCPStateFinWait1
|
||||
} else {
|
||||
conn.State = TCPStateCloseWait
|
||||
newState = TCPStateCloseWait
|
||||
}
|
||||
conn.SetEstablished(false)
|
||||
} else if flags&TCPRst != 0 {
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetTombstone()
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
|
||||
case TCPStateFinWait1:
|
||||
switch {
|
||||
case flags&TCPFin != 0 && flags&TCPAck != 0:
|
||||
conn.State = TCPStateClosing
|
||||
case flags&TCPFin != 0:
|
||||
conn.State = TCPStateFinWait2
|
||||
case flags&TCPAck != 0:
|
||||
conn.State = TCPStateFinWait2
|
||||
case flags&TCPRst != 0:
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetTombstone()
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
if packetDir != conn.Direction {
|
||||
switch {
|
||||
case flags&TCPFin != 0 && flags&TCPAck != 0:
|
||||
newState = TCPStateClosing
|
||||
case flags&TCPFin != 0:
|
||||
newState = TCPStateClosing
|
||||
case flags&TCPAck != 0:
|
||||
newState = TCPStateFinWait2
|
||||
}
|
||||
}
|
||||
|
||||
case TCPStateFinWait2:
|
||||
if flags&TCPFin != 0 {
|
||||
conn.State = TCPStateTimeWait
|
||||
|
||||
t.logger.Trace("TCP connection %s completed", key)
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
newState = TCPStateTimeWait
|
||||
}
|
||||
|
||||
case TCPStateClosing:
|
||||
if flags&TCPAck != 0 {
|
||||
conn.State = TCPStateTimeWait
|
||||
// Keep established = false from previous state
|
||||
|
||||
t.logger.Trace("TCP connection %s closed (simultaneous)", key)
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
newState = TCPStateTimeWait
|
||||
}
|
||||
|
||||
case TCPStateCloseWait:
|
||||
if flags&TCPFin != 0 {
|
||||
conn.State = TCPStateLastAck
|
||||
newState = TCPStateLastAck
|
||||
}
|
||||
|
||||
case TCPStateLastAck:
|
||||
if flags&TCPAck != 0 {
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetTombstone()
|
||||
newState = TCPStateClosed
|
||||
}
|
||||
}
|
||||
|
||||
// Send close event for gracefully closed connections
|
||||
if newState != 0 && conn.CompareAndSwapState(currentState, newState) {
|
||||
t.logger.Trace("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir)
|
||||
|
||||
switch newState {
|
||||
case TCPStateTimeWait:
|
||||
t.logger.Trace("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
|
||||
case TCPStateClosed:
|
||||
conn.SetTombstone()
|
||||
t.logger.Trace("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
t.logger.Trace("TCP connection %s closed gracefully", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -369,18 +362,22 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
|
||||
if !isValidFlagCombination(flags) {
|
||||
return false
|
||||
}
|
||||
if flags&TCPRst != 0 {
|
||||
if state == TCPStateSynSent {
|
||||
return flags&TCPAck != 0
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
switch state {
|
||||
case TCPStateNew:
|
||||
return flags&TCPSyn != 0 && flags&TCPAck == 0
|
||||
case TCPStateSynSent:
|
||||
// TODO: support simultaneous open
|
||||
return flags&TCPSyn != 0 && flags&TCPAck != 0
|
||||
case TCPStateSynReceived:
|
||||
return flags&TCPAck != 0
|
||||
case TCPStateEstablished:
|
||||
if flags&TCPRst != 0 {
|
||||
return true
|
||||
}
|
||||
return flags&TCPAck != 0
|
||||
case TCPStateFinWait1:
|
||||
return flags&TCPFin != 0 || flags&TCPAck != 0
|
||||
@@ -397,9 +394,7 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
|
||||
case TCPStateLastAck:
|
||||
return flags&TCPAck != 0
|
||||
case TCPStateClosed:
|
||||
// Accept retransmitted ACKs in closed state
|
||||
// This is important because the final ACK might be lost
|
||||
// and the peer will retransmit their FIN-ACK
|
||||
// Accept retransmitted ACKs in closed state, the final ACK might be lost and the peer will retransmit their FIN-ACK
|
||||
return flags&TCPAck != 0
|
||||
}
|
||||
return false
|
||||
@@ -430,23 +425,24 @@ func (t *TCPTracker) cleanup() {
|
||||
}
|
||||
|
||||
var timeout time.Duration
|
||||
switch {
|
||||
case conn.State == TCPStateTimeWait:
|
||||
timeout = TimeWaitTimeout
|
||||
case conn.IsEstablished():
|
||||
currentState := conn.GetState()
|
||||
switch currentState {
|
||||
case TCPStateTimeWait:
|
||||
timeout = t.waitTimeout
|
||||
case TCPStateEstablished:
|
||||
timeout = t.timeout
|
||||
default:
|
||||
timeout = TCPHandshakeTimeout
|
||||
}
|
||||
|
||||
if conn.timeoutExceeded(timeout) {
|
||||
// Return IPs to pool
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Trace("Cleaned up timed-out TCP connection %s", key)
|
||||
t.logger.Trace("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
|
||||
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
|
||||
// event already handled by state change
|
||||
if conn.State != TCPStateTimeWait {
|
||||
if currentState != TCPStateTimeWait {
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
}
|
||||
|
||||
83
client/firewall/uspfilter/conntrack/tcp_bench_test.go
Normal file
83
client/firewall/uspfilter/conntrack/tcp_bench_test.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func BenchmarkTCPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck|TCPSyn, 0)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("ConcurrentAccess", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
if i%2 == 0 {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||
} else {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck|TCPSyn, 0)
|
||||
}
|
||||
i++
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Benchmark connection cleanup
|
||||
func BenchmarkCleanup(b *testing.B) {
|
||||
b.Run("TCPCleanup", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
// Pre-populate with expired connections
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
for i := 0; i < 10000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||
}
|
||||
|
||||
// Wait for connections to expire
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.cleanup()
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -124,9 +125,6 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
// Receive RST
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||
require.True(t, valid, "RST should be allowed for established connection")
|
||||
|
||||
// Connection is logically dead but we don't enforce blocking subsequent packets
|
||||
// The connection will be cleaned up by timeout
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -217,97 +215,446 @@ func TestRSTHandling(t *testing.T) {
|
||||
conn := tracker.connections[key]
|
||||
if tt.wantValid {
|
||||
require.NotNil(t, conn)
|
||||
require.Equal(t, TCPStateClosed, conn.State)
|
||||
require.False(t, conn.IsEstablished())
|
||||
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTCPRetransmissions(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
// Test SYN retransmission
|
||||
t.Run("SYN Retransmission", func(t *testing.T) {
|
||||
// Initial SYN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
|
||||
// Retransmit SYN (should not affect the state machine)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
|
||||
// Verify we're still in SYN-SENT state
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
require.Equal(t, TCPStateSynSent, conn.GetState())
|
||||
|
||||
// Complete the handshake
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
// Verify we're in ESTABLISHED state
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
})
|
||||
|
||||
// Test ACK retransmission in established state
|
||||
t.Run("ACK Retransmission", func(t *testing.T) {
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
|
||||
// Establish connection
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Get connection object
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
|
||||
// Retransmit ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
// State should remain ESTABLISHED
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
})
|
||||
|
||||
// Test FIN retransmission
|
||||
t.Run("FIN Retransmission", func(t *testing.T) {
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
|
||||
// Establish connection
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Get connection object
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Send FIN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||
|
||||
// Retransmit FIN (should not change state)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||
|
||||
// Receive ACK for FIN
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateFinWait2, conn.GetState())
|
||||
})
|
||||
}
|
||||
|
||||
func TestTCPDataTransfer(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
t.Run("Data Transfer", func(t *testing.T) {
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Get connection object
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Send data
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPPush|TCPAck, 1000)
|
||||
|
||||
// Receive ACK for data
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 100)
|
||||
require.True(t, valid)
|
||||
|
||||
// Receive data
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 1500)
|
||||
require.True(t, valid)
|
||||
|
||||
// Send ACK for received data
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
|
||||
|
||||
// State should remain ESTABLISHED
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
|
||||
assert.Equal(t, uint64(1300), conn.BytesTx.Load())
|
||||
assert.Equal(t, uint64(1700), conn.BytesRx.Load())
|
||||
assert.Equal(t, uint64(4), conn.PacketsTx.Load())
|
||||
assert.Equal(t, uint64(3), conn.PacketsRx.Load())
|
||||
})
|
||||
}
|
||||
|
||||
func TestTCPHalfClosedConnections(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
// Test half-closed connection: local end closes, remote end continues sending data
|
||||
t.Run("Local Close, Remote Data", func(t *testing.T) {
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Send FIN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||
|
||||
// Receive ACK for FIN
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateFinWait2, conn.GetState())
|
||||
|
||||
// Remote end can still send data
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 1000)
|
||||
require.True(t, valid)
|
||||
|
||||
// We can still ACK their data
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
// Receive FIN from remote end
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
|
||||
// Send final ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
// State should remain TIME-WAIT (waiting for possible retransmissions)
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
})
|
||||
|
||||
// Test half-closed connection: remote end closes, local end continues sending data
|
||||
t.Run("Remote Close, Local Data", func(t *testing.T) {
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
|
||||
// Establish connection
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Get connection object
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Receive FIN from remote
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateCloseWait, conn.GetState())
|
||||
|
||||
// We can still send data
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPPush|TCPAck, 1000)
|
||||
|
||||
// Remote can still ACK our data
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
|
||||
// Send our FIN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.Equal(t, TCPStateLastAck, conn.GetState())
|
||||
|
||||
// Receive final ACK
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||
})
|
||||
}
|
||||
|
||||
func TestTCPAbnormalSequences(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
// Test handling of unsolicited RST in various states
|
||||
t.Run("Unsolicited RST in SYN-SENT", func(t *testing.T) {
|
||||
// Send SYN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
|
||||
// Receive unsolicited RST (without proper ACK)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||
require.False(t, valid, "RST without proper ACK in SYN-SENT should be rejected")
|
||||
|
||||
// Receive RST with proper ACK
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
|
||||
require.True(t, valid, "RST with proper ACK in SYN-SENT should be accepted")
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||
require.True(t, conn.IsTombstone())
|
||||
})
|
||||
}
|
||||
|
||||
func TestTCPTimeoutHandling(t *testing.T) {
|
||||
// Create tracker with a very short timeout for testing
|
||||
shortTimeout := 100 * time.Millisecond
|
||||
tracker := NewTCPTracker(shortTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
t.Run("Connection Timeout", func(t *testing.T) {
|
||||
// Establish a connection
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Get connection object
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
|
||||
// Wait for the connection to timeout
|
||||
time.Sleep(2 * shortTimeout)
|
||||
|
||||
// Force cleanup
|
||||
tracker.cleanup()
|
||||
|
||||
// Connection should be removed
|
||||
_, exists := tracker.connections[key]
|
||||
require.False(t, exists, "Connection should be removed after timeout")
|
||||
})
|
||||
|
||||
t.Run("TIME_WAIT Timeout", func(t *testing.T) {
|
||||
tracker = NewTCPTracker(shortTimeout, logger, flowLogger)
|
||||
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Complete the connection close to enter TIME_WAIT
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
|
||||
// TIME_WAIT should have its own timeout value (usually 2*MSL)
|
||||
// For the test, we're using a short timeout
|
||||
time.Sleep(2 * shortTimeout)
|
||||
|
||||
tracker.cleanup()
|
||||
|
||||
// Connection should be removed
|
||||
_, exists := tracker.connections[key]
|
||||
require.False(t, exists, "Connection should be removed after TIME_WAIT timeout")
|
||||
})
|
||||
}
|
||||
|
||||
func TestSynFlood(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
basePort := uint16(10000)
|
||||
dstPort := uint16(80)
|
||||
|
||||
// Create a large number of SYN packets to simulate a SYN flood
|
||||
for i := uint16(0); i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, basePort+i, dstPort, TCPSyn, 0)
|
||||
}
|
||||
|
||||
// Check that we're tracking all connections
|
||||
require.Equal(t, 1000, len(tracker.connections))
|
||||
|
||||
// Now simulate SYN timeout
|
||||
var oldConns int
|
||||
tracker.mutex.Lock()
|
||||
for _, conn := range tracker.connections {
|
||||
if conn.GetState() == TCPStateSynSent {
|
||||
// Make the connection appear old
|
||||
conn.lastSeen.Store(time.Now().Add(-TCPHandshakeTimeout - time.Second).UnixNano())
|
||||
oldConns++
|
||||
}
|
||||
}
|
||||
tracker.mutex.Unlock()
|
||||
require.Equal(t, 1000, oldConns)
|
||||
|
||||
// Run cleanup
|
||||
tracker.cleanup()
|
||||
|
||||
// Check that stale connections were cleaned up
|
||||
require.Equal(t, 0, len(tracker.connections))
|
||||
}
|
||||
|
||||
func TestTCPInboundInitiatedConnection(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
clientIP := netip.MustParseAddr("100.64.0.1")
|
||||
serverIP := netip.MustParseAddr("100.64.0.2")
|
||||
clientPort := uint16(12345)
|
||||
serverPort := uint16(80)
|
||||
|
||||
// 1. Client sends SYN (we receive it as inbound)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100)
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: clientIP,
|
||||
DstIP: serverIP,
|
||||
SrcPort: clientPort,
|
||||
DstPort: serverPort,
|
||||
}
|
||||
|
||||
tracker.mutex.RLock()
|
||||
conn := tracker.connections[key]
|
||||
tracker.mutex.RUnlock()
|
||||
|
||||
require.NotNil(t, conn)
|
||||
require.Equal(t, TCPStateSynReceived, conn.GetState(), "Connection should be in SYN-RECEIVED state after inbound SYN")
|
||||
|
||||
// 2. Server sends SYN-ACK response
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
||||
|
||||
// 3. Client sends ACK to complete handshake
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion")
|
||||
|
||||
// 4. Test data transfer
|
||||
// Client sends data
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000)
|
||||
|
||||
// Server sends ACK for data
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
|
||||
|
||||
// Server sends data
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500)
|
||||
|
||||
// Client sends ACK for data
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
|
||||
|
||||
// Verify state and counters
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
assert.Equal(t, uint64(1300), conn.BytesRx.Load()) // 3 packets * 100 + 1000 data
|
||||
assert.Equal(t, uint64(1700), conn.BytesTx.Load()) // 2 packets * 100 + 1500 data
|
||||
assert.Equal(t, uint64(4), conn.PacketsRx.Load()) // SYN, ACK, Data
|
||||
assert.Equal(t, uint64(3), conn.PacketsTx.Load()) // SYN-ACK, Data
|
||||
}
|
||||
|
||||
// Helper to establish a TCP connection
|
||||
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
|
||||
t.Helper()
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||
require.True(t, valid, "SYN-ACK should be allowed")
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
}
|
||||
|
||||
func BenchmarkTCPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck, 0)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("ConcurrentAccess", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
if i%2 == 0 {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||
} else {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck, 0)
|
||||
}
|
||||
i++
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Benchmark connection cleanup
|
||||
func BenchmarkCleanup(b *testing.B) {
|
||||
b.Run("TCPCleanup", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger) // Short timeout for testing
|
||||
defer tracker.Close()
|
||||
|
||||
// Pre-populate with expired connections
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
for i := 0; i < 10000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||
}
|
||||
|
||||
// Wait for connections to expire
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.cleanup()
|
||||
}
|
||||
})
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
|
||||
}
|
||||
|
||||
@@ -165,7 +165,7 @@ func (t *UDPTracker) cleanup() {
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
|
||||
@@ -14,8 +14,13 @@ import (
|
||||
type localIPManager struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// Use bitmap for IPv4 (32 bits * 2^16 = 256KB memory)
|
||||
ipv4Bitmap [1 << 16]uint32
|
||||
// fixed-size high array for upper byte of a IPv4 address
|
||||
ipv4Bitmap [256]*ipv4LowBitmap
|
||||
}
|
||||
|
||||
// ipv4LowBitmap is a map for the low 16 bits of a IPv4 address
|
||||
type ipv4LowBitmap struct {
|
||||
bitmap [8192]uint32
|
||||
}
|
||||
|
||||
func newLocalIPManager() *localIPManager {
|
||||
@@ -27,35 +32,59 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
|
||||
if ipv4 == nil {
|
||||
return
|
||||
}
|
||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||
m.ipv4Bitmap[high] |= 1 << (low % 32)
|
||||
high := uint16(ipv4[0])
|
||||
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
|
||||
if m.ipv4Bitmap[high] == nil {
|
||||
m.ipv4Bitmap[high] = &ipv4LowBitmap{}
|
||||
}
|
||||
|
||||
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
|
||||
}
|
||||
|
||||
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
||||
high := (uint16(ip[0]) << 8) | uint16(ip[1])
|
||||
low := (uint16(ip[2]) << 8) | uint16(ip[3])
|
||||
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
|
||||
}
|
||||
|
||||
func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
|
||||
func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||
if int(high) >= len(*newIPv4Bitmap) {
|
||||
return fmt.Errorf("invalid IPv4 address: %s", ip)
|
||||
high := uint16(ipv4[0])
|
||||
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||
|
||||
if bitmap[high] == nil {
|
||||
bitmap[high] = &ipv4LowBitmap{}
|
||||
}
|
||||
ipStr := ip.String()
|
||||
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
bitmap[high].bitmap[index] |= 1 << bit
|
||||
|
||||
ipStr := ipv4.String()
|
||||
if _, exists := ipv4Set[ipStr]; !exists {
|
||||
ipv4Set[ipStr] = struct{}{}
|
||||
*ipv4Addresses = append(*ipv4Addresses, ipStr)
|
||||
newIPv4Bitmap[high] |= 1 << (low % 32)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
||||
high := uint16(ip[0])
|
||||
low := (uint16(ip[1]) << 8) | (uint16(ip[2]) << 4) | uint16(ip[3])
|
||||
|
||||
if m.ipv4Bitmap[high] == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
|
||||
}
|
||||
|
||||
func (m *localIPManager) processIP(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
|
||||
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
||||
@@ -73,7 +102,7 @@ func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||
if err := m.processIP(ip, bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||
log.Debugf("process IP failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -86,14 +115,14 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
}
|
||||
}()
|
||||
|
||||
var newIPv4Bitmap [1 << 16]uint32
|
||||
var newIPv4Bitmap [256]*ipv4LowBitmap
|
||||
ipv4Set := make(map[string]struct{})
|
||||
var ipv4Addresses []string
|
||||
|
||||
// 127.0.0.0/8
|
||||
high := uint16(127) << 8
|
||||
for i := uint16(0); i < 256; i++ {
|
||||
newIPv4Bitmap[high|i] = 0xffffffff
|
||||
newIPv4Bitmap[127] = &ipv4LowBitmap{}
|
||||
for i := 0; i < 8192; i++ {
|
||||
newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
|
||||
}
|
||||
|
||||
if iface != nil {
|
||||
@@ -120,12 +149,12 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
}
|
||||
|
||||
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
|
||||
if !ip.Is4() {
|
||||
return false
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if ip.Is4() {
|
||||
return m.checkBitmapBit(ip.AsSlice())
|
||||
}
|
||||
|
||||
return false
|
||||
return m.checkBitmapBit(ip.AsSlice())
|
||||
}
|
||||
|
||||
@@ -77,6 +77,18 @@ func TestLocalIPManager(t *testing.T) {
|
||||
testIP: netip.MustParseAddr("192.168.1.2"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Local IP doesn't match - addresses 32 apart",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: netip.MustParseAddr("192.168.1.33"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 address",
|
||||
setupAddr: wgaddr.Address{
|
||||
@@ -192,10 +204,8 @@ func BenchmarkIPChecks(b *testing.B) {
|
||||
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
|
||||
}
|
||||
|
||||
// Setup bitmap version
|
||||
bitmapManager := &localIPManager{
|
||||
ipv4Bitmap: [1 << 16]uint32{},
|
||||
}
|
||||
// Setup bitmap
|
||||
bitmapManager := newLocalIPManager()
|
||||
for _, ip := range interfaces[:8] { // Add half of IPs
|
||||
bitmapManager.setBitmapBit(ip)
|
||||
}
|
||||
@@ -248,7 +258,7 @@ func BenchmarkWGPosition(b *testing.B) {
|
||||
|
||||
// Create two managers - one checks WG IP first, other checks it last
|
||||
b.Run("WG_First", func(b *testing.B) {
|
||||
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
||||
bm := newLocalIPManager()
|
||||
bm.setBitmapBit(wgIP)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
@@ -257,7 +267,7 @@ func BenchmarkWGPosition(b *testing.B) {
|
||||
})
|
||||
|
||||
b.Run("WG_Last", func(b *testing.B) {
|
||||
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
||||
bm := newLocalIPManager()
|
||||
// Fill with other IPs first
|
||||
for i := 0; i < 15; i++ {
|
||||
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
|
||||
|
||||
@@ -658,7 +658,8 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
if !m.isValidPacket(d, packetData) {
|
||||
valid, fragment := m.isValidPacket(d, packetData)
|
||||
if !valid {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -668,6 +669,13 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// TODO: pass fragments of routed packets to forwarder
|
||||
if fragment {
|
||||
m.logger.Trace("packet is a fragment: src=%v dst=%v id=%v flags=%v",
|
||||
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
|
||||
return false
|
||||
}
|
||||
|
||||
// For all inbound traffic, first check if it matches a tracked connection.
|
||||
// This must happen before any other filtering because the packets are statefully tracked.
|
||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
|
||||
@@ -678,7 +686,7 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
|
||||
return m.handleLocalTraffic(d, srcIP, dstIP, packetData, size)
|
||||
}
|
||||
|
||||
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
|
||||
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData, size)
|
||||
}
|
||||
|
||||
// handleLocalTraffic handles local traffic.
|
||||
@@ -739,7 +747,7 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
|
||||
|
||||
// handleRoutedTraffic handles routed traffic.
|
||||
// If it returns true, the packet should be dropped.
|
||||
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte) bool {
|
||||
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
|
||||
// Drop if routing is disabled
|
||||
if !m.routingEnabled.Load() {
|
||||
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
|
||||
@@ -749,6 +757,7 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
|
||||
|
||||
// Pass to native stack if native router is enabled or forced
|
||||
if m.nativeRouter.Load() {
|
||||
m.trackInbound(d, srcIP, dstIP, nil, size)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -770,6 +779,8 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
// TODO: icmp type/code
|
||||
RxPackets: 1,
|
||||
RxBytes: uint64(size),
|
||||
})
|
||||
return true
|
||||
}
|
||||
@@ -812,17 +823,32 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
|
||||
// isValidPacket checks if the packet is valid.
|
||||
// It returns true, false if the packet is valid and not a fragment.
|
||||
// It returns true, true if the packet is a fragment and valid.
|
||||
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
m.logger.Trace("couldn't decode packet, err: %s", err)
|
||||
return false
|
||||
return false, false
|
||||
}
|
||||
|
||||
if len(d.decoded) < 2 {
|
||||
m.logger.Trace("packet doesn't have network and transport layers")
|
||||
return false
|
||||
l := len(d.decoded)
|
||||
|
||||
// L3 and L4 are mandatory
|
||||
if l >= 2 {
|
||||
return true, false
|
||||
}
|
||||
return true
|
||||
|
||||
// Fragments are also valid
|
||||
if l == 1 && d.decoded[0] == layers.LayerTypeIPv4 {
|
||||
ip4 := d.ip4
|
||||
if ip4.Flags&layers.IPv4MoreFragments != 0 || ip4.FragOffset != 0 {
|
||||
return true, true
|
||||
}
|
||||
}
|
||||
|
||||
m.logger.Trace("packet doesn't have network and transport layers")
|
||||
return false, false
|
||||
}
|
||||
|
||||
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr, size int) bool {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -24,7 +23,7 @@ import (
|
||||
)
|
||||
|
||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||
|
||||
type IFaceMock struct {
|
||||
SetFilterFunc func(device.PacketFilter) error
|
||||
|
||||
@@ -150,7 +150,7 @@ func isZeros(ip net.IP) bool {
|
||||
// NewUDPMuxDefault creates an implementation of UDPMux
|
||||
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
||||
if params.Logger == nil {
|
||||
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
|
||||
params.Logger = getLogger()
|
||||
}
|
||||
|
||||
mux := &UDPMuxDefault{
|
||||
@@ -455,3 +455,9 @@ func newBufferHolder(size int) *bufferHolder {
|
||||
buf: make([]byte, size),
|
||||
}
|
||||
}
|
||||
|
||||
func getLogger() logging.LeveledLogger {
|
||||
fac := logging.NewDefaultLoggerFactory()
|
||||
//fac.Writer = log.StandardLogger().Writer()
|
||||
return fac.NewLogger("ice")
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ type UniversalUDPMuxParams struct {
|
||||
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
|
||||
func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault {
|
||||
if params.Logger == nil {
|
||||
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
|
||||
params.Logger = getLogger()
|
||||
}
|
||||
if params.XORMappedAddrCacheTTL == 0 {
|
||||
params.XORMappedAddrCacheTTL = time.Second * 25
|
||||
|
||||
@@ -357,7 +357,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
||||
|
||||
func getFwmark() int {
|
||||
if nbnet.AdvancedRouting() {
|
||||
return nbnet.NetbirdFwmark
|
||||
return nbnet.ControlPlaneMark
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
@@ -15,7 +14,7 @@ import (
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||
|
||||
func TestDefaultManager(t *testing.T) {
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
|
||||
@@ -94,12 +94,17 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
||||
p.codeVerifier = codeVerifier
|
||||
|
||||
codeChallenge := createCodeChallenge(codeVerifier)
|
||||
authURL := p.oAuthConfig.AuthCodeURL(
|
||||
state,
|
||||
|
||||
params := []oauth2.AuthCodeOption{
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
|
||||
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
||||
)
|
||||
}
|
||||
if !p.providerConfig.DisablePromptLogin {
|
||||
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
|
||||
}
|
||||
|
||||
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
|
||||
|
||||
return AuthFlowInfo{
|
||||
VerificationURIComplete: authURL,
|
||||
|
||||
49
client/internal/auth/pkce_flow_test.go
Normal file
49
client/internal/auth/pkce_flow_test.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
func TestPromptLogin(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
prompt bool
|
||||
}{
|
||||
{"PromptLogin", true},
|
||||
{"NoPromptLogin", false},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
config := internal.PKCEAuthProviderConfig{
|
||||
ClientID: "test-client-id",
|
||||
Audience: "test-audience",
|
||||
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||
Scope: "openid email profile",
|
||||
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
|
||||
RedirectURLs: []string{"http://127.0.0.1:33992/"},
|
||||
UseIDToken: true,
|
||||
DisablePromptLogin: !tc.prompt,
|
||||
}
|
||||
pkce, err := NewPKCEAuthorizationFlow(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create PKCEAuthorizationFlow: %v", err)
|
||||
}
|
||||
authInfo, err := pkce.RequestAuthInfo(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to request auth info: %v", err)
|
||||
}
|
||||
pattern := "prompt=login"
|
||||
if tc.prompt {
|
||||
require.Contains(t, authInfo.VerificationURIComplete, pattern)
|
||||
} else {
|
||||
require.NotContains(t, authInfo.VerificationURIComplete, pattern)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -349,6 +349,25 @@ func (c *ConnectClient) Engine() *Engine {
|
||||
return e
|
||||
}
|
||||
|
||||
// GetLatestNetworkMap returns the latest network map from the engine.
|
||||
func (c *ConnectClient) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
|
||||
engine := c.Engine()
|
||||
if engine == nil {
|
||||
return nil, errors.New("engine is not initialized")
|
||||
}
|
||||
|
||||
networkMap, err := engine.GetLatestNetworkMap()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get latest network map: %w", err)
|
||||
}
|
||||
|
||||
if networkMap == nil {
|
||||
return nil, errors.New("network map is not available")
|
||||
}
|
||||
|
||||
return networkMap, nil
|
||||
}
|
||||
|
||||
// Status returns the current client status
|
||||
func (c *ConnectClient) Status() StatusType {
|
||||
if c == nil {
|
||||
|
||||
1022
client/internal/debug/debug.go
Normal file
1022
client/internal/debug/debug.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,8 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package server
|
||||
package debug
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
@@ -14,36 +13,31 @@ import (
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// addFirewallRules collects and adds firewall rules to the archive
|
||||
func (s *Server) addFirewallRules(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
|
||||
func (g *BundleGenerator) addFirewallRules() error {
|
||||
log.Info("Collecting firewall rules")
|
||||
// Collect and add iptables rules
|
||||
iptablesRules, err := collectIPTablesRules()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to collect iptables rules: %v", err)
|
||||
} else {
|
||||
if req.GetAnonymize() {
|
||||
iptablesRules = anonymizer.AnonymizeString(iptablesRules)
|
||||
if g.anonymize {
|
||||
iptablesRules = g.anonymizer.AnonymizeString(iptablesRules)
|
||||
}
|
||||
if err := addFileToZip(archive, strings.NewReader(iptablesRules), "iptables.txt"); err != nil {
|
||||
if err := g.addFileToZip(strings.NewReader(iptablesRules), "iptables.txt"); err != nil {
|
||||
log.Warnf("Failed to add iptables rules to bundle: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Collect and add nftables rules
|
||||
nftablesRules, err := collectNFTablesRules()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to collect nftables rules: %v", err)
|
||||
} else {
|
||||
if req.GetAnonymize() {
|
||||
nftablesRules = anonymizer.AnonymizeString(nftablesRules)
|
||||
if g.anonymize {
|
||||
nftablesRules = g.anonymizer.AnonymizeString(nftablesRules)
|
||||
}
|
||||
if err := addFileToZip(archive, strings.NewReader(nftablesRules), "nftables.txt"); err != nil {
|
||||
if err := g.addFileToZip(strings.NewReader(nftablesRules), "nftables.txt"); err != nil {
|
||||
log.Warnf("Failed to add nftables rules to bundle: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -65,16 +59,13 @@ func collectIPTablesRules() (string, error) {
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
// Then get verbose statistics for each table
|
||||
builder.WriteString("=== iptables -v -n -L output ===\n")
|
||||
|
||||
// Get list of tables
|
||||
tables := []string{"filter", "nat", "mangle", "raw", "security"}
|
||||
|
||||
for _, table := range tables {
|
||||
builder.WriteString(fmt.Sprintf("*%s\n", table))
|
||||
|
||||
// Get verbose statistics for the entire table
|
||||
stats, err := getTableStatistics(table)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to get statistics for table %s: %v", table, err)
|
||||
@@ -182,12 +173,10 @@ func formatTables(conn *nftables.Conn, tables []*nftables.Table) string {
|
||||
continue
|
||||
}
|
||||
|
||||
// Format chains
|
||||
for _, chain := range chains {
|
||||
formatChain(conn, table, chain, &builder)
|
||||
}
|
||||
|
||||
// Format sets
|
||||
if sets, err := conn.GetSets(table); err != nil {
|
||||
log.Warnf("Failed to get sets for table %s: %v", table.Name, err)
|
||||
} else if len(sets) > 0 {
|
||||
7
client/internal/debug/debug_mobile.go
Normal file
7
client/internal/debug/debug_mobile.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build ios || android
|
||||
|
||||
package debug
|
||||
|
||||
func (g *BundleGenerator) addRoutes() error {
|
||||
return nil
|
||||
}
|
||||
8
client/internal/debug/debug_nonlinux.go
Normal file
8
client/internal/debug/debug_nonlinux.go
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build !linux || android
|
||||
|
||||
package debug
|
||||
|
||||
// collectFirewallRules returns nothing on non-linux systems
|
||||
func (g *BundleGenerator) addFirewallRules() error {
|
||||
return nil
|
||||
}
|
||||
25
client/internal/debug/debug_nonmobile.go
Normal file
25
client/internal/debug/debug_nonmobile.go
Normal file
@@ -0,0 +1,25 @@
|
||||
//go:build !ios && !android
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
func (g *BundleGenerator) addRoutes() error {
|
||||
routes, err := systemops.GetRoutesFromTable()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get routes: %w", err)
|
||||
}
|
||||
|
||||
// TODO: get routes including nexthop
|
||||
routesContent := formatRoutes(routes, g.anonymize, g.anonymizer)
|
||||
routesReader := strings.NewReader(routesContent)
|
||||
if err := g.addFileToZip(routesReader, "routes.txt"); err != nil {
|
||||
return fmt.Errorf("add routes file to zip: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package debug
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -239,7 +239,7 @@ func searchDomains(config HostDNSConfig) []string {
|
||||
continue
|
||||
}
|
||||
|
||||
listOfDomains = append(listOfDomains, dConf.Domain)
|
||||
listOfDomains = append(listOfDomains, strings.TrimSuffix(dConf.Domain, "."))
|
||||
}
|
||||
return listOfDomains
|
||||
}
|
||||
|
||||
@@ -75,12 +75,7 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
||||
}
|
||||
|
||||
// First remove any existing handler with same pattern (case-insensitive) and priority
|
||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
|
||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
c.removeEntry(origPattern, priority)
|
||||
|
||||
// Check if handler implements SubdomainMatcher interface
|
||||
matchSubdomains := false
|
||||
@@ -133,30 +128,20 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
|
||||
|
||||
pattern = dns.Fqdn(pattern)
|
||||
|
||||
c.removeEntry(pattern, priority)
|
||||
}
|
||||
|
||||
func (c *HandlerChain) removeEntry(pattern string, priority int) {
|
||||
// Find and remove handlers matching both original pattern (case-insensitive) and priority
|
||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||
entry := c.handlers[i]
|
||||
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||
return
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HasHandlers returns true if there are any handlers remaining for the given pattern
|
||||
func (c *HandlerChain) HasHandlers(pattern string) bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
pattern = strings.ToLower(dns.Fqdn(pattern))
|
||||
for _, entry := range c.handlers {
|
||||
if strings.EqualFold(entry.Pattern, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) == 0 {
|
||||
return
|
||||
|
||||
@@ -443,14 +443,6 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
||||
for _, handler := range handlers {
|
||||
handler.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// Verify handler exists check
|
||||
for priority, shouldExist := range tt.expectedCalls {
|
||||
if shouldExist {
|
||||
assert.True(t, chain.HasHandlers(tt.ops[0].pattern),
|
||||
"Handler chain should have handlers for pattern after removing priority %d", priority)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -470,45 +462,69 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(testQuery, dns.TypeA)
|
||||
|
||||
// Keep track of mocks for the final assertion in Step 4
|
||||
mocks := []*nbdns.MockSubdomainHandler{routeHandler, matchHandler, defaultHandler}
|
||||
|
||||
// Add handlers in mixed order
|
||||
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
|
||||
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
|
||||
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
|
||||
|
||||
// Test 1: Initial state with all three handlers
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
// Test 1: Initial state
|
||||
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
// Highest priority handler (routeHandler) should be called
|
||||
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||
matchHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
|
||||
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
|
||||
|
||||
chain.ServeDNS(w, r)
|
||||
chain.ServeDNS(w1, r)
|
||||
routeHandler.AssertExpectations(t)
|
||||
|
||||
routeHandler.ExpectedCalls = nil
|
||||
routeHandler.Calls = nil
|
||||
matchHandler.ExpectedCalls = nil
|
||||
matchHandler.Calls = nil
|
||||
defaultHandler.ExpectedCalls = nil
|
||||
defaultHandler.Calls = nil
|
||||
|
||||
// Test 2: Remove highest priority handler
|
||||
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
|
||||
assert.True(t, chain.HasHandlers(testDomain))
|
||||
|
||||
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
w2 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
// Now middle priority handler (matchHandler) should be called
|
||||
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure default is not expected yet
|
||||
|
||||
chain.ServeDNS(w, r)
|
||||
chain.ServeDNS(w2, r)
|
||||
matchHandler.AssertExpectations(t)
|
||||
|
||||
matchHandler.ExpectedCalls = nil
|
||||
matchHandler.Calls = nil
|
||||
defaultHandler.ExpectedCalls = nil
|
||||
defaultHandler.Calls = nil
|
||||
|
||||
// Test 3: Remove middle priority handler
|
||||
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
|
||||
assert.True(t, chain.HasHandlers(testDomain))
|
||||
|
||||
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
// Now lowest priority handler (defaultHandler) should be called
|
||||
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||
|
||||
chain.ServeDNS(w, r)
|
||||
chain.ServeDNS(w3, r)
|
||||
defaultHandler.AssertExpectations(t)
|
||||
|
||||
defaultHandler.ExpectedCalls = nil
|
||||
defaultHandler.Calls = nil
|
||||
|
||||
// Test 4: Remove last handler
|
||||
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
|
||||
|
||||
assert.False(t, chain.HasHandlers(testDomain))
|
||||
w4 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
chain.ServeDNS(w4, r) // Call ServeDNS on the now empty chain for this domain
|
||||
|
||||
for _, m := range mocks {
|
||||
m.AssertNumberOfCalls(t, "ServeDNS", 0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
||||
@@ -830,3 +846,165 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addPattern string
|
||||
removePattern string
|
||||
queryPattern string
|
||||
shouldBeRemoved bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "exact same pattern",
|
||||
addPattern: "example.com.",
|
||||
removePattern: "example.com.",
|
||||
queryPattern: "example.com.",
|
||||
shouldBeRemoved: true,
|
||||
description: "Adding and removing with identical patterns",
|
||||
},
|
||||
{
|
||||
name: "case difference",
|
||||
addPattern: "Example.Com.",
|
||||
removePattern: "EXAMPLE.COM.",
|
||||
queryPattern: "example.com.",
|
||||
shouldBeRemoved: true,
|
||||
description: "Adding with mixed case, removing with uppercase",
|
||||
},
|
||||
{
|
||||
name: "reversed case difference",
|
||||
addPattern: "EXAMPLE.ORG.",
|
||||
removePattern: "example.org.",
|
||||
queryPattern: "example.org.",
|
||||
shouldBeRemoved: true,
|
||||
description: "Adding with uppercase, removing with lowercase",
|
||||
},
|
||||
{
|
||||
name: "add wildcard, remove wildcard",
|
||||
addPattern: "*.example.com.",
|
||||
removePattern: "*.example.com.",
|
||||
queryPattern: "sub.example.com.",
|
||||
shouldBeRemoved: true,
|
||||
description: "Adding and removing with identical wildcard patterns",
|
||||
},
|
||||
{
|
||||
name: "add wildcard, remove transformed pattern",
|
||||
addPattern: "*.example.net.",
|
||||
removePattern: "example.net.",
|
||||
queryPattern: "sub.example.net.",
|
||||
shouldBeRemoved: false,
|
||||
description: "Adding with wildcard, removing with non-wildcard pattern",
|
||||
},
|
||||
{
|
||||
name: "add transformed pattern, remove wildcard",
|
||||
addPattern: "example.io.",
|
||||
removePattern: "*.example.io.",
|
||||
queryPattern: "example.io.",
|
||||
shouldBeRemoved: false,
|
||||
description: "Adding with non-wildcard pattern, removing with wildcard pattern",
|
||||
},
|
||||
{
|
||||
name: "trailing dot difference",
|
||||
addPattern: "example.dev",
|
||||
removePattern: "example.dev.",
|
||||
queryPattern: "example.dev.",
|
||||
shouldBeRemoved: true,
|
||||
description: "Adding without trailing dot, removing with trailing dot",
|
||||
},
|
||||
{
|
||||
name: "reversed trailing dot difference",
|
||||
addPattern: "example.app.",
|
||||
removePattern: "example.app",
|
||||
queryPattern: "example.app.",
|
||||
shouldBeRemoved: true,
|
||||
description: "Adding with trailing dot, removing without trailing dot",
|
||||
},
|
||||
{
|
||||
name: "mixed case and wildcard",
|
||||
addPattern: "*.Example.Site.",
|
||||
removePattern: "*.EXAMPLE.SITE.",
|
||||
queryPattern: "sub.example.site.",
|
||||
shouldBeRemoved: true,
|
||||
description: "Adding mixed case wildcard, removing uppercase wildcard",
|
||||
},
|
||||
{
|
||||
name: "root zone",
|
||||
addPattern: ".",
|
||||
removePattern: ".",
|
||||
queryPattern: "random.domain.",
|
||||
shouldBeRemoved: true,
|
||||
description: "Adding and removing root zone",
|
||||
},
|
||||
{
|
||||
name: "wrong domain",
|
||||
addPattern: "example.com.",
|
||||
removePattern: "different.com.",
|
||||
queryPattern: "example.com.",
|
||||
shouldBeRemoved: false,
|
||||
description: "Adding one domain, trying to remove a different domain",
|
||||
},
|
||||
{
|
||||
name: "subdomain mismatch",
|
||||
addPattern: "sub.example.com.",
|
||||
removePattern: "example.com.",
|
||||
queryPattern: "sub.example.com.",
|
||||
shouldBeRemoved: false,
|
||||
description: "Adding subdomain, trying to remove parent domain",
|
||||
},
|
||||
{
|
||||
name: "parent domain mismatch",
|
||||
addPattern: "example.com.",
|
||||
removePattern: "sub.example.com.",
|
||||
queryPattern: "example.com.",
|
||||
shouldBeRemoved: false,
|
||||
description: "Adding parent domain, trying to remove subdomain",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
|
||||
handler := &nbdns.MockHandler{}
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.queryPattern, dns.TypeA)
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
|
||||
// First verify no handler is called before adding any
|
||||
chain.ServeDNS(w, r)
|
||||
handler.AssertNotCalled(t, "ServeDNS")
|
||||
|
||||
// Add handler
|
||||
chain.AddHandler(tt.addPattern, handler, nbdns.PriorityDefault)
|
||||
|
||||
// Verify handler is called after adding
|
||||
handler.On("ServeDNS", mock.Anything, r).Once()
|
||||
chain.ServeDNS(w, r)
|
||||
handler.AssertExpectations(t)
|
||||
|
||||
// Reset mock for the next test
|
||||
handler.ExpectedCalls = nil
|
||||
|
||||
// Remove handler
|
||||
chain.RemoveHandler(tt.removePattern, nbdns.PriorityDefault)
|
||||
|
||||
// Set up expectations based on whether removal should succeed
|
||||
if !tt.shouldBeRemoved {
|
||||
handler.On("ServeDNS", mock.Anything, r).Once()
|
||||
}
|
||||
|
||||
// Test if handler is still called after removal attempt
|
||||
chain.ServeDNS(w, r)
|
||||
|
||||
if tt.shouldBeRemoved {
|
||||
handler.AssertNotCalled(t, "ServeDNS",
|
||||
"Handler should not be called after successful removal with pattern %q",
|
||||
tt.removePattern)
|
||||
} else {
|
||||
handler.AssertExpectations(t)
|
||||
handler.ExpectedCalls = nil
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
@@ -12,8 +14,8 @@ import (
|
||||
var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
|
||||
|
||||
const (
|
||||
ipv4ReverseZone = ".in-addr.arpa"
|
||||
ipv6ReverseZone = ".ip6.arpa"
|
||||
ipv4ReverseZone = ".in-addr.arpa."
|
||||
ipv6ReverseZone = ".ip6.arpa."
|
||||
)
|
||||
|
||||
type hostManager interface {
|
||||
@@ -103,7 +105,7 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
|
||||
|
||||
for _, domain := range nsConfig.Domains {
|
||||
config.Domains = append(config.Domains, DomainConfig{
|
||||
Domain: strings.TrimSuffix(domain, "."),
|
||||
Domain: strings.ToLower(dns.Fqdn(domain)),
|
||||
MatchOnly: !nsConfig.SearchDomainsEnabled,
|
||||
})
|
||||
}
|
||||
@@ -112,7 +114,7 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
|
||||
for _, customZone := range dnsConfig.CustomZones {
|
||||
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
|
||||
config.Domains = append(config.Domains, DomainConfig{
|
||||
Domain: strings.TrimSuffix(customZone.Domain, "."),
|
||||
Domain: strings.ToLower(dns.Fqdn(customZone.Domain)),
|
||||
MatchOnly: matchOnly,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -79,10 +79,10 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
||||
continue
|
||||
}
|
||||
if dConf.MatchOnly {
|
||||
matchDomains = append(matchDomains, dConf.Domain)
|
||||
matchDomains = append(matchDomains, strings.TrimSuffix(dConf.Domain, "."))
|
||||
continue
|
||||
}
|
||||
searchDomains = append(searchDomains, dConf.Domain)
|
||||
searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, "."))
|
||||
}
|
||||
|
||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||
|
||||
@@ -17,9 +17,12 @@ import (
|
||||
|
||||
var (
|
||||
userenv = syscall.NewLazyDLL("userenv.dll")
|
||||
dnsapi = syscall.NewLazyDLL("dnsapi.dll")
|
||||
|
||||
// https://learn.microsoft.com/en-us/windows/win32/api/userenv/nf-userenv-refreshpolicyex
|
||||
refreshPolicyExFn = userenv.NewProc("RefreshPolicyEx")
|
||||
|
||||
dnsFlushResolverCacheFn = dnsapi.NewProc("DnsFlushResolverCache")
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -97,9 +100,9 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
continue
|
||||
}
|
||||
if !dConf.MatchOnly {
|
||||
searchDomains = append(searchDomains, dConf.Domain)
|
||||
searchDomains = append(searchDomains, strings.TrimSuffix(dConf.Domain, "."))
|
||||
}
|
||||
matchDomains = append(matchDomains, "."+dConf.Domain)
|
||||
matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, "."))
|
||||
}
|
||||
|
||||
if len(matchDomains) != 0 {
|
||||
@@ -116,6 +119,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
return fmt.Errorf("update search domains: %w", err)
|
||||
}
|
||||
|
||||
if err := r.flushDNSCache(); err != nil {
|
||||
log.Errorf("failed to flush DNS cache: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -184,6 +191,26 @@ func (r *registryConfigurator) string() string {
|
||||
return "registry"
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) flushDNSCache() error {
|
||||
// dnsFlushResolverCacheFn.Call() may panic if the func is not found
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
log.Errorf("Recovered from panic in flushDNSCache: %v", rec)
|
||||
}
|
||||
}()
|
||||
|
||||
ret, _, err := dnsFlushResolverCacheFn.Call()
|
||||
if ret == 0 {
|
||||
if err != nil && !errors.Is(err, syscall.Errno(0)) {
|
||||
return fmt.Errorf("DnsFlushResolverCache failed: %w", err)
|
||||
}
|
||||
return fmt.Errorf("DnsFlushResolverCache failed")
|
||||
}
|
||||
|
||||
log.Info("flushed DNS cache")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
|
||||
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ",")); err != nil {
|
||||
return fmt.Errorf("update search domains: %w", err)
|
||||
@@ -236,6 +263,10 @@ func (r *registryConfigurator) restoreHostDNS() error {
|
||||
return fmt.Errorf("remove interface registry key: %w", err)
|
||||
}
|
||||
|
||||
if err := r.flushDNSCache(); err != nil {
|
||||
log.Errorf("failed to flush DNS cache: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -71,6 +71,12 @@ func (d *localResolver) lookupRecords(r *dns.Msg) []dns.RR {
|
||||
|
||||
value, found := d.records.Load(key)
|
||||
if !found {
|
||||
// alternatively check if we have a cname
|
||||
if question.Qtype != dns.TypeCNAME {
|
||||
r.Question[0].Qtype = dns.TypeCNAME
|
||||
return d.lookupRecords(r)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
// MockServer is the mock instance of a dns server
|
||||
@@ -13,17 +14,17 @@ type MockServer struct {
|
||||
InitializeFunc func() error
|
||||
StopFunc func()
|
||||
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
||||
RegisterHandlerFunc func([]string, dns.Handler, int)
|
||||
DeregisterHandlerFunc func([]string, int)
|
||||
RegisterHandlerFunc func(domain.List, dns.Handler, int)
|
||||
DeregisterHandlerFunc func(domain.List, int)
|
||||
}
|
||||
|
||||
func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
|
||||
func (m *MockServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
|
||||
if m.RegisterHandlerFunc != nil {
|
||||
m.RegisterHandlerFunc(domains, handler, priority)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockServer) DeregisterHandler(domains []string, priority int) {
|
||||
func (m *MockServer) DeregisterHandler(domains domain.List, priority int) {
|
||||
if m.DeregisterHandlerFunc != nil {
|
||||
m.DeregisterHandlerFunc(domains, priority)
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
"github.com/hashicorp/go-version"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
@@ -126,10 +125,10 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st
|
||||
continue
|
||||
}
|
||||
if dConf.MatchOnly {
|
||||
matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.Domain))
|
||||
matchDomains = append(matchDomains, "~."+dConf.Domain)
|
||||
continue
|
||||
}
|
||||
searchDomains = append(searchDomains, dns.Fqdn(dConf.Domain))
|
||||
searchDomains = append(searchDomains, dConf.Domain)
|
||||
}
|
||||
|
||||
newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic
|
||||
|
||||
@@ -6,11 +6,13 @@ import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/mitchellh/hashstructure/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
@@ -18,6 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
||||
@@ -32,8 +35,8 @@ type IosDnsManager interface {
|
||||
|
||||
// Server is a dns server interface
|
||||
type Server interface {
|
||||
RegisterHandler(domains []string, handler dns.Handler, priority int)
|
||||
DeregisterHandler(domains []string, priority int)
|
||||
RegisterHandler(domains domain.List, handler dns.Handler, priority int)
|
||||
DeregisterHandler(domains domain.List, priority int)
|
||||
Initialize() error
|
||||
Stop()
|
||||
DnsIP() string
|
||||
@@ -65,6 +68,7 @@ type DefaultServer struct {
|
||||
previousConfigHash uint64
|
||||
currentConfig HostDNSConfig
|
||||
handlerChain *HandlerChain
|
||||
extraDomains map[domain.Domain]int
|
||||
|
||||
// permanent related properties
|
||||
permanent bool
|
||||
@@ -164,13 +168,15 @@ func newDefaultServer(
|
||||
stateManager *statemanager.Manager,
|
||||
disableSys bool,
|
||||
) *DefaultServer {
|
||||
handlerChain := NewHandlerChain()
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
defaultServer := &DefaultServer{
|
||||
ctx: ctx,
|
||||
ctxCancel: stop,
|
||||
disableSys: disableSys,
|
||||
service: dnsService,
|
||||
handlerChain: NewHandlerChain(),
|
||||
handlerChain: handlerChain,
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
@@ -181,14 +187,26 @@ func newDefaultServer(
|
||||
hostsDNSHolder: newHostsDNSHolder(),
|
||||
}
|
||||
|
||||
// register with root zone, handler chain takes care of the routing
|
||||
dnsService.RegisterMux(".", handlerChain)
|
||||
|
||||
return defaultServer
|
||||
}
|
||||
|
||||
func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
|
||||
// RegisterHandler registers a handler for the given domains with the given priority.
|
||||
// Any previously registered handler for the same domain and priority will be replaced.
|
||||
func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
s.registerHandler(domains, handler, priority)
|
||||
s.registerHandler(domains.ToPunycodeList(), handler, priority)
|
||||
|
||||
// TODO: This will take over zones for non-wildcard domains, for which we might not have a handler in the chain
|
||||
for _, domain := range domains {
|
||||
// convert to zone with simple ref counter
|
||||
s.extraDomains[toZone(domain)]++
|
||||
}
|
||||
s.applyHostConfig()
|
||||
}
|
||||
|
||||
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
|
||||
@@ -200,15 +218,23 @@ func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, p
|
||||
continue
|
||||
}
|
||||
s.handlerChain.AddHandler(domain, handler, priority)
|
||||
s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DefaultServer) DeregisterHandler(domains []string, priority int) {
|
||||
// DeregisterHandler deregisters the handler for the given domains with the given priority.
|
||||
func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
s.deregisterHandler(domains, priority)
|
||||
s.deregisterHandler(domains.ToPunycodeList(), priority)
|
||||
for _, domain := range domains {
|
||||
zone := toZone(domain)
|
||||
s.extraDomains[zone]--
|
||||
if s.extraDomains[zone] <= 0 {
|
||||
delete(s.extraDomains, zone)
|
||||
}
|
||||
}
|
||||
s.applyHostConfig()
|
||||
}
|
||||
|
||||
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
||||
@@ -221,11 +247,6 @@ func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
||||
}
|
||||
|
||||
s.handlerChain.RemoveHandler(domain, priority)
|
||||
|
||||
// Only deregister from service if no handlers remain
|
||||
if !s.handlerChain.HasHandlers(domain) {
|
||||
s.service.DeregisterMux(nbdns.NormalizeZone(domain))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -286,6 +307,8 @@ func (s *DefaultServer) Stop() {
|
||||
}
|
||||
|
||||
s.service.Stop()
|
||||
|
||||
maps.Clear(s.extraDomains)
|
||||
}
|
||||
|
||||
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
||||
@@ -390,7 +413,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
// is the service should be Disabled, we stop the listener or fake resolver
|
||||
// and proceed with a regular update to clean up the handlers and records
|
||||
if update.ServiceEnable {
|
||||
_ = s.service.Listen()
|
||||
if err := s.service.Listen(); err != nil {
|
||||
log.Errorf("failed to start DNS service: %v", err)
|
||||
}
|
||||
} else if !s.permanent {
|
||||
s.service.Stop()
|
||||
}
|
||||
@@ -413,17 +438,13 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
|
||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
||||
|
||||
hostUpdate := s.currentConfig
|
||||
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
|
||||
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
|
||||
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
|
||||
hostUpdate.RouteAll = false
|
||||
s.currentConfig.RouteAll = false
|
||||
}
|
||||
|
||||
if err = s.hostManager.applyDNSConfig(hostUpdate, s.stateManager); err != nil {
|
||||
log.Error(err)
|
||||
s.handleErrNoGroupaAll(err)
|
||||
}
|
||||
s.applyHostConfig()
|
||||
|
||||
go func() {
|
||||
// persist dns state right away
|
||||
@@ -441,6 +462,43 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) applyHostConfig() {
|
||||
if s.hostManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// prevent reapplying config if we're shutting down
|
||||
if s.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
config := s.currentConfig
|
||||
|
||||
existingDomains := make(map[string]struct{})
|
||||
for _, d := range config.Domains {
|
||||
existingDomains[d.Domain] = struct{}{}
|
||||
}
|
||||
|
||||
// add extra domains only if they're not already in the config
|
||||
for domain := range s.extraDomains {
|
||||
domainStr := domain.PunycodeString()
|
||||
|
||||
if _, exists := existingDomains[domainStr]; !exists {
|
||||
config.Domains = append(config.Domains, DomainConfig{
|
||||
Domain: domainStr,
|
||||
MatchOnly: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("extra match domains: %v", s.extraDomains)
|
||||
|
||||
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
|
||||
log.Errorf("failed to apply DNS host manager update: %v", err)
|
||||
s.handleErrNoGroupaAll(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DefaultServer) handleErrNoGroupaAll(err error) {
|
||||
if !errors.Is(ErrRouteAllWithoutNameserverGroup, err) {
|
||||
return
|
||||
@@ -690,10 +748,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
||||
s.handleErrNoGroupaAll(err)
|
||||
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
|
||||
}
|
||||
s.applyHostConfig()
|
||||
|
||||
go func() {
|
||||
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
||||
@@ -728,12 +783,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
|
||||
}
|
||||
|
||||
if s.hostManager != nil {
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
||||
s.handleErrNoGroupaAll(err)
|
||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||
}
|
||||
}
|
||||
s.applyHostConfig()
|
||||
|
||||
s.updateNSState(nsGroup, nil, true)
|
||||
}
|
||||
@@ -836,3 +886,13 @@ func groupNSGroupsByDomain(nsGroups []*nbdns.NameServerGroup) []nsGroupsByDomain
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func toZone(d domain.Domain) domain.Domain {
|
||||
return domain.Domain(
|
||||
nbdns.NormalizeZone(
|
||||
dns.Fqdn(
|
||||
strings.ToLower(d.PunycodeString()),
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -29,16 +29,17 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||
|
||||
type mocWGIface struct {
|
||||
filter device.PacketFilter
|
||||
}
|
||||
|
||||
func (w *mocWGIface) Name() string {
|
||||
panic("implement me")
|
||||
return "utun2301"
|
||||
}
|
||||
|
||||
func (w *mocWGIface) Address() wgaddr.Address {
|
||||
@@ -1448,3 +1449,497 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtraDomains(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
initialConfig nbdns.Config
|
||||
registerDomains []domain.List
|
||||
deregisterDomains []domain.List
|
||||
finalConfig nbdns.Config
|
||||
expectedDomains []string
|
||||
expectedMatchOnly []string
|
||||
applyHostConfigCall int
|
||||
}{
|
||||
{
|
||||
name: "Register domains before config update",
|
||||
registerDomains: []domain.List{
|
||||
{"extra1.example.com", "extra2.example.com"},
|
||||
},
|
||||
initialConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
"config.example.com.",
|
||||
"extra1.example.com.",
|
||||
"extra2.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
"extra1.example.com.",
|
||||
"extra2.example.com.",
|
||||
},
|
||||
applyHostConfigCall: 2,
|
||||
},
|
||||
{
|
||||
name: "Register domains after config update",
|
||||
initialConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
},
|
||||
},
|
||||
registerDomains: []domain.List{
|
||||
{"extra1.example.com", "extra2.example.com"},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
"config.example.com.",
|
||||
"extra1.example.com.",
|
||||
"extra2.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
"extra1.example.com.",
|
||||
"extra2.example.com.",
|
||||
},
|
||||
applyHostConfigCall: 2,
|
||||
},
|
||||
{
|
||||
name: "Register overlapping domains",
|
||||
initialConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
{Domain: "overlap.example.com"},
|
||||
},
|
||||
},
|
||||
registerDomains: []domain.List{
|
||||
{"extra.example.com", "overlap.example.com"},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
"config.example.com.",
|
||||
"overlap.example.com.",
|
||||
"extra.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
"extra.example.com.",
|
||||
},
|
||||
applyHostConfigCall: 2,
|
||||
},
|
||||
{
|
||||
name: "Register and deregister domains",
|
||||
initialConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
},
|
||||
},
|
||||
registerDomains: []domain.List{
|
||||
{"extra1.example.com", "extra2.example.com"},
|
||||
{"extra3.example.com", "extra4.example.com"},
|
||||
},
|
||||
deregisterDomains: []domain.List{
|
||||
{"extra1.example.com", "extra3.example.com"},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
"config.example.com.",
|
||||
"extra2.example.com.",
|
||||
"extra4.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
"extra2.example.com.",
|
||||
"extra4.example.com.",
|
||||
},
|
||||
applyHostConfigCall: 4,
|
||||
},
|
||||
{
|
||||
name: "Register domains with ref counter",
|
||||
initialConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
},
|
||||
},
|
||||
registerDomains: []domain.List{
|
||||
{"extra.example.com", "duplicate.example.com"},
|
||||
{"other.example.com", "duplicate.example.com"},
|
||||
},
|
||||
deregisterDomains: []domain.List{
|
||||
{"duplicate.example.com"},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
"config.example.com.",
|
||||
"extra.example.com.",
|
||||
"other.example.com.",
|
||||
"duplicate.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
"extra.example.com.",
|
||||
"other.example.com.",
|
||||
"duplicate.example.com.",
|
||||
},
|
||||
applyHostConfigCall: 4,
|
||||
},
|
||||
{
|
||||
name: "Config update with new domains after registration",
|
||||
initialConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
},
|
||||
},
|
||||
registerDomains: []domain.List{
|
||||
{"extra.example.com", "duplicate.example.com"},
|
||||
},
|
||||
finalConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
{Domain: "newconfig.example.com"},
|
||||
},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
"config.example.com.",
|
||||
"newconfig.example.com.",
|
||||
"extra.example.com.",
|
||||
"duplicate.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
"extra.example.com.",
|
||||
"duplicate.example.com.",
|
||||
},
|
||||
applyHostConfigCall: 3,
|
||||
},
|
||||
{
|
||||
name: "Deregister domain that is part of customZones",
|
||||
initialConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
{Domain: "protected.example.com"},
|
||||
},
|
||||
},
|
||||
registerDomains: []domain.List{
|
||||
{"extra.example.com", "protected.example.com"},
|
||||
},
|
||||
deregisterDomains: []domain.List{
|
||||
{"protected.example.com"},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
"extra.example.com.",
|
||||
"config.example.com.",
|
||||
"protected.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
"extra.example.com.",
|
||||
},
|
||||
applyHostConfigCall: 3,
|
||||
},
|
||||
{
|
||||
name: "Register domain that is part of nameserver group",
|
||||
initialConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"ns.example.com", "overlap.ns.example.com"},
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
registerDomains: []domain.List{
|
||||
{"extra.example.com", "overlap.ns.example.com"},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
"ns.example.com.",
|
||||
"overlap.ns.example.com.",
|
||||
"extra.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
"ns.example.com.",
|
||||
"overlap.ns.example.com.",
|
||||
"extra.example.com.",
|
||||
},
|
||||
applyHostConfigCall: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var capturedConfigs []HostDNSConfig
|
||||
mockHostConfig := &mockHostConfigurator{
|
||||
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||
capturedConfigs = append(capturedConfigs, config)
|
||||
return nil
|
||||
},
|
||||
restoreHostDNSFunc: func() error {
|
||||
return nil
|
||||
},
|
||||
supportCustomPortFunc: func() bool {
|
||||
return true
|
||||
},
|
||||
stringFunc: func() string {
|
||||
return "mock"
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc := &mockService{}
|
||||
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
wgInterface: &mocWGIface{},
|
||||
hostManager: mockHostConfig,
|
||||
localResolver: &localResolver{},
|
||||
service: mockSvc,
|
||||
statusRecorder: peer.NewRecorder("test"),
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
}
|
||||
|
||||
// Apply initial configuration
|
||||
if tt.initialConfig.ServiceEnable {
|
||||
err := server.applyConfiguration(tt.initialConfig)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Register domains
|
||||
for _, domains := range tt.registerDomains {
|
||||
server.RegisterHandler(domains, &MockHandler{}, PriorityDefault)
|
||||
}
|
||||
|
||||
// Deregister domains if specified
|
||||
for _, domains := range tt.deregisterDomains {
|
||||
server.DeregisterHandler(domains, PriorityDefault)
|
||||
}
|
||||
|
||||
// Apply final configuration if specified
|
||||
if tt.finalConfig.ServiceEnable {
|
||||
err := server.applyConfiguration(tt.finalConfig)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify number of calls
|
||||
assert.Equal(t, tt.applyHostConfigCall, len(capturedConfigs),
|
||||
"Expected %d calls to applyDNSConfig, got %d", tt.applyHostConfigCall, len(capturedConfigs))
|
||||
|
||||
// Get the last applied config
|
||||
lastConfig := capturedConfigs[len(capturedConfigs)-1]
|
||||
|
||||
// Check all expected domains are present
|
||||
domainMap := make(map[string]bool)
|
||||
matchOnlyMap := make(map[string]bool)
|
||||
|
||||
for _, d := range lastConfig.Domains {
|
||||
domainMap[d.Domain] = true
|
||||
if d.MatchOnly {
|
||||
matchOnlyMap[d.Domain] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Verify expected domains
|
||||
for _, d := range tt.expectedDomains {
|
||||
assert.True(t, domainMap[d], "Expected domain %s not found in final config", d)
|
||||
}
|
||||
|
||||
// Verify match-only domains
|
||||
for _, d := range tt.expectedMatchOnly {
|
||||
assert.True(t, matchOnlyMap[d], "Expected match-only domain %s not found in final config", d)
|
||||
}
|
||||
|
||||
// Verify no unexpected domains
|
||||
assert.Equal(t, len(tt.expectedDomains), len(domainMap), "Unexpected number of domains in final config")
|
||||
assert.Equal(t, len(tt.expectedMatchOnly), len(matchOnlyMap), "Unexpected number of match-only domains in final config")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtraDomainsRefCounting(t *testing.T) {
|
||||
mockHostConfig := &mockHostConfigurator{
|
||||
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||
return nil
|
||||
},
|
||||
restoreHostDNSFunc: func() error {
|
||||
return nil
|
||||
},
|
||||
supportCustomPortFunc: func() bool {
|
||||
return true
|
||||
},
|
||||
stringFunc: func() string {
|
||||
return "mock"
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc := &mockService{}
|
||||
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: mockHostConfig,
|
||||
localResolver: &localResolver{},
|
||||
service: mockSvc,
|
||||
statusRecorder: peer.NewRecorder("test"),
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
}
|
||||
|
||||
// Register domains from different handlers with same domain
|
||||
server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute)
|
||||
server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityMatchDomain)
|
||||
|
||||
// Verify refcount is 2
|
||||
zoneKey := toZone("shared.example.com")
|
||||
assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice")
|
||||
|
||||
// Deregister one handler
|
||||
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityMatchDomain)
|
||||
|
||||
// Verify refcount is 1
|
||||
assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler")
|
||||
|
||||
// Deregister the other handler
|
||||
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityDNSRoute)
|
||||
|
||||
// Verify domain is removed
|
||||
_, exists := server.extraDomains[zoneKey]
|
||||
assert.False(t, exists, "Domain should be removed after deregistering all handlers")
|
||||
}
|
||||
|
||||
func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
|
||||
var capturedConfig HostDNSConfig
|
||||
mockHostConfig := &mockHostConfigurator{
|
||||
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||
capturedConfig = config
|
||||
return nil
|
||||
},
|
||||
restoreHostDNSFunc: func() error {
|
||||
return nil
|
||||
},
|
||||
supportCustomPortFunc: func() bool {
|
||||
return true
|
||||
},
|
||||
stringFunc: func() string {
|
||||
return "mock"
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc := &mockService{}
|
||||
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: mockHostConfig,
|
||||
localResolver: &localResolver{},
|
||||
service: mockSvc,
|
||||
statusRecorder: peer.NewRecorder("test"),
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
}
|
||||
|
||||
server.RegisterHandler(domain.List{"extra.example.com"}, &MockHandler{}, PriorityDefault)
|
||||
|
||||
initialConfig := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
},
|
||||
}
|
||||
err := server.applyConfiguration(initialConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var domains []string
|
||||
for _, d := range capturedConfig.Domains {
|
||||
domains = append(domains, d.Domain)
|
||||
}
|
||||
assert.Contains(t, domains, "config.example.com.")
|
||||
assert.Contains(t, domains, "extra.example.com.")
|
||||
|
||||
// Now apply a new configuration with overlapping domain
|
||||
updatedConfig := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
{Domain: "extra.example.com"},
|
||||
},
|
||||
}
|
||||
err = server.applyConfiguration(updatedConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify both domains are in config, but no duplicates
|
||||
domains = []string{}
|
||||
matchOnlyCount := 0
|
||||
for _, d := range capturedConfig.Domains {
|
||||
domains = append(domains, d.Domain)
|
||||
if d.MatchOnly {
|
||||
matchOnlyCount++
|
||||
}
|
||||
}
|
||||
|
||||
assert.Contains(t, domains, "config.example.com.")
|
||||
assert.Contains(t, domains, "extra.example.com.")
|
||||
assert.Equal(t, 2, len(domains), "Should have exactly 2 domains with no duplicates")
|
||||
|
||||
// Extra domain should no longer be marked as match-only when in config
|
||||
matchOnlyDomain := ""
|
||||
for _, d := range capturedConfig.Domains {
|
||||
if d.Domain == "extra.example.com." && d.MatchOnly {
|
||||
matchOnlyDomain = d.Domain
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.Empty(t, matchOnlyDomain, "Domain should not be match-only when included in config")
|
||||
}
|
||||
|
||||
func TestDomainCaseHandling(t *testing.T) {
|
||||
var capturedConfig HostDNSConfig
|
||||
mockHostConfig := &mockHostConfigurator{
|
||||
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||
capturedConfig = config
|
||||
return nil
|
||||
},
|
||||
restoreHostDNSFunc: func() error {
|
||||
return nil
|
||||
},
|
||||
supportCustomPortFunc: func() bool {
|
||||
return true
|
||||
},
|
||||
stringFunc: func() string {
|
||||
return "mock"
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc := &mockService{}
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: mockHostConfig,
|
||||
localResolver: &localResolver{},
|
||||
service: mockSvc,
|
||||
statusRecorder: peer.NewRecorder("test"),
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
}
|
||||
|
||||
server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault)
|
||||
server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityMatchDomain)
|
||||
|
||||
assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized")
|
||||
|
||||
config := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
},
|
||||
}
|
||||
err := server.applyConfiguration(config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var domains []string
|
||||
for _, d := range capturedConfig.Domains {
|
||||
domains = append(domains, d.Domain)
|
||||
}
|
||||
assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent")
|
||||
assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present")
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
@@ -38,7 +37,6 @@ const (
|
||||
|
||||
type systemdDbusConfigurator struct {
|
||||
dbusLinkObject dbus.ObjectPath
|
||||
routingAll bool
|
||||
ifaceName string
|
||||
}
|
||||
|
||||
@@ -112,7 +110,7 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
|
||||
continue
|
||||
}
|
||||
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
|
||||
Domain: dns.Fqdn(dConf.Domain),
|
||||
Domain: dConf.Domain,
|
||||
MatchOnly: dConf.MatchOnly,
|
||||
})
|
||||
|
||||
@@ -124,18 +122,19 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
|
||||
}
|
||||
|
||||
if config.RouteAll {
|
||||
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
||||
err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting link as default dns router, failed with error: %w", err)
|
||||
return fmt.Errorf("set link as default dns router: %w", err)
|
||||
}
|
||||
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
|
||||
Domain: nbdns.RootZone,
|
||||
MatchOnly: true,
|
||||
})
|
||||
s.routingAll = true
|
||||
} else if s.routingAll {
|
||||
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
||||
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
||||
} else {
|
||||
if err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, false); err != nil {
|
||||
return fmt.Errorf("remove link as default dns router: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
state := &ShutdownState{
|
||||
@@ -151,6 +150,11 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
if err := s.flushDNSCache(); err != nil {
|
||||
log.Errorf("failed to flush DNS cache: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -163,7 +167,8 @@ func (s *systemdDbusConfigurator) setDomainsForInterface(domainsInput []systemdD
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting domains configuration failed with error: %w", err)
|
||||
}
|
||||
return s.flushCaches()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemdDbusConfigurator) restoreHostDNS() error {
|
||||
@@ -183,10 +188,14 @@ func (s *systemdDbusConfigurator) restoreHostDNS() error {
|
||||
return fmt.Errorf("unable to revert link configuration, got error: %w", err)
|
||||
}
|
||||
|
||||
return s.flushCaches()
|
||||
if err := s.flushDNSCache(); err != nil {
|
||||
log.Errorf("failed to flush DNS cache: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemdDbusConfigurator) flushCaches() error {
|
||||
func (s *systemdDbusConfigurator) flushDNSCache() error {
|
||||
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("attempting to retrieve the object %s, err: %w", systemdDbusObjectNode, err)
|
||||
|
||||
@@ -18,14 +18,16 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
UpstreamTimeout = 15 * time.Second
|
||||
|
||||
failsTillDeact = int32(5)
|
||||
reactivatePeriod = 30 * time.Second
|
||||
upstreamTimeout = 15 * time.Second
|
||||
probeTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
@@ -66,7 +68,7 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
domain: domain,
|
||||
upstreamTimeout: upstreamTimeout,
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
reactivatePeriod: reactivatePeriod,
|
||||
failsTillDeact: failsTillDeact,
|
||||
statusRecorder: statusRecorder,
|
||||
@@ -106,9 +108,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}()
|
||||
|
||||
log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
|
||||
if r.Extra == nil {
|
||||
r.SetEdns0(4096, false)
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
}
|
||||
|
||||
@@ -336,3 +336,51 @@ func (u *upstreamResolverBase) testNameserver(server string, timeout time.Durati
|
||||
_, _, err := u.upstreamClient.exchange(ctx, server, r)
|
||||
return err
|
||||
}
|
||||
|
||||
// ExchangeWithFallback exchanges a DNS message with the upstream server.
|
||||
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
|
||||
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
|
||||
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
|
||||
// MTU - ip + udp headers
|
||||
// Note: this could be sent out on an interface that is not ours, but our MTU should always be lower.
|
||||
client.UDPSize = iface.DefaultMTU - (60 + 8)
|
||||
|
||||
var (
|
||||
rm *dns.Msg
|
||||
t time.Duration
|
||||
err error
|
||||
)
|
||||
|
||||
if ctx == nil {
|
||||
rm, t, err = client.Exchange(r, upstream)
|
||||
} else {
|
||||
rm, t, err = client.ExchangeContext(ctx, r, upstream)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, t, fmt.Errorf("with udp: %w", err)
|
||||
}
|
||||
|
||||
if rm == nil || !rm.MsgHdr.Truncated {
|
||||
return rm, t, nil
|
||||
}
|
||||
|
||||
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP.",
|
||||
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
|
||||
client.Net = "tcp"
|
||||
|
||||
if ctx == nil {
|
||||
rm, t, err = client.Exchange(r, upstream)
|
||||
} else {
|
||||
rm, t, err = client.ExchangeContext(ctx, r, upstream)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, t, fmt.Errorf("with tcp: %w", err)
|
||||
}
|
||||
|
||||
// TODO: once TCP is implemented, rm.Truncate() if the request came in over UDP
|
||||
|
||||
return rm, t, nil
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream strin
|
||||
|
||||
// exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN
|
||||
func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
||||
timeout := upstreamTimeout
|
||||
timeout := UpstreamTimeout
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
timeout = time.Until(deadline)
|
||||
}
|
||||
|
||||
@@ -34,6 +34,5 @@ func newUpstreamResolver(
|
||||
}
|
||||
|
||||
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
||||
upstreamExchangeClient := &dns.Client{}
|
||||
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
|
||||
return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream)
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err)
|
||||
}
|
||||
|
||||
timeout := upstreamTimeout
|
||||
timeout := UpstreamTimeout
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
timeout = time.Until(deadline)
|
||||
}
|
||||
@@ -68,7 +68,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
}
|
||||
|
||||
// Cannot use client.ExchangeContext because it overwrites our Dialer
|
||||
return client.Exchange(r, upstream)
|
||||
return ExchangeWithFallback(nil, client, r, upstream)
|
||||
}
|
||||
|
||||
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
||||
|
||||
@@ -26,7 +26,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
name: "Should Resolve A Record",
|
||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||
InputServers: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||
timeout: upstreamTimeout,
|
||||
timeout: UpstreamTimeout,
|
||||
expectedAnswer: "1.1.1.1",
|
||||
},
|
||||
{
|
||||
@@ -48,7 +48,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
|
||||
cancelCTX: true,
|
||||
timeout: upstreamTimeout,
|
||||
timeout: UpstreamTimeout,
|
||||
responseShouldBeNil: true,
|
||||
},
|
||||
}
|
||||
@@ -122,7 +122,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
||||
r: new(dns.Msg),
|
||||
rtt: time.Millisecond,
|
||||
},
|
||||
upstreamTimeout: upstreamTimeout,
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
reactivatePeriod: reactivatePeriod,
|
||||
failsTillDeact: failsTillDeact,
|
||||
}
|
||||
|
||||
@@ -3,34 +3,45 @@ package dnsfwd
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
const errResolveFailed = "failed to resolve query for domain=%s: %v"
|
||||
const upstreamTimeout = 15 * time.Second
|
||||
|
||||
type DNSForwarder struct {
|
||||
listenAddress string
|
||||
ttl uint32
|
||||
domains []string
|
||||
listenAddress string
|
||||
ttl uint32
|
||||
domains []string
|
||||
statusRecorder *peer.Status
|
||||
|
||||
dnsServer *dns.Server
|
||||
mux *dns.ServeMux
|
||||
|
||||
resId sync.Map
|
||||
}
|
||||
|
||||
func NewDNSForwarder(listenAddress string, ttl uint32) *DNSForwarder {
|
||||
func NewDNSForwarder(listenAddress string, ttl uint32, statusRecorder *peer.Status) *DNSForwarder {
|
||||
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
|
||||
return &DNSForwarder{
|
||||
listenAddress: listenAddress,
|
||||
ttl: ttl,
|
||||
listenAddress: listenAddress,
|
||||
ttl: ttl,
|
||||
statusRecorder: statusRecorder,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) Listen(domains []string) error {
|
||||
func (f *DNSForwarder) Listen(domains []string, resIds map[string]string) error {
|
||||
log.Infof("listen DNS forwarder on address=%s", f.listenAddress)
|
||||
mux := dns.NewServeMux()
|
||||
|
||||
@@ -42,22 +53,30 @@ func (f *DNSForwarder) Listen(domains []string) error {
|
||||
f.dnsServer = dnsServer
|
||||
f.mux = mux
|
||||
|
||||
f.UpdateDomains(domains)
|
||||
f.UpdateDomains(domains, resIds)
|
||||
|
||||
return dnsServer.ListenAndServe()
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) UpdateDomains(domains []string) {
|
||||
func (f *DNSForwarder) UpdateDomains(domains []string, resIds map[string]string) {
|
||||
log.Debugf("Updating domains from %v to %v", f.domains, domains)
|
||||
|
||||
for _, d := range f.domains {
|
||||
f.mux.HandleRemove(d)
|
||||
}
|
||||
f.resId.Clear()
|
||||
|
||||
newDomains := filterDomains(domains)
|
||||
for _, d := range newDomains {
|
||||
f.mux.HandleFunc(d, f.handleDNSQuery)
|
||||
}
|
||||
|
||||
for domain, resId := range resIds {
|
||||
if domain != "" {
|
||||
f.resId.Store(domain, resId)
|
||||
}
|
||||
}
|
||||
|
||||
f.domains = newDomains
|
||||
}
|
||||
|
||||
@@ -79,41 +98,87 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
||||
domain := question.Name
|
||||
|
||||
resp := query.SetReply(query)
|
||||
var network string
|
||||
switch question.Qtype {
|
||||
case dns.TypeA:
|
||||
network = "ip4"
|
||||
case dns.TypeAAAA:
|
||||
network = "ip6"
|
||||
default:
|
||||
// TODO: Handle other types
|
||||
|
||||
ips, err := net.LookupIP(domain)
|
||||
if err != nil {
|
||||
var dnsErr *net.DNSError
|
||||
|
||||
switch {
|
||||
case errors.As(err, &dnsErr):
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
if dnsErr.IsNotFound {
|
||||
// Pass through NXDOMAIN
|
||||
resp.Rcode = dns.RcodeNameError
|
||||
}
|
||||
|
||||
if dnsErr.Server != "" {
|
||||
log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err)
|
||||
} else {
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
}
|
||||
default:
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
}
|
||||
|
||||
resp.Rcode = dns.RcodeNotImplemented
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", err)
|
||||
log.Errorf("failed to write DNS response: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
||||
defer cancel()
|
||||
ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain)
|
||||
if err != nil {
|
||||
f.handleDNSError(w, resp, domain, err)
|
||||
return
|
||||
}
|
||||
|
||||
resId := f.getResIdForDomain(strings.TrimSuffix(domain, "."))
|
||||
if resId != "" {
|
||||
for _, ip := range ips {
|
||||
var ipWithSuffix string
|
||||
if ip.Is4() {
|
||||
ipWithSuffix = ip.String() + "/32"
|
||||
log.Tracef("resolved domain=%s to IPv4=%s", domain, ipWithSuffix)
|
||||
} else {
|
||||
ipWithSuffix = ip.String() + "/128"
|
||||
log.Tracef("resolved domain=%s to IPv6=%s", domain, ipWithSuffix)
|
||||
}
|
||||
f.statusRecorder.AddResolvedIPLookupEntry(ipWithSuffix, resId)
|
||||
}
|
||||
}
|
||||
|
||||
f.addIPsToResponse(resp, domain, ips)
|
||||
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed to write DNS response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleDNSError processes DNS lookup errors and sends an appropriate error response
|
||||
func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, resp *dns.Msg, domain string, err error) {
|
||||
var dnsErr *net.DNSError
|
||||
|
||||
switch {
|
||||
case errors.As(err, &dnsErr):
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
if dnsErr.IsNotFound {
|
||||
// Pass through NXDOMAIN
|
||||
resp.Rcode = dns.RcodeNameError
|
||||
}
|
||||
|
||||
if dnsErr.Server != "" {
|
||||
log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err)
|
||||
} else {
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
}
|
||||
default:
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// addIPsToResponse adds IP addresses to the DNS response as appropriate A or AAAA records
|
||||
func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []netip.Addr) {
|
||||
for _, ip := range ips {
|
||||
var respRecord dns.RR
|
||||
if ip.To4() == nil {
|
||||
if ip.Is6() {
|
||||
log.Tracef("resolved domain=%s to IPv6=%s", domain, ip)
|
||||
rr := dns.AAAA{
|
||||
AAAA: ip,
|
||||
AAAA: ip.AsSlice(),
|
||||
Hdr: dns.RR_Header{
|
||||
Name: domain,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
@@ -125,7 +190,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
||||
} else {
|
||||
log.Tracef("resolved domain=%s to IPv4=%s", domain, ip)
|
||||
rr := dns.A{
|
||||
A: ip,
|
||||
A: ip.AsSlice(),
|
||||
Hdr: dns.RR_Header{
|
||||
Name: domain,
|
||||
Rrtype: dns.TypeA,
|
||||
@@ -137,10 +202,36 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
||||
}
|
||||
resp.Answer = append(resp.Answer, respRecord)
|
||||
}
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed to write DNS response: %v", err)
|
||||
}
|
||||
func (f *DNSForwarder) getResIdForDomain(domain string) string {
|
||||
var selectedResId string
|
||||
var bestScore int
|
||||
|
||||
f.resId.Range(func(key, value interface{}) bool {
|
||||
var score int
|
||||
pattern := key.(string)
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(pattern, "*."):
|
||||
baseDomain := strings.TrimPrefix(pattern, "*.")
|
||||
if domain == baseDomain || strings.HasSuffix(domain, "."+baseDomain) {
|
||||
score = len(baseDomain)
|
||||
}
|
||||
case domain == pattern:
|
||||
score = math.MaxInt
|
||||
default:
|
||||
return true
|
||||
}
|
||||
|
||||
if score > bestScore {
|
||||
bestScore = score
|
||||
selectedResId = value.(string)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return selectedResId
|
||||
}
|
||||
|
||||
// filterDomains returns a list of normalized domains
|
||||
|
||||
95
client/internal/dnsfwd/forwarder_test.go
Normal file
95
client/internal/dnsfwd/forwarder_test.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package dnsfwd
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetResIdForDomain(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
storedMappings map[string]string // key: domain pattern, value: resId
|
||||
queryDomain string
|
||||
expectedResId string
|
||||
}{
|
||||
{
|
||||
name: "Empty map returns empty string",
|
||||
storedMappings: map[string]string{},
|
||||
queryDomain: "example.com",
|
||||
expectedResId: "",
|
||||
},
|
||||
{
|
||||
name: "Exact match returns stored resId",
|
||||
storedMappings: map[string]string{"example.com": "res1"},
|
||||
queryDomain: "example.com",
|
||||
expectedResId: "res1",
|
||||
},
|
||||
{
|
||||
name: "Wildcard pattern matches base domain",
|
||||
storedMappings: map[string]string{"*.example.com": "res2"},
|
||||
queryDomain: "example.com",
|
||||
expectedResId: "res2",
|
||||
},
|
||||
{
|
||||
name: "Wildcard pattern matches subdomain",
|
||||
storedMappings: map[string]string{"*.example.com": "res3"},
|
||||
queryDomain: "foo.example.com",
|
||||
expectedResId: "res3",
|
||||
},
|
||||
{
|
||||
name: "Wildcard pattern does not match different domain",
|
||||
storedMappings: map[string]string{"*.example.com": "res4"},
|
||||
queryDomain: "foo.notexample.com",
|
||||
expectedResId: "",
|
||||
},
|
||||
{
|
||||
name: "Non-wildcard pattern does not match subdomain",
|
||||
storedMappings: map[string]string{"example.com": "res5"},
|
||||
queryDomain: "foo.example.com",
|
||||
expectedResId: "",
|
||||
},
|
||||
{
|
||||
name: "Exact match over overlapping wildcard",
|
||||
storedMappings: map[string]string{
|
||||
"*.example.com": "resWildcard",
|
||||
"foo.example.com": "resExact",
|
||||
},
|
||||
queryDomain: "foo.example.com",
|
||||
expectedResId: "resExact",
|
||||
},
|
||||
{
|
||||
name: "Overlapping wildcards: Select more specific wildcard",
|
||||
storedMappings: map[string]string{
|
||||
"*.example.com": "resA",
|
||||
"*.sub.example.com": "resB",
|
||||
},
|
||||
queryDomain: "bar.sub.example.com",
|
||||
expectedResId: "resB",
|
||||
},
|
||||
{
|
||||
name: "Wildcard multi-level subdomain match",
|
||||
storedMappings: map[string]string{
|
||||
"*.example.com": "resMulti",
|
||||
},
|
||||
queryDomain: "a.b.example.com",
|
||||
expectedResId: "resMulti",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
fwd := &DNSForwarder{
|
||||
resId: sync.Map{},
|
||||
}
|
||||
|
||||
for domainPattern, resId := range tc.storedMappings {
|
||||
fwd.resId.Store(domainPattern, resId)
|
||||
}
|
||||
|
||||
got := fwd.getResIdForDomain(tc.queryDomain)
|
||||
if got != tc.expectedResId {
|
||||
t.Errorf("For query domain %q, expected resId %q, but got %q", tc.queryDomain, tc.expectedResId, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -19,19 +20,21 @@ const (
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
firewall firewall.Manager
|
||||
firewall firewall.Manager
|
||||
statusRecorder *peer.Status
|
||||
|
||||
fwRules []firewall.Rule
|
||||
dnsForwarder *DNSForwarder
|
||||
}
|
||||
|
||||
func NewManager(fw firewall.Manager) *Manager {
|
||||
func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager {
|
||||
return &Manager{
|
||||
firewall: fw,
|
||||
firewall: fw,
|
||||
statusRecorder: statusRecorder,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) Start(domains []string) error {
|
||||
func (m *Manager) Start(domains []string, resIds map[string]string) error {
|
||||
log.Infof("starting DNS forwarder")
|
||||
if m.dnsForwarder != nil {
|
||||
return nil
|
||||
@@ -41,9 +44,9 @@ func (m *Manager) Start(domains []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL)
|
||||
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, m.statusRecorder)
|
||||
go func() {
|
||||
if err := m.dnsForwarder.Listen(domains); err != nil {
|
||||
if err := m.dnsForwarder.Listen(domains, resIds); err != nil {
|
||||
// todo handle close error if it is exists
|
||||
log.Errorf("failed to start DNS forwarder, err: %v", err)
|
||||
}
|
||||
@@ -52,12 +55,12 @@ func (m *Manager) Start(domains []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) UpdateDomains(domains []string) {
|
||||
func (m *Manager) UpdateDomains(domains []string, resIds map[string]string) {
|
||||
if m.dnsForwarder == nil {
|
||||
return
|
||||
}
|
||||
|
||||
m.dnsForwarder.UpdateDomains(domains)
|
||||
m.dnsForwarder.UpdateDomains(domains, resIds)
|
||||
}
|
||||
|
||||
func (m *Manager) Stop(ctx context.Context) error {
|
||||
|
||||
@@ -353,7 +353,7 @@ func (e *Engine) Start() error {
|
||||
|
||||
// start flow manager right after interface creation
|
||||
publicKey := e.config.WgPrivateKey.PublicKey()
|
||||
e.flowManager = netflow.NewManager(e.ctx, e.wgInterface, publicKey[:], e.statusRecorder)
|
||||
e.flowManager = netflow.NewManager(e.wgInterface, publicKey[:], e.statusRecorder)
|
||||
|
||||
if e.config.RosenpassEnabled {
|
||||
log.Infof("rosenpass is enabled")
|
||||
@@ -952,11 +952,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Apply ACLs in the beginning to avoid security leaks
|
||||
if e.acl != nil {
|
||||
e.acl.ApplyFiltering(networkMap)
|
||||
}
|
||||
|
||||
if e.firewall != nil {
|
||||
if localipfw, ok := e.firewall.(localIpUpdater); ok {
|
||||
if err := localipfw.UpdateLocalIPs(); err != nil {
|
||||
@@ -967,14 +962,19 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
|
||||
// DNS forwarder
|
||||
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
||||
dnsRouteDomains := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), networkMap.GetRoutes())
|
||||
e.updateDNSForwarder(dnsRouteFeatureFlag, dnsRouteDomains)
|
||||
dnsRouteDomains, resourceIds := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), networkMap.GetRoutes())
|
||||
e.updateDNSForwarder(dnsRouteFeatureFlag, dnsRouteDomains, resourceIds)
|
||||
|
||||
routes := toRoutes(networkMap.GetRoutes())
|
||||
if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil {
|
||||
log.Errorf("failed to update clientRoutes, err: %v", err)
|
||||
}
|
||||
|
||||
// acls might need routing to be enabled, so we apply after routes
|
||||
if e.acl != nil {
|
||||
e.acl.ApplyFiltering(networkMap)
|
||||
}
|
||||
|
||||
// Ingress forward rules
|
||||
if err := e.updateForwardRules(networkMap.GetForwardingRules()); err != nil {
|
||||
log.Errorf("failed to update forward rules, err: %v", err)
|
||||
@@ -1079,21 +1079,29 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
||||
return routes
|
||||
}
|
||||
|
||||
func toRouteDomains(myPubKey string, protoRoutes []*mgmProto.Route) []string {
|
||||
func toRouteDomains(myPubKey string, protoRoutes []*mgmProto.Route) ([]string, map[string]string) {
|
||||
if protoRoutes == nil {
|
||||
protoRoutes = []*mgmProto.Route{}
|
||||
}
|
||||
|
||||
var dnsRoutes []string
|
||||
resIds := make(map[string]string)
|
||||
for _, protoRoute := range protoRoutes {
|
||||
if len(protoRoute.Domains) == 0 {
|
||||
continue
|
||||
}
|
||||
if protoRoute.Peer == myPubKey {
|
||||
dnsRoutes = append(dnsRoutes, protoRoute.Domains...)
|
||||
// resource ID is the first part of the ID
|
||||
resId := strings.Split(protoRoute.ID, ":")
|
||||
for _, domain := range protoRoute.Domains {
|
||||
if len(resId) > 0 {
|
||||
resIds[domain] = resId[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return dnsRoutes
|
||||
return dnsRoutes, resIds
|
||||
}
|
||||
|
||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config {
|
||||
@@ -1223,36 +1231,19 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer
|
||||
PreSharedKey: e.config.PreSharedKey,
|
||||
}
|
||||
|
||||
if e.config.RosenpassEnabled && !e.config.RosenpassPermissive {
|
||||
lk := []byte(e.config.WgPrivateKey.PublicKey().String())
|
||||
rk := []byte(wgConfig.RemoteKey)
|
||||
var keyInput []byte
|
||||
if string(lk) > string(rk) {
|
||||
//nolint:gocritic
|
||||
keyInput = append(lk[:16], rk[:16]...)
|
||||
} else {
|
||||
//nolint:gocritic
|
||||
keyInput = append(rk[:16], lk[:16]...)
|
||||
}
|
||||
|
||||
key, err := wgtypes.NewKey(keyInput)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wgConfig.PreSharedKey = &key
|
||||
}
|
||||
|
||||
// randomize connection timeout
|
||||
timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond
|
||||
config := peer.ConnConfig{
|
||||
Key: pubKey,
|
||||
LocalKey: e.config.WgPrivateKey.PublicKey().String(),
|
||||
Timeout: timeout,
|
||||
WgConfig: wgConfig,
|
||||
LocalWgPort: e.config.WgPort,
|
||||
RosenpassPubKey: e.getRosenpassPubKey(),
|
||||
RosenpassAddr: e.getRosenpassAddr(),
|
||||
Key: pubKey,
|
||||
LocalKey: e.config.WgPrivateKey.PublicKey().String(),
|
||||
Timeout: timeout,
|
||||
WgConfig: wgConfig,
|
||||
LocalWgPort: e.config.WgPort,
|
||||
RosenpassConfig: peer.RosenpassConfig{
|
||||
PubKey: e.getRosenpassPubKey(),
|
||||
Addr: e.getRosenpassAddr(),
|
||||
PermissiveMode: e.config.RosenpassPermissive,
|
||||
},
|
||||
ICEConfig: icemaker.Config{
|
||||
StunTurn: &e.stunTurn,
|
||||
InterfaceBlackList: e.config.IFaceBlackList,
|
||||
@@ -1760,7 +1751,7 @@ func (e *Engine) GetWgAddr() net.IP {
|
||||
}
|
||||
|
||||
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
|
||||
func (e *Engine) updateDNSForwarder(enabled bool, domains []string) {
|
||||
func (e *Engine) updateDNSForwarder(enabled bool, domains []string, resIds map[string]string) {
|
||||
if !enabled {
|
||||
if e.dnsForwardMgr == nil {
|
||||
return
|
||||
@@ -1774,15 +1765,15 @@ func (e *Engine) updateDNSForwarder(enabled bool, domains []string) {
|
||||
if len(domains) > 0 {
|
||||
log.Infof("enable domain router service for domains: %v", domains)
|
||||
if e.dnsForwardMgr == nil {
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall)
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
|
||||
|
||||
if err := e.dnsForwardMgr.Start(domains); err != nil {
|
||||
if err := e.dnsForwardMgr.Start(domains, resIds); err != nil {
|
||||
log.Errorf("failed to start DNS forward: %v", err)
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
} else {
|
||||
log.Infof("update domain router service for domains: %v", domains)
|
||||
e.dnsForwardMgr.UpdateDomains(domains)
|
||||
e.dnsForwardMgr.UpdateDomains(domains, resIds)
|
||||
}
|
||||
} else if e.dnsForwardMgr != nil {
|
||||
log.Infof("disable domain router service")
|
||||
|
||||
@@ -49,6 +49,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
@@ -1400,15 +1401,15 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
config := &server.Config{
|
||||
Stuns: []*server.Host{},
|
||||
TURNConfig: &server.TURNConfig{},
|
||||
Relay: &server.Relay{
|
||||
config := &types.Config{
|
||||
Stuns: []*types.Host{},
|
||||
TURNConfig: &types.TURNConfig{},
|
||||
Relay: &types.Relay{
|
||||
Addresses: []string{"127.0.0.1:1234"},
|
||||
CredentialsTTL: util.Duration{Duration: time.Hour},
|
||||
Secret: "222222222222222222",
|
||||
},
|
||||
Signal: &server.Host{
|
||||
Signal: &types.Host{
|
||||
Proto: "http",
|
||||
URI: "localhost:10000",
|
||||
},
|
||||
@@ -1446,7 +1447,9 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/ti-mo/netfilter"
|
||||
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const defaultChannelSize = 100
|
||||
@@ -176,7 +177,7 @@ func (c *ConnTrack) handleEvent(event nfct.Event) {
|
||||
srcIP := flow.TupleOrig.IP.SourceAddress
|
||||
dstIP := flow.TupleOrig.IP.DestinationAddress
|
||||
|
||||
if !c.relevantFlow(srcIP, dstIP) {
|
||||
if !c.relevantFlow(flow.Mark, srcIP, dstIP) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -193,7 +194,7 @@ func (c *ConnTrack) handleEvent(event nfct.Event) {
|
||||
}
|
||||
|
||||
flowID := c.getFlowID(flow.ID)
|
||||
direction := c.inferDirection(srcIP, dstIP)
|
||||
direction := c.inferDirection(flow.Mark, srcIP, dstIP)
|
||||
|
||||
eventType := nftypes.TypeStart
|
||||
eventStr := "New"
|
||||
@@ -224,15 +225,14 @@ func (c *ConnTrack) handleEvent(event nfct.Event) {
|
||||
}
|
||||
|
||||
// relevantFlow checks if the flow is related to the specified interface
|
||||
func (c *ConnTrack) relevantFlow(srcIP, dstIP netip.Addr) bool {
|
||||
// TODO: filter traffic by interface
|
||||
|
||||
wgnet := c.iface.Address().Network
|
||||
if !wgnet.Contains(srcIP.AsSlice()) && !wgnet.Contains(dstIP.AsSlice()) {
|
||||
return false
|
||||
func (c *ConnTrack) relevantFlow(mark uint32, srcIP, dstIP netip.Addr) bool {
|
||||
if nbnet.IsDataPlaneMark(mark) {
|
||||
return true
|
||||
}
|
||||
|
||||
return true
|
||||
// fallback if mark rules are not in place
|
||||
wgnet := c.iface.Address().Network
|
||||
return wgnet.Contains(srcIP.AsSlice()) || wgnet.Contains(dstIP.AsSlice())
|
||||
}
|
||||
|
||||
// mapRxPackets maps packet counts to RX based on flow direction
|
||||
@@ -282,7 +282,15 @@ func (c *ConnTrack) getFlowID(conntrackID uint32) uuid.UUID {
|
||||
return uuid.NewSHA1(c.instanceID, buf[:])
|
||||
}
|
||||
|
||||
func (c *ConnTrack) inferDirection(srcIP, dstIP netip.Addr) nftypes.Direction {
|
||||
func (c *ConnTrack) inferDirection(mark uint32, srcIP, dstIP netip.Addr) nftypes.Direction {
|
||||
switch mark {
|
||||
case nbnet.DataPlaneMarkIn:
|
||||
return nftypes.Ingress
|
||||
case nbnet.DataPlaneMarkOut:
|
||||
return nftypes.Egress
|
||||
}
|
||||
|
||||
// fallback if marks are not set
|
||||
wgaddr := c.iface.Address().IP
|
||||
wgnetwork := c.iface.Address().Network
|
||||
src, dst := srcIP.AsSlice(), dstIP.AsSlice()
|
||||
@@ -298,8 +306,6 @@ func (c *ConnTrack) inferDirection(srcIP, dstIP netip.Addr) nftypes.Direction {
|
||||
case wgnetwork.Contains(dst):
|
||||
// resource network -> netbird network
|
||||
return nftypes.Egress
|
||||
|
||||
// TODO: handle site2site traffic
|
||||
}
|
||||
|
||||
return nftypes.DirectionUnknown
|
||||
|
||||
@@ -19,11 +19,9 @@ import (
|
||||
type rcvChan chan *types.EventFields
|
||||
type Logger struct {
|
||||
mux sync.Mutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
enabled atomic.Bool
|
||||
rcvChan atomic.Pointer[rcvChan]
|
||||
cancelReceiver context.CancelFunc
|
||||
cancel context.CancelFunc
|
||||
statusRecorder *peer.Status
|
||||
wgIfaceIPNet net.IPNet
|
||||
dnsCollection atomic.Bool
|
||||
@@ -31,12 +29,9 @@ type Logger struct {
|
||||
Store types.Store
|
||||
}
|
||||
|
||||
func New(ctx context.Context, statusRecorder *peer.Status, wgIfaceIPNet net.IPNet) *Logger {
|
||||
func New(statusRecorder *peer.Status, wgIfaceIPNet net.IPNet) *Logger {
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
return &Logger{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
statusRecorder: statusRecorder,
|
||||
wgIfaceIPNet: wgIfaceIPNet,
|
||||
Store: store.NewMemoryStore(),
|
||||
@@ -70,8 +65,8 @@ func (l *Logger) startReceiver() {
|
||||
}
|
||||
|
||||
l.mux.Lock()
|
||||
ctx, cancel := context.WithCancel(l.ctx)
|
||||
l.cancelReceiver = cancel
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
l.cancel = cancel
|
||||
l.mux.Unlock()
|
||||
|
||||
c := make(rcvChan, 100)
|
||||
@@ -91,25 +86,25 @@ func (l *Logger) startReceiver() {
|
||||
Timestamp: time.Now().UTC(),
|
||||
}
|
||||
|
||||
var isExitNode bool
|
||||
if event.Direction == types.Ingress {
|
||||
if !l.wgIfaceIPNet.Contains(net.IP(event.SourceIP.AsSlice())) {
|
||||
event.SourceResourceID, isExitNode = l.statusRecorder.CheckRoutes(event.SourceIP)
|
||||
}
|
||||
} else if event.Direction == types.Egress {
|
||||
if !l.wgIfaceIPNet.Contains(net.IP(event.DestIP.AsSlice())) {
|
||||
event.DestResourceID, isExitNode = l.statusRecorder.CheckRoutes(event.DestIP)
|
||||
}
|
||||
var isSrcExitNode bool
|
||||
var isDestExitNode bool
|
||||
|
||||
if !l.wgIfaceIPNet.Contains(net.IP(event.SourceIP.AsSlice())) {
|
||||
event.SourceResourceID, isSrcExitNode = l.statusRecorder.CheckRoutes(event.SourceIP)
|
||||
}
|
||||
|
||||
if l.shouldStore(eventFields, isExitNode) {
|
||||
if !l.wgIfaceIPNet.Contains(net.IP(event.DestIP.AsSlice())) {
|
||||
event.DestResourceID, isDestExitNode = l.statusRecorder.CheckRoutes(event.DestIP)
|
||||
}
|
||||
|
||||
if l.shouldStore(eventFields, isSrcExitNode || isDestExitNode) {
|
||||
l.Store.StoreEvent(&event)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Disable() {
|
||||
func (l *Logger) Close() {
|
||||
l.stop()
|
||||
l.Store.Close()
|
||||
}
|
||||
@@ -121,9 +116,9 @@ func (l *Logger) stop() {
|
||||
|
||||
l.enabled.Store(false)
|
||||
l.mux.Lock()
|
||||
if l.cancelReceiver != nil {
|
||||
l.cancelReceiver()
|
||||
l.cancelReceiver = nil
|
||||
if l.cancel != nil {
|
||||
l.cancel()
|
||||
l.cancel = nil
|
||||
}
|
||||
l.rcvChan.Store(nil)
|
||||
l.mux.Unlock()
|
||||
@@ -142,11 +137,6 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) {
|
||||
l.exitNodeCollection.Store(exitNodeCollection)
|
||||
}
|
||||
|
||||
func (l *Logger) Close() {
|
||||
l.stop()
|
||||
l.cancel()
|
||||
}
|
||||
|
||||
func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool {
|
||||
// check dns collection
|
||||
if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == dnsfwd.ListenPort) {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package logger_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -13,7 +12,7 @@ import (
|
||||
)
|
||||
|
||||
func TestStore(t *testing.T) {
|
||||
logger := logger.New(context.Background(), nil, net.IPNet{})
|
||||
logger := logger.New(nil, net.IPNet{})
|
||||
logger.Enable()
|
||||
|
||||
event := types.EventFields{
|
||||
@@ -40,7 +39,7 @@ func TestStore(t *testing.T) {
|
||||
}
|
||||
|
||||
// test disable
|
||||
logger.Disable()
|
||||
logger.Close()
|
||||
wait()
|
||||
logger.StoreEvent(event)
|
||||
wait()
|
||||
|
||||
@@ -27,18 +27,18 @@ type Manager struct {
|
||||
logger nftypes.FlowLogger
|
||||
flowConfig *nftypes.FlowConfig
|
||||
conntrack nftypes.ConnTracker
|
||||
ctx context.Context
|
||||
receiverClient *client.GRPCClient
|
||||
publicKey []byte
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewManager creates a new netflow manager
|
||||
func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
|
||||
func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
|
||||
var ipNet net.IPNet
|
||||
if iface != nil {
|
||||
ipNet = *iface.Address().Network
|
||||
}
|
||||
flowLogger := logger.New(ctx, statusRecorder, ipNet)
|
||||
flowLogger := logger.New(statusRecorder, ipNet)
|
||||
|
||||
var ct nftypes.ConnTracker
|
||||
if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() {
|
||||
@@ -48,7 +48,6 @@ func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte
|
||||
return &Manager{
|
||||
logger: flowLogger,
|
||||
conntrack: ct,
|
||||
ctx: ctx,
|
||||
publicKey: publicKey,
|
||||
}
|
||||
}
|
||||
@@ -68,21 +67,9 @@ func (m *Manager) needsNewClient(previous *nftypes.FlowConfig) bool {
|
||||
func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error {
|
||||
// first make sender ready so events don't pile up
|
||||
if m.needsNewClient(previous) {
|
||||
if m.receiverClient != nil {
|
||||
if err := m.receiverClient.Close(); err != nil {
|
||||
log.Warnf("error closing previous flow client: %v", err)
|
||||
}
|
||||
if err := m.resetClient(); err != nil {
|
||||
return fmt.Errorf("reset client: %w", err)
|
||||
}
|
||||
|
||||
flowClient, err := client.NewClient(m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature, m.flowConfig.Interval)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create client: %w", err)
|
||||
}
|
||||
log.Infof("flow client configured to connect to %s", m.flowConfig.URL)
|
||||
|
||||
m.receiverClient = flowClient
|
||||
go m.receiveACKs(flowClient)
|
||||
go m.startSender()
|
||||
}
|
||||
|
||||
m.logger.Enable()
|
||||
@@ -96,17 +83,50 @@ func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) resetClient() error {
|
||||
if m.receiverClient != nil {
|
||||
if err := m.receiverClient.Close(); err != nil {
|
||||
log.Warnf("error closing previous flow client: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
flowClient, err := client.NewClient(m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature, m.flowConfig.Interval)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create client: %w", err)
|
||||
}
|
||||
log.Infof("flow client configured to connect to %s", m.flowConfig.URL)
|
||||
|
||||
m.receiverClient = flowClient
|
||||
|
||||
if m.cancel != nil {
|
||||
m.cancel()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
m.cancel = cancel
|
||||
|
||||
go m.receiveACKs(ctx, flowClient)
|
||||
go m.startSender(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// disableFlow stops components for flow tracking
|
||||
func (m *Manager) disableFlow() error {
|
||||
if m.cancel != nil {
|
||||
m.cancel()
|
||||
}
|
||||
|
||||
if m.conntrack != nil {
|
||||
m.conntrack.Stop()
|
||||
}
|
||||
|
||||
m.logger.Disable()
|
||||
m.logger.Close()
|
||||
|
||||
if m.receiverClient != nil {
|
||||
return m.receiverClient.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -133,17 +153,18 @@ func (m *Manager) Update(update *nftypes.FlowConfig) error {
|
||||
|
||||
m.logger.UpdateConfig(update.DNSCollection, update.ExitNodeCollection)
|
||||
|
||||
changed := previous != nil && update.Enabled != previous.Enabled
|
||||
if update.Enabled {
|
||||
log.Infof("netflow manager enabled; starting netflow manager")
|
||||
if changed {
|
||||
log.Infof("netflow manager enabled; starting netflow manager")
|
||||
}
|
||||
return m.enableFlow(previous)
|
||||
}
|
||||
|
||||
log.Infof("netflow manager disabled; stopping netflow manager")
|
||||
err := m.disableFlow()
|
||||
if err != nil {
|
||||
log.Errorf("failed to disable netflow manager: %v", err)
|
||||
if changed {
|
||||
log.Infof("netflow manager disabled; stopping netflow manager")
|
||||
}
|
||||
return err
|
||||
return m.disableFlow()
|
||||
}
|
||||
|
||||
// Close cleans up all resources
|
||||
@@ -151,17 +172,9 @@ func (m *Manager) Close() {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
if m.conntrack != nil {
|
||||
m.conntrack.Close()
|
||||
if err := m.disableFlow(); err != nil {
|
||||
log.Warnf("failed to disable flow manager: %v", err)
|
||||
}
|
||||
|
||||
if m.receiverClient != nil {
|
||||
if err := m.receiverClient.Close(); err != nil {
|
||||
log.Warnf("failed to close receiver client: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
m.logger.Close()
|
||||
}
|
||||
|
||||
// GetLogger returns the flow logger
|
||||
@@ -169,13 +182,13 @@ func (m *Manager) GetLogger() nftypes.FlowLogger {
|
||||
return m.logger
|
||||
}
|
||||
|
||||
func (m *Manager) startSender() {
|
||||
func (m *Manager) startSender(ctx context.Context) {
|
||||
ticker := time.NewTicker(m.flowConfig.Interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
events := m.logger.GetEvents()
|
||||
@@ -190,8 +203,8 @@ func (m *Manager) startSender() {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) receiveACKs(client *client.GRPCClient) {
|
||||
err := client.Receive(m.ctx, m.flowConfig.Interval, func(ack *proto.FlowEventAck) error {
|
||||
func (m *Manager) receiveACKs(ctx context.Context, client *client.GRPCClient) {
|
||||
err := client.Receive(ctx, m.flowConfig.Interval, func(ack *proto.FlowEventAck) error {
|
||||
id, err := uuid.FromBytes(ack.EventId)
|
||||
if err != nil {
|
||||
log.Warnf("failed to convert ack event id to uuid: %v", err)
|
||||
|
||||
200
client/internal/netflow/manager_test.go
Normal file
200
client/internal/netflow/manager_test.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package netflow
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
type mockIFaceMapper struct {
|
||||
address wgaddr.Address
|
||||
isUserspaceBind bool
|
||||
}
|
||||
|
||||
func (m *mockIFaceMapper) Name() string {
|
||||
return "wt0"
|
||||
}
|
||||
|
||||
func (m *mockIFaceMapper) Address() wgaddr.Address {
|
||||
return m.address
|
||||
}
|
||||
|
||||
func (m *mockIFaceMapper) IsUserspaceBind() bool {
|
||||
return m.isUserspaceBind
|
||||
}
|
||||
|
||||
func TestManager_Update(t *testing.T) {
|
||||
mockIFace := &mockIFaceMapper{
|
||||
address: wgaddr.Address{
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
isUserspaceBind: true,
|
||||
}
|
||||
|
||||
publicKey := []byte("test-public-key")
|
||||
statusRecorder := peer.NewRecorder("")
|
||||
|
||||
manager := NewManager(mockIFace, publicKey, statusRecorder)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config *types.FlowConfig
|
||||
}{
|
||||
{
|
||||
name: "nil config",
|
||||
config: nil,
|
||||
},
|
||||
{
|
||||
name: "disabled config",
|
||||
config: &types.FlowConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "enabled config with minimal valid settings",
|
||||
config: &types.FlowConfig{
|
||||
Enabled: true,
|
||||
URL: "https://example.com",
|
||||
TokenPayload: "test-payload",
|
||||
TokenSignature: "test-signature",
|
||||
Interval: 30 * time.Second,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := manager.Update(tc.config)
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
if tc.config == nil {
|
||||
return
|
||||
}
|
||||
|
||||
require.NotNil(t, manager.flowConfig)
|
||||
|
||||
if tc.config.Enabled {
|
||||
assert.Equal(t, tc.config.Enabled, manager.flowConfig.Enabled)
|
||||
}
|
||||
|
||||
if tc.config.URL != "" {
|
||||
assert.Equal(t, tc.config.URL, manager.flowConfig.URL)
|
||||
}
|
||||
|
||||
if tc.config.TokenPayload != "" {
|
||||
assert.Equal(t, tc.config.TokenPayload, manager.flowConfig.TokenPayload)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_Update_TokenPreservation(t *testing.T) {
|
||||
mockIFace := &mockIFaceMapper{
|
||||
address: wgaddr.Address{
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
isUserspaceBind: true,
|
||||
}
|
||||
|
||||
publicKey := []byte("test-public-key")
|
||||
manager := NewManager(mockIFace, publicKey, nil)
|
||||
|
||||
// First update with tokens
|
||||
initialConfig := &types.FlowConfig{
|
||||
Enabled: false,
|
||||
TokenPayload: "initial-payload",
|
||||
TokenSignature: "initial-signature",
|
||||
}
|
||||
|
||||
err := manager.Update(initialConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Second update without tokens should preserve them
|
||||
updatedConfig := &types.FlowConfig{
|
||||
Enabled: false,
|
||||
URL: "https://example.com",
|
||||
}
|
||||
|
||||
err = manager.Update(updatedConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify tokens were preserved
|
||||
assert.Equal(t, "initial-payload", manager.flowConfig.TokenPayload)
|
||||
assert.Equal(t, "initial-signature", manager.flowConfig.TokenSignature)
|
||||
}
|
||||
|
||||
func TestManager_NeedsNewClient(t *testing.T) {
|
||||
manager := &Manager{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
previous *types.FlowConfig
|
||||
current *types.FlowConfig
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil previous config",
|
||||
previous: nil,
|
||||
current: &types.FlowConfig{},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "previous disabled",
|
||||
previous: &types.FlowConfig{Enabled: false},
|
||||
current: &types.FlowConfig{Enabled: true},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "different URL",
|
||||
previous: &types.FlowConfig{Enabled: true, URL: "old-url"},
|
||||
current: &types.FlowConfig{Enabled: true, URL: "new-url"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "different TokenPayload",
|
||||
previous: &types.FlowConfig{Enabled: true, TokenPayload: "old-payload"},
|
||||
current: &types.FlowConfig{Enabled: true, TokenPayload: "new-payload"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "different TokenSignature",
|
||||
previous: &types.FlowConfig{Enabled: true, TokenSignature: "old-signature"},
|
||||
current: &types.FlowConfig{Enabled: true, TokenSignature: "new-signature"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "same config",
|
||||
previous: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature"},
|
||||
current: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "only interval changed",
|
||||
previous: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature", Interval: 30 * time.Second},
|
||||
current: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature", Interval: 60 * time.Second},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
manager.flowConfig = tc.current
|
||||
result := manager.needsNewClient(tc.previous)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
const ZoneID = 0x1BD0
|
||||
|
||||
type Protocol uint8
|
||||
|
||||
const (
|
||||
@@ -120,9 +122,6 @@ type FlowLogger interface {
|
||||
Close()
|
||||
// Enable enables the flow logger receiver
|
||||
Enable()
|
||||
// Disable disables the flow logger receiver
|
||||
Disable()
|
||||
|
||||
// UpdateConfig updates the flow manager configuration
|
||||
UpdateConfig(dnsCollection, exitNodeCollection bool)
|
||||
}
|
||||
|
||||
@@ -60,6 +60,15 @@ type WgConfig struct {
|
||||
PreSharedKey *wgtypes.Key
|
||||
}
|
||||
|
||||
type RosenpassConfig struct {
|
||||
// RosenpassPubKey is this peer's Rosenpass public key
|
||||
PubKey []byte
|
||||
// RosenpassPubKey is this peer's RosenpassAddr server address (IP:port)
|
||||
Addr string
|
||||
|
||||
PermissiveMode bool
|
||||
}
|
||||
|
||||
// ConnConfig is a peer Connection configuration
|
||||
type ConnConfig struct {
|
||||
// Key is a public key of a remote peer
|
||||
@@ -73,10 +82,7 @@ type ConnConfig struct {
|
||||
|
||||
LocalWgPort int
|
||||
|
||||
// RosenpassPubKey is this peer's Rosenpass public key
|
||||
RosenpassPubKey []byte
|
||||
// RosenpassPubKey is this peer's RosenpassAddr server address (IP:port)
|
||||
RosenpassAddr string
|
||||
RosenpassConfig RosenpassConfig
|
||||
|
||||
// ICEConfig ICE protocol configuration
|
||||
ICEConfig icemaker.Config
|
||||
@@ -103,11 +109,14 @@ type Conn struct {
|
||||
|
||||
workerICE *WorkerICE
|
||||
workerRelay *WorkerRelay
|
||||
wgWatcherWg sync.WaitGroup
|
||||
|
||||
connIDRelay nbnet.ConnectionID
|
||||
connIDICE nbnet.ConnectionID
|
||||
beforeAddPeerHooks []nbnet.AddHookFunc
|
||||
afterRemovePeerHooks []nbnet.RemoveHookFunc
|
||||
// used to store the remote Rosenpass key for Relayed connection in case of connection update from ice
|
||||
rosenpassRemoteKey []byte
|
||||
|
||||
wgProxyICE wgproxy.Proxy
|
||||
wgProxyRelay wgproxy.Proxy
|
||||
@@ -211,6 +220,7 @@ func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) {
|
||||
// Close closes this peer Conn issuing a close event to the Conn closeCh
|
||||
func (conn *Conn) Close() {
|
||||
conn.mu.Lock()
|
||||
defer conn.wgWatcherWg.Wait()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
conn.log.Infof("close peer connection")
|
||||
@@ -252,6 +262,7 @@ func (conn *Conn) Close() {
|
||||
}
|
||||
|
||||
conn.setStatusToDisconnected()
|
||||
conn.log.Infof("peer connection has been closed")
|
||||
}
|
||||
|
||||
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
|
||||
@@ -362,6 +373,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
|
||||
}
|
||||
|
||||
conn.workerRelay.DisableWgWatcher()
|
||||
// todo consider to run conn.wgWatcherWg.Wait() here
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
conn.wgProxyRelay.Pause()
|
||||
@@ -371,7 +383,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
|
||||
wgProxy.Work()
|
||||
}
|
||||
|
||||
if err = conn.configureWGEndpoint(ep); err != nil {
|
||||
if err = conn.configureWGEndpoint(ep, iceConnInfo.RosenpassPubKey); err != nil {
|
||||
conn.handleConfigurationFailure(err, wgProxy)
|
||||
return
|
||||
}
|
||||
@@ -404,10 +416,15 @@ func (conn *Conn) onICEStateDisconnected() {
|
||||
conn.dumpState.SwitchToRelay()
|
||||
conn.wgProxyRelay.Work()
|
||||
|
||||
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
|
||||
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil {
|
||||
conn.log.Errorf("failed to switch to relay conn: %v", err)
|
||||
}
|
||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||
|
||||
conn.wgWatcherWg.Add(1)
|
||||
go func() {
|
||||
defer conn.wgWatcherWg.Done()
|
||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||
}()
|
||||
conn.currentConnPriority = connPriorityRelay
|
||||
} else {
|
||||
conn.log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", connPriorityNone.String())
|
||||
@@ -469,16 +486,22 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
}
|
||||
|
||||
wgProxy.Work()
|
||||
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil {
|
||||
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
|
||||
if err := wgProxy.CloseConn(); err != nil {
|
||||
conn.log.Warnf("Failed to close relay connection: %v", err)
|
||||
}
|
||||
conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err)
|
||||
return
|
||||
}
|
||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||
|
||||
conn.wgWatcherWg.Add(1)
|
||||
go func() {
|
||||
defer conn.wgWatcherWg.Done()
|
||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||
}()
|
||||
|
||||
wgConfigWorkaround()
|
||||
conn.rosenpassRemoteKey = rci.rosenpassPubKey
|
||||
conn.currentConnPriority = connPriorityRelay
|
||||
conn.statusRelay.Set(StatusConnected)
|
||||
conn.setRelayedProxy(wgProxy)
|
||||
@@ -502,6 +525,7 @@ func (conn *Conn) onRelayDisconnected() {
|
||||
if err := conn.removeWgPeer(); err != nil {
|
||||
conn.log.Errorf("failed to remove wg endpoint: %v", err)
|
||||
}
|
||||
conn.currentConnPriority = connPriorityNone
|
||||
}
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
@@ -541,13 +565,14 @@ func (conn *Conn) listenGuardEvent(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr) error {
|
||||
func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr, remoteRPKey []byte) error {
|
||||
presharedKey := conn.presharedKey(remoteRPKey)
|
||||
return conn.config.WgConfig.WgInterface.UpdatePeer(
|
||||
conn.config.WgConfig.RemoteKey,
|
||||
conn.config.WgConfig.AllowedIps,
|
||||
defaultWgKeepAlive,
|
||||
addr,
|
||||
conn.config.WgConfig.PreSharedKey,
|
||||
presharedKey,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -768,6 +793,44 @@ func (conn *Conn) AllowedIP() netip.Addr {
|
||||
return conn.config.WgConfig.AllowedIps[0].Addr()
|
||||
}
|
||||
|
||||
func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
|
||||
if conn.config.RosenpassConfig.PubKey == nil {
|
||||
return conn.config.WgConfig.PreSharedKey
|
||||
}
|
||||
|
||||
if remoteRosenpassKey == nil && conn.config.RosenpassConfig.PermissiveMode {
|
||||
return conn.config.WgConfig.PreSharedKey
|
||||
}
|
||||
|
||||
determKey, err := conn.rosenpassDetermKey()
|
||||
if err != nil {
|
||||
conn.log.Errorf("failed to generate Rosenpass initial key: %v", err)
|
||||
return conn.config.WgConfig.PreSharedKey
|
||||
}
|
||||
|
||||
return determKey
|
||||
}
|
||||
|
||||
// todo: move this logic into Rosenpass package
|
||||
func (conn *Conn) rosenpassDetermKey() (*wgtypes.Key, error) {
|
||||
lk := []byte(conn.config.LocalKey)
|
||||
rk := []byte(conn.config.Key) // remote key
|
||||
var keyInput []byte
|
||||
if string(lk) > string(rk) {
|
||||
//nolint:gocritic
|
||||
keyInput = append(lk[:16], rk[:16]...)
|
||||
} else {
|
||||
//nolint:gocritic
|
||||
keyInput = append(rk[:16], lk[:16]...)
|
||||
}
|
||||
|
||||
key, err := wgtypes.NewKey(keyInput)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
func isController(config ConnConfig) bool {
|
||||
return config.LocalKey > config.Key
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
@@ -161,3 +162,145 @@ func TestConn_Status(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_presharedKey(t *testing.T) {
|
||||
conn1 := Conn{
|
||||
config: ConnConfig{
|
||||
Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
RosenpassConfig: RosenpassConfig{},
|
||||
},
|
||||
}
|
||||
conn2 := Conn{
|
||||
config: ConnConfig{
|
||||
Key: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
LocalKey: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
RosenpassConfig: RosenpassConfig{},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
conn1Permissive bool
|
||||
conn1RosenpassEnabled bool
|
||||
conn2Permissive bool
|
||||
conn2RosenpassEnabled bool
|
||||
conn1ExpectedInitialKey bool
|
||||
conn2ExpectedInitialKey bool
|
||||
}{
|
||||
{
|
||||
conn1Permissive: false,
|
||||
conn1RosenpassEnabled: false,
|
||||
conn2Permissive: false,
|
||||
conn2RosenpassEnabled: false,
|
||||
conn1ExpectedInitialKey: false,
|
||||
conn2ExpectedInitialKey: false,
|
||||
},
|
||||
{
|
||||
conn1Permissive: false,
|
||||
conn1RosenpassEnabled: true,
|
||||
conn2Permissive: false,
|
||||
conn2RosenpassEnabled: true,
|
||||
conn1ExpectedInitialKey: true,
|
||||
conn2ExpectedInitialKey: true,
|
||||
},
|
||||
{
|
||||
conn1Permissive: false,
|
||||
conn1RosenpassEnabled: true,
|
||||
conn2Permissive: false,
|
||||
conn2RosenpassEnabled: false,
|
||||
conn1ExpectedInitialKey: true,
|
||||
conn2ExpectedInitialKey: false,
|
||||
},
|
||||
{
|
||||
conn1Permissive: false,
|
||||
conn1RosenpassEnabled: false,
|
||||
conn2Permissive: false,
|
||||
conn2RosenpassEnabled: true,
|
||||
conn1ExpectedInitialKey: false,
|
||||
conn2ExpectedInitialKey: true,
|
||||
},
|
||||
{
|
||||
conn1Permissive: true,
|
||||
conn1RosenpassEnabled: true,
|
||||
conn2Permissive: false,
|
||||
conn2RosenpassEnabled: false,
|
||||
conn1ExpectedInitialKey: false,
|
||||
conn2ExpectedInitialKey: false,
|
||||
},
|
||||
{
|
||||
conn1Permissive: false,
|
||||
conn1RosenpassEnabled: false,
|
||||
conn2Permissive: true,
|
||||
conn2RosenpassEnabled: true,
|
||||
conn1ExpectedInitialKey: false,
|
||||
conn2ExpectedInitialKey: false,
|
||||
},
|
||||
{
|
||||
conn1Permissive: true,
|
||||
conn1RosenpassEnabled: true,
|
||||
conn2Permissive: true,
|
||||
conn2RosenpassEnabled: true,
|
||||
conn1ExpectedInitialKey: true,
|
||||
conn2ExpectedInitialKey: true,
|
||||
},
|
||||
{
|
||||
conn1Permissive: false,
|
||||
conn1RosenpassEnabled: false,
|
||||
conn2Permissive: false,
|
||||
conn2RosenpassEnabled: true,
|
||||
conn1ExpectedInitialKey: false,
|
||||
conn2ExpectedInitialKey: true,
|
||||
},
|
||||
{
|
||||
conn1Permissive: false,
|
||||
conn1RosenpassEnabled: true,
|
||||
conn2Permissive: true,
|
||||
conn2RosenpassEnabled: true,
|
||||
conn1ExpectedInitialKey: true,
|
||||
conn2ExpectedInitialKey: true,
|
||||
},
|
||||
}
|
||||
|
||||
conn1.config.RosenpassConfig.PermissiveMode = true
|
||||
for i, test := range tests {
|
||||
tcase := i + 1
|
||||
t.Run(fmt.Sprintf("Rosenpass test case %d", tcase), func(t *testing.T) {
|
||||
conn1.config.RosenpassConfig = RosenpassConfig{}
|
||||
conn2.config.RosenpassConfig = RosenpassConfig{}
|
||||
|
||||
if test.conn1RosenpassEnabled {
|
||||
conn1.config.RosenpassConfig.PubKey = []byte("dummykey")
|
||||
}
|
||||
conn1.config.RosenpassConfig.PermissiveMode = test.conn1Permissive
|
||||
|
||||
if test.conn2RosenpassEnabled {
|
||||
conn2.config.RosenpassConfig.PubKey = []byte("dummykey")
|
||||
}
|
||||
conn2.config.RosenpassConfig.PermissiveMode = test.conn2Permissive
|
||||
|
||||
conn1PresharedKey := conn1.presharedKey(conn2.config.RosenpassConfig.PubKey)
|
||||
conn2PresharedKey := conn2.presharedKey(conn1.config.RosenpassConfig.PubKey)
|
||||
|
||||
if test.conn1ExpectedInitialKey {
|
||||
if conn1PresharedKey == nil {
|
||||
t.Errorf("Case %d: Expected conn1 to have a non-nil key, but got nil", tcase)
|
||||
}
|
||||
} else {
|
||||
if conn1PresharedKey != nil {
|
||||
t.Errorf("Case %d: Expected conn1 to have a nil key, but got %v", tcase, conn1PresharedKey)
|
||||
}
|
||||
}
|
||||
|
||||
// Assert conn2's key expectation
|
||||
if test.conn2ExpectedInitialKey {
|
||||
if conn2PresharedKey == nil {
|
||||
t.Errorf("Case %d: Expected conn2 to have a non-nil key, but got nil", tcase)
|
||||
}
|
||||
} else {
|
||||
if conn2PresharedKey != nil {
|
||||
t.Errorf("Case %d: Expected conn2 to have a nil key, but got %v", tcase, conn2PresharedKey)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -154,8 +154,8 @@ func (h *Handshaker) sendOffer() error {
|
||||
IceCredentials: IceCredentials{iceUFrag, icePwd},
|
||||
WgListenPort: h.config.LocalWgPort,
|
||||
Version: version.NetbirdVersion(),
|
||||
RosenpassPubKey: h.config.RosenpassPubKey,
|
||||
RosenpassAddr: h.config.RosenpassAddr,
|
||||
RosenpassPubKey: h.config.RosenpassConfig.PubKey,
|
||||
RosenpassAddr: h.config.RosenpassConfig.Addr,
|
||||
}
|
||||
|
||||
addr, err := h.relay.RelayInstanceAddress()
|
||||
@@ -174,8 +174,8 @@ func (h *Handshaker) sendAnswer() error {
|
||||
IceCredentials: IceCredentials{uFrag, pwd},
|
||||
WgListenPort: h.config.LocalWgPort,
|
||||
Version: version.NetbirdVersion(),
|
||||
RosenpassPubKey: h.config.RosenpassPubKey,
|
||||
RosenpassAddr: h.config.RosenpassAddr,
|
||||
RosenpassPubKey: h.config.RosenpassConfig.PubKey,
|
||||
RosenpassAddr: h.config.RosenpassConfig.Addr,
|
||||
}
|
||||
addr, err := h.relay.RelayInstanceAddress()
|
||||
if err == nil {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/pion/ice/v3"
|
||||
"github.com/pion/logging"
|
||||
"github.com/pion/randutil"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -35,6 +36,10 @@ func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candida
|
||||
log.Errorf("failed to create pion's stdnet: %s", err)
|
||||
}
|
||||
|
||||
fac := logging.NewDefaultLoggerFactory()
|
||||
|
||||
//fac.Writer = log.StandardLogger().Writer()
|
||||
|
||||
agentConfig := &ice.AgentConfig{
|
||||
MulticastDNSMode: ice.MulticastDNSModeDisabled,
|
||||
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
|
||||
@@ -51,6 +56,7 @@ func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candida
|
||||
RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
|
||||
LocalUfrag: ufrag,
|
||||
LocalPwd: pwd,
|
||||
LoggerFactory: fac,
|
||||
}
|
||||
|
||||
if config.DisableIPv6Discovery {
|
||||
|
||||
@@ -2,37 +2,89 @@ package peer
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// routeEntry holds the route prefix and the corresponding resource ID.
|
||||
type routeEntry struct {
|
||||
prefix netip.Prefix
|
||||
resourceID string
|
||||
}
|
||||
|
||||
type routeIDLookup struct {
|
||||
localMap sync.Map
|
||||
remoteMap sync.Map
|
||||
localRoutes []routeEntry
|
||||
localLock sync.RWMutex
|
||||
|
||||
remoteRoutes []routeEntry
|
||||
remoteLock sync.RWMutex
|
||||
|
||||
resolvedIPs sync.Map
|
||||
}
|
||||
|
||||
func (r *routeIDLookup) AddLocalRouteID(resourceID string, route netip.Prefix) {
|
||||
_, exists := r.localMap.LoadOrStore(route, resourceID)
|
||||
if exists {
|
||||
log.Tracef("resourceID %s already exists in local map", resourceID)
|
||||
r.localLock.Lock()
|
||||
defer r.localLock.Unlock()
|
||||
|
||||
// update the resource id if the route already exists.
|
||||
for i, entry := range r.localRoutes {
|
||||
if entry.prefix == route {
|
||||
r.localRoutes[i].resourceID = resourceID
|
||||
log.Tracef("resourceID for route %v updated to %s in local routes", route, resourceID)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// append and sort descending by prefix bits (more specific first)
|
||||
r.localRoutes = append(r.localRoutes, routeEntry{prefix: route, resourceID: resourceID})
|
||||
sort.Slice(r.localRoutes, func(i, j int) bool {
|
||||
return r.localRoutes[i].prefix.Bits() > r.localRoutes[j].prefix.Bits()
|
||||
})
|
||||
}
|
||||
|
||||
func (r *routeIDLookup) RemoveLocalRouteID(route netip.Prefix) {
|
||||
r.localMap.Delete(route)
|
||||
}
|
||||
r.localLock.Lock()
|
||||
defer r.localLock.Unlock()
|
||||
|
||||
func (r *routeIDLookup) AddRemoteRouteID(resourceID string, route netip.Prefix) {
|
||||
_, exists := r.remoteMap.LoadOrStore(route, resourceID)
|
||||
if exists {
|
||||
log.Tracef("resourceID %s already exists in remote map", resourceID)
|
||||
for i, entry := range r.localRoutes {
|
||||
if entry.prefix == route {
|
||||
r.localRoutes = append(r.localRoutes[:i], r.localRoutes[i+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *routeIDLookup) AddRemoteRouteID(resourceID string, route netip.Prefix) {
|
||||
r.remoteLock.Lock()
|
||||
defer r.remoteLock.Unlock()
|
||||
|
||||
for i, entry := range r.remoteRoutes {
|
||||
if entry.prefix == route {
|
||||
r.remoteRoutes[i].resourceID = resourceID
|
||||
log.Tracef("resourceID for route %v updated to %s in remote routes", route, resourceID)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// append and sort descending by prefix bits.
|
||||
r.remoteRoutes = append(r.remoteRoutes, routeEntry{prefix: route, resourceID: resourceID})
|
||||
sort.Slice(r.remoteRoutes, func(i, j int) bool {
|
||||
return r.remoteRoutes[i].prefix.Bits() > r.remoteRoutes[j].prefix.Bits()
|
||||
})
|
||||
}
|
||||
|
||||
func (r *routeIDLookup) RemoveRemoteRouteID(route netip.Prefix) {
|
||||
r.remoteMap.Delete(route)
|
||||
r.remoteLock.Lock()
|
||||
defer r.remoteLock.Unlock()
|
||||
|
||||
for i, entry := range r.remoteRoutes {
|
||||
if entry.prefix == route {
|
||||
r.remoteRoutes = append(r.remoteRoutes[:i], r.remoteRoutes[i+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *routeIDLookup) AddResolvedIP(resourceID string, route netip.Prefix) {
|
||||
@@ -44,37 +96,35 @@ func (r *routeIDLookup) RemoveResolvedIP(route netip.Prefix) {
|
||||
}
|
||||
|
||||
// Lookup returns the resource ID for the given IP address
|
||||
// and a bool indicating if the IP is an exit node
|
||||
// and a bool indicating if the IP is an exit node.
|
||||
func (r *routeIDLookup) Lookup(ip netip.Addr) (string, bool) {
|
||||
var isExitNode bool
|
||||
|
||||
resId, ok := r.resolvedIPs.Load(ip)
|
||||
if ok {
|
||||
return resId.(string), false
|
||||
if res, ok := r.resolvedIPs.Load(ip); ok {
|
||||
return res.(string), false
|
||||
}
|
||||
|
||||
var resourceID string
|
||||
r.localMap.Range(func(key, value interface{}) bool {
|
||||
pref := key.(netip.Prefix)
|
||||
if pref.Contains(ip) {
|
||||
resourceID = value.(string)
|
||||
isExitNode = pref.Bits() == 0
|
||||
return false
|
||||
var isExitNode bool
|
||||
|
||||
r.localLock.RLock()
|
||||
for _, entry := range r.localRoutes {
|
||||
if entry.prefix.Contains(ip) {
|
||||
resourceID = entry.resourceID
|
||||
isExitNode = (entry.prefix.Bits() == 0)
|
||||
break
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
r.localLock.RUnlock()
|
||||
|
||||
if resourceID == "" {
|
||||
r.remoteMap.Range(func(key, value interface{}) bool {
|
||||
pref := key.(netip.Prefix)
|
||||
if pref.Contains(ip) {
|
||||
resourceID = value.(string)
|
||||
isExitNode = pref.Bits() == 0
|
||||
return false
|
||||
r.remoteLock.RLock()
|
||||
for _, entry := range r.remoteRoutes {
|
||||
if entry.prefix.Contains(ip) {
|
||||
resourceID = entry.resourceID
|
||||
isExitNode = (entry.prefix.Bits() == 0)
|
||||
break
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
r.remoteLock.RUnlock()
|
||||
}
|
||||
|
||||
return resourceID, isExitNode
|
||||
|
||||
@@ -586,9 +586,8 @@ func (d *Status) AddLocalPeerStateRoute(route, resourceId string) {
|
||||
defer d.mux.Unlock()
|
||||
|
||||
pref, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse prefix %s: %v", route, err)
|
||||
return
|
||||
if err == nil {
|
||||
d.routeIDLookup.AddLocalRouteID(resourceId, pref)
|
||||
}
|
||||
|
||||
if d.localPeer.Routes == nil {
|
||||
@@ -596,8 +595,6 @@ func (d *Status) AddLocalPeerStateRoute(route, resourceId string) {
|
||||
}
|
||||
|
||||
d.localPeer.Routes[route] = struct{}{}
|
||||
|
||||
d.routeIDLookup.AddLocalRouteID(resourceId, pref)
|
||||
}
|
||||
|
||||
// RemoveLocalPeerStateRoute removes a route from the local peer state
|
||||
@@ -606,14 +603,33 @@ func (d *Status) RemoveLocalPeerStateRoute(route string) {
|
||||
defer d.mux.Unlock()
|
||||
|
||||
pref, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse prefix %s: %v", route, err)
|
||||
return
|
||||
if err == nil {
|
||||
d.routeIDLookup.RemoveLocalRouteID(pref)
|
||||
}
|
||||
|
||||
delete(d.localPeer.Routes, route)
|
||||
}
|
||||
|
||||
d.routeIDLookup.RemoveLocalRouteID(pref)
|
||||
// AddResolvedIPLookupEntry adds a resolved IP lookup entry
|
||||
func (d *Status) AddResolvedIPLookupEntry(route, resourceId string) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
pref, err := netip.ParsePrefix(route)
|
||||
if err == nil {
|
||||
d.routeIDLookup.AddResolvedIP(resourceId, pref)
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveResolvedIPLookupEntry removes a resolved IP lookup entry
|
||||
func (d *Status) RemoveResolvedIPLookupEntry(route string) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
pref, err := netip.ParsePrefix(route)
|
||||
if err == nil {
|
||||
d.routeIDLookup.RemoveResolvedIP(pref)
|
||||
}
|
||||
}
|
||||
|
||||
// CleanLocalPeerStateRoutes cleans all routes from the local peer state
|
||||
|
||||
@@ -32,7 +32,6 @@ type WGWatcher struct {
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
ctxLock sync.Mutex
|
||||
waitGroup sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
|
||||
@@ -48,24 +47,24 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
|
||||
func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) {
|
||||
w.log.Debugf("enable WireGuard watcher")
|
||||
w.ctxLock.Lock()
|
||||
defer w.ctxLock.Unlock()
|
||||
|
||||
if w.ctx != nil && w.ctx.Err() == nil {
|
||||
w.log.Errorf("WireGuard watcher already enabled")
|
||||
w.ctxLock.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
ctx, ctxCancel := context.WithCancel(parentCtx)
|
||||
w.ctx = ctx
|
||||
w.ctxCancel = ctxCancel
|
||||
w.ctxLock.Unlock()
|
||||
|
||||
initialHandshake, err := w.wgState()
|
||||
if err != nil {
|
||||
w.log.Warnf("failed to read initial wg stats: %v", err)
|
||||
}
|
||||
|
||||
w.waitGroup.Add(1)
|
||||
go w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake)
|
||||
w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake)
|
||||
}
|
||||
|
||||
// DisableWgWatcher stops the WireGuard watcher and wait for the watcher to exit
|
||||
@@ -81,13 +80,11 @@ func (w *WGWatcher) DisableWgWatcher() {
|
||||
|
||||
w.ctxCancel()
|
||||
w.ctxCancel = nil
|
||||
w.waitGroup.Wait()
|
||||
}
|
||||
|
||||
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
|
||||
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func(), initialHandshake time.Time) {
|
||||
w.log.Infof("WireGuard watcher started")
|
||||
defer w.waitGroup.Done()
|
||||
|
||||
timer := time.NewTimer(wgHandshakeOvertime)
|
||||
defer timer.Stop()
|
||||
|
||||
@@ -49,7 +49,7 @@ func TestWGWatcher_EnableWgWatcher(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
onDisconnected := make(chan struct{}, 1)
|
||||
watcher.EnableWgWatcher(ctx, func() {
|
||||
go watcher.EnableWgWatcher(ctx, func() {
|
||||
mlog.Infof("onDisconnectedFn")
|
||||
onDisconnected <- struct{}{}
|
||||
})
|
||||
@@ -79,10 +79,11 @@ func TestWGWatcher_ReEnable(t *testing.T) {
|
||||
|
||||
onDisconnected := make(chan struct{}, 1)
|
||||
|
||||
watcher.EnableWgWatcher(ctx, func() {})
|
||||
go watcher.EnableWgWatcher(ctx, func() {})
|
||||
time.Sleep(1 * time.Second)
|
||||
watcher.DisableWgWatcher()
|
||||
|
||||
watcher.EnableWgWatcher(ctx, func() {
|
||||
go watcher.EnableWgWatcher(ctx, func() {
|
||||
onDisconnected <- struct{}{}
|
||||
})
|
||||
|
||||
|
||||
@@ -65,6 +65,7 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *
|
||||
iFaceDiscover: ifaceDiscover,
|
||||
statusRecorder: statusRecorder,
|
||||
hasRelayOnLocally: hasRelayOnLocally,
|
||||
lastKnownState: ice.ConnectionStateDisconnected,
|
||||
}
|
||||
|
||||
localUfrag, localPwd, err := icemaker.GenerateICECredentials()
|
||||
@@ -213,7 +214,7 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
|
||||
w.lastKnownState = ice.ConnectionStateConnected
|
||||
return
|
||||
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
|
||||
if w.lastKnownState != ice.ConnectionStateDisconnected {
|
||||
if w.lastKnownState == ice.ConnectionStateConnected {
|
||||
w.lastKnownState = ice.ConnectionStateDisconnected
|
||||
w.conn.onICEStateDisconnected()
|
||||
}
|
||||
|
||||
@@ -39,6 +39,8 @@ type PKCEAuthProviderConfig struct {
|
||||
UseIDToken bool
|
||||
// ClientCertPair is used for mTLS authentication to the IDP
|
||||
ClientCertPair *tls.Certificate
|
||||
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
|
||||
DisablePromptLogin bool
|
||||
}
|
||||
|
||||
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
|
||||
@@ -97,6 +99,7 @@ func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL
|
||||
RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(),
|
||||
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
||||
ClientCertPair: clientCert,
|
||||
DisablePromptLogin: protoPKCEAuthorizationFlow.GetProviderConfig().GetDisablePromptLogin(),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
@@ -60,7 +59,7 @@ func (d *DnsInterceptor) String() string {
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) AddRoute(context.Context) error {
|
||||
d.dnsServer.RegisterHandler(d.route.Domains.ToPunycodeList(), d, nbdns.PriorityDNSRoute)
|
||||
d.dnsServer.RegisterHandler(d.route.Domains, d, nbdns.PriorityDNSRoute)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -89,7 +88,7 @@ func (d *DnsInterceptor) RemoveRoute() error {
|
||||
clear(d.interceptedDomains)
|
||||
d.mu.Unlock()
|
||||
|
||||
d.dnsServer.DeregisterHandler(d.route.Domains.ToPunycodeList(), nbdns.PriorityDNSRoute)
|
||||
d.dnsServer.DeregisterHandler(d.route.Domains, nbdns.PriorityDNSRoute)
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
@@ -142,57 +141,67 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
log.Tracef("received DNS request for domain=%s type=%v class=%v",
|
||||
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
|
||||
// pass if non A/AAAA query
|
||||
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
|
||||
d.continueToNextHandler(w, r, "non A/AAAA query")
|
||||
return
|
||||
}
|
||||
|
||||
d.mu.RLock()
|
||||
peerKey := d.currentPeerKey
|
||||
d.mu.RUnlock()
|
||||
|
||||
if peerKey == "" {
|
||||
log.Tracef("no current peer key set, letting next handler try for domain=%s", r.Question[0].Name)
|
||||
|
||||
d.continueToNextHandler(w, r, "no current peer key")
|
||||
d.writeDNSError(w, r, "no current peer key")
|
||||
return
|
||||
}
|
||||
|
||||
upstreamIP, err := d.getUpstreamIP(peerKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get upstream IP: %v", err)
|
||||
d.continueToNextHandler(w, r, fmt.Sprintf("failed to get upstream IP: %v", err))
|
||||
d.writeDNSError(w, r, fmt.Sprintf("get upstream IP: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
|
||||
if r.Extra == nil {
|
||||
r.SetEdns0(4096, false)
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
}
|
||||
|
||||
client := &dns.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
Timeout: nbdns.UpstreamTimeout,
|
||||
Net: "udp",
|
||||
}
|
||||
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
|
||||
reply, _, err := client.ExchangeContext(context.Background(), r, upstream)
|
||||
|
||||
var answer []dns.RR
|
||||
if reply != nil {
|
||||
answer = reply.Answer
|
||||
}
|
||||
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
|
||||
|
||||
reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream)
|
||||
if err != nil {
|
||||
log.Errorf("failed to exchange DNS request with %s: %v", upstream, err)
|
||||
log.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
|
||||
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
||||
log.Errorf("failed writing DNS response: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var answer []dns.RR
|
||||
if reply != nil {
|
||||
answer = reply.Answer
|
||||
}
|
||||
|
||||
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
|
||||
|
||||
reply.Id = r.Id
|
||||
if err := d.writeMsg(w, reply); err != nil {
|
||||
log.Errorf("failed writing DNS response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, reason string) {
|
||||
log.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason)
|
||||
|
||||
resp := new(dns.Msg)
|
||||
resp.SetRcode(r, dns.RcodeServerFailure)
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed to write DNS error response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// continueToNextHandler signals the handler chain to try the next handler
|
||||
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) {
|
||||
log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
|
||||
|
||||
@@ -235,7 +235,7 @@ func (r *Route) resolve(results chan resolveResult) {
|
||||
ips, err := r.getIPsFromResolver(domain)
|
||||
if err != nil {
|
||||
log.Tracef("Failed to resolve domain %s with private resolver: %v", domain.SafeString(), err)
|
||||
ips, err = net.LookupIP(string(domain))
|
||||
ips, err = net.LookupIP(domain.PunycodeString())
|
||||
if err != nil {
|
||||
results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)}
|
||||
return
|
||||
|
||||
@@ -9,5 +9,5 @@ import (
|
||||
)
|
||||
|
||||
func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
|
||||
return net.LookupIP(string(domain))
|
||||
return net.LookupIP(domain.PunycodeString())
|
||||
}
|
||||
|
||||
@@ -23,11 +23,11 @@ func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
|
||||
}
|
||||
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(dns.Fqdn(string(domain)), dns.TypeA)
|
||||
msg.SetQuestion(dns.Fqdn(domain.PunycodeString()), dns.TypeA)
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
response, _, err := privateClient.Exchange(msg, r.resolverAddr)
|
||||
response, _, err := nbdns.ExchangeWithFallback(nil, privateClient, msg, r.resolverAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("DNS query for %s failed after %s: %s ", domain.SafeString(), time.Since(startTime), err)
|
||||
}
|
||||
|
||||
@@ -55,6 +55,18 @@ func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
|
||||
delete(m.routes, routeID)
|
||||
}
|
||||
|
||||
// If routing is to be disabled, do it after routes have been removed
|
||||
// If routing is to be enabled, do it before adding new routes; addToServerNetwork needs routing to be enabled
|
||||
if len(routesMap) > 0 {
|
||||
if err := m.firewall.EnableRouting(); err != nil {
|
||||
return fmt.Errorf("enable routing: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := m.firewall.DisableRouting(); err != nil {
|
||||
return fmt.Errorf("disable routing: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
for id, newRoute := range routesMap {
|
||||
_, found := m.routes[id]
|
||||
if found {
|
||||
@@ -69,16 +81,6 @@ func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
|
||||
m.routes[id] = newRoute
|
||||
}
|
||||
|
||||
if len(m.routes) > 0 {
|
||||
if err := m.firewall.EnableRouting(); err != nil {
|
||||
return fmt.Errorf("enable routing: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := m.firewall.DisableRouting(); err != nil {
|
||||
return fmt.Errorf("disable routing: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -103,7 +105,11 @@ func (m *serverRouter) removeFromServerNetwork(route *route.Route) error {
|
||||
|
||||
delete(m.routes, route.ID)
|
||||
|
||||
m.statusRecorder.RemoveLocalPeerStateRoute(route.Network.String())
|
||||
routeStr := route.Network.String()
|
||||
if route.IsDynamic() {
|
||||
routeStr = route.Domains.SafeString()
|
||||
}
|
||||
m.statusRecorder.RemoveLocalPeerStateRoute(routeStr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -57,8 +57,8 @@ func getSetupRules() []ruleParams {
|
||||
return []ruleParams{
|
||||
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"},
|
||||
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, false, 0, "rule with suppress prefixlen v6"},
|
||||
{110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, true, -1, "rule v4 netbird"},
|
||||
{110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, true, -1, "rule v6 netbird"},
|
||||
{110, nbnet.ControlPlaneMark, NetbirdVPNTableID, netlink.FAMILY_V4, true, -1, "rule v4 netbird"},
|
||||
{110, nbnet.ControlPlaneMark, NetbirdVPNTableID, netlink.FAMILY_V6, true, -1, "rule v6 netbird"},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -10,20 +10,27 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/errors"
|
||||
route "github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type RouteSelector struct {
|
||||
mu sync.RWMutex
|
||||
selectedRoutes map[route.NetID]struct{}
|
||||
selectAll bool
|
||||
|
||||
// Indicates if new routes should be automatically selected
|
||||
includeNewRoutes bool
|
||||
|
||||
// All known routes at the time of deselection
|
||||
knownRoutes []route.NetID
|
||||
}
|
||||
|
||||
func NewRouteSelector() *RouteSelector {
|
||||
return &RouteSelector{
|
||||
selectedRoutes: map[route.NetID]struct{}{},
|
||||
// default selects all routes
|
||||
selectAll: true,
|
||||
selectedRoutes: map[route.NetID]struct{}{},
|
||||
selectAll: true,
|
||||
includeNewRoutes: false,
|
||||
knownRoutes: []route.NetID{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,6 +53,7 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
|
||||
rs.selectedRoutes[route] = struct{}{}
|
||||
}
|
||||
rs.selectAll = false
|
||||
rs.includeNewRoutes = false
|
||||
|
||||
return errors.FormatErrorOrNil(err)
|
||||
}
|
||||
@@ -57,16 +65,22 @@ func (rs *RouteSelector) SelectAllRoutes() {
|
||||
|
||||
rs.selectAll = true
|
||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||
rs.includeNewRoutes = false
|
||||
}
|
||||
|
||||
// DeselectRoutes removes specific routes from the selection.
|
||||
// If the selector is in "select all" mode, it will transition to "select specific" mode.
|
||||
// If the selector is in "select all" mode, it will transition to "select specific" mode
|
||||
// but will keep new routes selected.
|
||||
func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error {
|
||||
rs.mu.Lock()
|
||||
defer rs.mu.Unlock()
|
||||
|
||||
if rs.selectAll {
|
||||
rs.selectAll = false
|
||||
rs.includeNewRoutes = true
|
||||
rs.knownRoutes = make([]route.NetID, len(allRoutes))
|
||||
copy(rs.knownRoutes, allRoutes)
|
||||
|
||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||
for _, route := range allRoutes {
|
||||
rs.selectedRoutes[route] = struct{}{}
|
||||
@@ -92,6 +106,7 @@ func (rs *RouteSelector) DeselectAllRoutes() {
|
||||
defer rs.mu.Unlock()
|
||||
|
||||
rs.selectAll = false
|
||||
rs.includeNewRoutes = false
|
||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||
}
|
||||
|
||||
@@ -103,8 +118,20 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
||||
if rs.selectAll {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if the route exists in selectedRoutes
|
||||
_, selected := rs.selectedRoutes[routeID]
|
||||
return selected
|
||||
if selected {
|
||||
return true
|
||||
}
|
||||
|
||||
// If includeNewRoutes is true and this is a new route (not in knownRoutes),
|
||||
// then it should be selected
|
||||
if rs.includeNewRoutes && !slices.Contains(rs.knownRoutes, routeID) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// FilterSelected removes unselected routes from the provided map.
|
||||
@@ -118,7 +145,11 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
|
||||
|
||||
filtered := route.HAMap{}
|
||||
for id, rt := range routes {
|
||||
if rs.IsSelected(id.NetID()) {
|
||||
netID := id.NetID()
|
||||
_, selected := rs.selectedRoutes[netID]
|
||||
|
||||
// Include if directly selected or if it's a new route and includeNewRoutes is true
|
||||
if selected || (rs.includeNewRoutes && !slices.Contains(rs.knownRoutes, netID)) {
|
||||
filtered[id] = rt
|
||||
}
|
||||
}
|
||||
@@ -131,11 +162,15 @@ func (rs *RouteSelector) MarshalJSON() ([]byte, error) {
|
||||
defer rs.mu.RUnlock()
|
||||
|
||||
return json.Marshal(struct {
|
||||
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
|
||||
SelectAll bool `json:"select_all"`
|
||||
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
|
||||
SelectAll bool `json:"select_all"`
|
||||
IncludeNewRoutes bool `json:"include_new_routes"`
|
||||
KnownRoutes []route.NetID `json:"known_routes"`
|
||||
}{
|
||||
SelectAll: rs.selectAll,
|
||||
SelectedRoutes: rs.selectedRoutes,
|
||||
SelectAll: rs.selectAll,
|
||||
SelectedRoutes: rs.selectedRoutes,
|
||||
IncludeNewRoutes: rs.includeNewRoutes,
|
||||
KnownRoutes: rs.knownRoutes,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -149,12 +184,16 @@ func (rs *RouteSelector) UnmarshalJSON(data []byte) error {
|
||||
if len(data) == 0 || string(data) == "null" {
|
||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||
rs.selectAll = true
|
||||
rs.includeNewRoutes = false
|
||||
rs.knownRoutes = []route.NetID{}
|
||||
return nil
|
||||
}
|
||||
|
||||
var temp struct {
|
||||
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
|
||||
SelectAll bool `json:"select_all"`
|
||||
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
|
||||
SelectAll bool `json:"select_all"`
|
||||
IncludeNewRoutes bool `json:"include_new_routes"`
|
||||
KnownRoutes []route.NetID `json:"known_routes"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &temp); err != nil {
|
||||
@@ -163,10 +202,15 @@ func (rs *RouteSelector) UnmarshalJSON(data []byte) error {
|
||||
|
||||
rs.selectedRoutes = temp.SelectedRoutes
|
||||
rs.selectAll = temp.SelectAll
|
||||
rs.includeNewRoutes = temp.IncludeNewRoutes
|
||||
rs.knownRoutes = temp.KnownRoutes
|
||||
|
||||
if rs.selectedRoutes == nil {
|
||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||
}
|
||||
if rs.knownRoutes == nil {
|
||||
rs.knownRoutes = []route.NetID{}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -316,7 +316,7 @@ func TestRouteSelector_NewRoutesBehavior(t *testing.T) {
|
||||
return rs.DeselectRoutes([]route.NetID{"route1"}, initialRoutes)
|
||||
},
|
||||
// After deselecting specific routes, new routes should remain unselected
|
||||
wantNewSelected: []route.NetID{"route2", "route3"},
|
||||
wantNewSelected: []route.NetID{"route2", "route3", "route4", "route5"},
|
||||
},
|
||||
{
|
||||
name: "New routes after selecting with append",
|
||||
@@ -358,3 +358,73 @@ func TestRouteSelector_NewRoutesBehavior(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteSelector_MixedSelectionDeselection(t *testing.T) {
|
||||
allRoutes := []route.NetID{"route1", "route2", "route3"}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
routesToSelect []route.NetID
|
||||
selectAppend bool
|
||||
routesToDeselect []route.NetID
|
||||
selectFirst bool
|
||||
wantSelectedFinal []route.NetID
|
||||
}{
|
||||
{
|
||||
name: "1. Select A, then Deselect B",
|
||||
routesToSelect: []route.NetID{"route1"},
|
||||
selectAppend: false,
|
||||
routesToDeselect: []route.NetID{"route2"},
|
||||
selectFirst: true,
|
||||
wantSelectedFinal: []route.NetID{"route1"},
|
||||
},
|
||||
{
|
||||
name: "2. Select A, then Deselect A",
|
||||
routesToSelect: []route.NetID{"route1"},
|
||||
selectAppend: false,
|
||||
routesToDeselect: []route.NetID{"route1"},
|
||||
selectFirst: true,
|
||||
wantSelectedFinal: []route.NetID{},
|
||||
},
|
||||
{
|
||||
name: "3. Deselect A (from all), then Select B",
|
||||
routesToSelect: []route.NetID{"route2"},
|
||||
selectAppend: false,
|
||||
routesToDeselect: []route.NetID{"route1"},
|
||||
selectFirst: false,
|
||||
wantSelectedFinal: []route.NetID{"route2"},
|
||||
},
|
||||
{
|
||||
name: "4. Deselect A (from all), then Select A",
|
||||
routesToSelect: []route.NetID{"route1"},
|
||||
selectAppend: false,
|
||||
routesToDeselect: []route.NetID{"route1"},
|
||||
selectFirst: false,
|
||||
wantSelectedFinal: []route.NetID{"route1"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
|
||||
var err1, err2 error
|
||||
|
||||
if tt.selectFirst {
|
||||
err1 = rs.SelectRoutes(tt.routesToSelect, tt.selectAppend, allRoutes)
|
||||
require.NoError(t, err1)
|
||||
err2 = rs.DeselectRoutes(tt.routesToDeselect, allRoutes)
|
||||
require.NoError(t, err2)
|
||||
} else {
|
||||
err1 = rs.DeselectRoutes(tt.routesToDeselect, allRoutes)
|
||||
require.NoError(t, err1)
|
||||
err2 = rs.SelectRoutes(tt.routesToSelect, tt.selectAppend, allRoutes)
|
||||
require.NoError(t, err2)
|
||||
}
|
||||
|
||||
for _, r := range allRoutes {
|
||||
assert.Equal(t, slices.Contains(tt.wantSelectedFinal, r), rs.IsSelected(r), "Route %s final state mismatch", r)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,15 +0,0 @@
|
||||
//go:build !linux || android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// collectFirewallRules returns nothing on non-linux systems
|
||||
func (s *Server) addFirewallRules(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
|
||||
return nil
|
||||
}
|
||||
@@ -100,7 +100,7 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
|
||||
|
||||
// Convert to proto format
|
||||
for domain, ips := range domainMap {
|
||||
pbRoute.ResolvedIPs[string(domain)] = &proto.IPList{
|
||||
pbRoute.ResolvedIPs[domain.PunycodeString()] = &proto.IPList{
|
||||
Ips: ips,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -84,6 +84,7 @@ func New(ctx context.Context, configPath, logFile string) *Server {
|
||||
},
|
||||
logFile: logFile,
|
||||
persistNetworkMap: true,
|
||||
statusRecorder: peer.NewRecorder(""),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -136,9 +137,6 @@ func (s *Server) Start() error {
|
||||
|
||||
s.config = config
|
||||
|
||||
if s.statusRecorder == nil {
|
||||
s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
|
||||
}
|
||||
s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String())
|
||||
s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive)
|
||||
|
||||
@@ -622,9 +620,6 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
|
||||
return nil, fmt.Errorf("config is not defined, please call login command first")
|
||||
}
|
||||
|
||||
if s.statusRecorder == nil {
|
||||
s.statusRecorder = peer.NewRecorder(s.config.ManagementURL.String())
|
||||
}
|
||||
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
|
||||
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
|
||||
|
||||
@@ -692,9 +687,6 @@ func (s *Server) Status(
|
||||
|
||||
statusResponse := proto.StatusResponse{Status: string(status), DaemonVersion: version.NetbirdVersion()}
|
||||
|
||||
if s.statusRecorder == nil {
|
||||
s.statusRecorder = peer.NewRecorder(s.config.ManagementURL.String())
|
||||
}
|
||||
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
|
||||
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -11,19 +12,24 @@ import (
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
daemonProto "github.com/netbirdio/netbird/client/proto"
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/signal/proto"
|
||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||
)
|
||||
@@ -83,6 +89,72 @@ func TestConnectWithRetryRuns(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_Up(t *testing.T) {
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
|
||||
s := New(ctx, t.TempDir()+"/config.json", "console")
|
||||
|
||||
err := s.Start()
|
||||
require.NoError(t, err)
|
||||
|
||||
u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345")
|
||||
require.NoError(t, err)
|
||||
s.config = &internal.Config{
|
||||
ManagementURL: u,
|
||||
}
|
||||
|
||||
upCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
upReq := &daemonProto.UpRequest{}
|
||||
_, err = s.Up(upCtx, upReq)
|
||||
|
||||
assert.Contains(t, err.Error(), "NeedsLogin")
|
||||
}
|
||||
|
||||
type mockSubscribeEventsServer struct {
|
||||
ctx context.Context
|
||||
sentEvents []*daemonProto.SystemEvent
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
func (m *mockSubscribeEventsServer) Send(event *daemonProto.SystemEvent) error {
|
||||
m.sentEvents = append(m.sentEvents, event)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSubscribeEventsServer) Context() context.Context {
|
||||
return m.ctx
|
||||
}
|
||||
|
||||
func TestServer_SubcribeEvents(t *testing.T) {
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
|
||||
s := New(ctx, t.TempDir()+"/config.json", "console")
|
||||
|
||||
err := s.Start()
|
||||
require.NoError(t, err)
|
||||
|
||||
u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345")
|
||||
require.NoError(t, err)
|
||||
s.config = &internal.Config{
|
||||
ManagementURL: u,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
upReq := &daemonProto.SubscribeRequest{}
|
||||
mockServer := &mockSubscribeEventsServer{
|
||||
ctx: ctx,
|
||||
sentEvents: make([]*daemonProto.SystemEvent, 0),
|
||||
ServerStream: nil,
|
||||
}
|
||||
err = s.SubscribeEvents(upReq, mockServer)
|
||||
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
type mockServer struct {
|
||||
mgmtProto.ManagementServiceServer
|
||||
counter *int
|
||||
@@ -97,10 +169,10 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
|
||||
config := &server.Config{
|
||||
Stuns: []*server.Host{},
|
||||
TURNConfig: &server.TURNConfig{},
|
||||
Signal: &server.Host{
|
||||
config := &types.Config{
|
||||
Stuns: []*types.Host{},
|
||||
TURNConfig: &types.TURNConfig{},
|
||||
Signal: &types.Host{
|
||||
Proto: "http",
|
||||
URI: signalAddr,
|
||||
},
|
||||
@@ -132,8 +204,9 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -13,10 +13,12 @@ import (
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/flow/proto"
|
||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||
@@ -77,17 +79,24 @@ func (c *GRPCClient) Close() error {
|
||||
defer c.streamMu.Unlock()
|
||||
|
||||
c.stream = nil
|
||||
return c.clientConn.Close()
|
||||
if err := c.clientConn.Close(); err != nil && !errors.Is(err, context.Canceled) {
|
||||
return fmt.Errorf("close client connection: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *GRPCClient) Receive(ctx context.Context, interval time.Duration, msgHandler func(msg *proto.FlowEventAck) error) error {
|
||||
backOff := defaultBackoff(ctx, interval)
|
||||
operation := func() error {
|
||||
err := c.establishStreamAndReceive(ctx, msgHandler)
|
||||
if err != nil {
|
||||
if err := c.establishStreamAndReceive(ctx, msgHandler); err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled {
|
||||
return fmt.Errorf("receive: %w: %w", err, context.Canceled)
|
||||
}
|
||||
log.Errorf("receive failed: %v", err)
|
||||
return fmt.Errorf("receive: %w", err)
|
||||
}
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := backoff.Retry(operation, backOff); err != nil {
|
||||
|
||||
256
flow/client/client_test.go
Normal file
256
flow/client/client_test.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package client_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
flow "github.com/netbirdio/netbird/flow/client"
|
||||
"github.com/netbirdio/netbird/flow/proto"
|
||||
)
|
||||
|
||||
type testServer struct {
|
||||
proto.UnimplementedFlowServiceServer
|
||||
events chan *proto.FlowEvent
|
||||
acks chan *proto.FlowEventAck
|
||||
grpcSrv *grpc.Server
|
||||
addr string
|
||||
}
|
||||
|
||||
func newTestServer(t *testing.T) *testServer {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
s := &testServer{
|
||||
events: make(chan *proto.FlowEvent, 100),
|
||||
acks: make(chan *proto.FlowEventAck, 100),
|
||||
grpcSrv: grpc.NewServer(),
|
||||
addr: listener.Addr().String(),
|
||||
}
|
||||
|
||||
proto.RegisterFlowServiceServer(s.grpcSrv, s)
|
||||
|
||||
go func() {
|
||||
if err := s.grpcSrv.Serve(listener); err != nil && !errors.Is(err, grpc.ErrServerStopped) {
|
||||
t.Logf("server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
t.Cleanup(func() {
|
||||
s.grpcSrv.Stop()
|
||||
})
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *testServer) Events(stream proto.FlowService_EventsServer) error {
|
||||
err := stream.Send(&proto.FlowEventAck{IsInitiator: true})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(stream.Context())
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
defer cancel()
|
||||
for {
|
||||
event, err := stream.Recv()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !event.IsInitiator {
|
||||
select {
|
||||
case s.events <- event:
|
||||
ack := &proto.FlowEventAck{
|
||||
EventId: event.EventId,
|
||||
}
|
||||
select {
|
||||
case s.acks <- ack:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case ack := <-s.acks:
|
||||
if err := stream.Send(ack); err != nil {
|
||||
return err
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReceive(t *testing.T) {
|
||||
server := newTestServer(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err, "failed to close flow")
|
||||
})
|
||||
|
||||
receivedAcks := make(map[string]bool)
|
||||
receiveDone := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
||||
if !msg.IsInitiator && len(msg.EventId) > 0 {
|
||||
id := string(msg.EventId)
|
||||
receivedAcks[id] = true
|
||||
|
||||
if len(receivedAcks) >= 3 {
|
||||
close(receiveDone)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
t.Logf("receive error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
eventID := uuid.New().String()
|
||||
|
||||
// Create acknowledgment and send it to the flow through our test server
|
||||
ack := &proto.FlowEventAck{
|
||||
EventId: []byte(eventID),
|
||||
}
|
||||
|
||||
select {
|
||||
case server.acks <- ack:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout sending ack")
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-receiveDone:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for acks to be processed")
|
||||
}
|
||||
|
||||
assert.Equal(t, 3, len(receivedAcks))
|
||||
}
|
||||
|
||||
func TestReceive_ContextCancellation(t *testing.T) {
|
||||
server := newTestServer(t)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err, "failed to close flow")
|
||||
})
|
||||
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
handlerCalled := false
|
||||
msgHandler := func(msg *proto.FlowEventAck) error {
|
||||
if !msg.IsInitiator {
|
||||
handlerCalled = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
err = client.Receive(ctx, 1*time.Second, msgHandler)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "context canceled")
|
||||
assert.False(t, handlerCalled)
|
||||
}
|
||||
|
||||
func TestSend(t *testing.T) {
|
||||
server := newTestServer(t)
|
||||
|
||||
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err, "failed to close flow")
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
ackReceived := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
err := client.Receive(ctx, 1*time.Second, func(ack *proto.FlowEventAck) error {
|
||||
if len(ack.EventId) > 0 && !ack.IsInitiator {
|
||||
close(ackReceived)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
t.Logf("receive error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
testEvent := &proto.FlowEvent{
|
||||
EventId: []byte("test-event-id"),
|
||||
PublicKey: []byte("test-public-key"),
|
||||
FlowFields: &proto.FlowFields{
|
||||
FlowId: []byte("test-flow-id"),
|
||||
Protocol: 6,
|
||||
SourceIp: []byte{192, 168, 1, 1},
|
||||
DestIp: []byte{192, 168, 1, 2},
|
||||
ConnectionInfo: &proto.FlowFields_PortInfo{
|
||||
PortInfo: &proto.PortInfo{
|
||||
SourcePort: 12345,
|
||||
DestPort: 443,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err = client.Send(testEvent)
|
||||
require.NoError(t, err)
|
||||
|
||||
var receivedEvent *proto.FlowEvent
|
||||
select {
|
||||
case receivedEvent = <-server.events:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for event to be received by server")
|
||||
}
|
||||
|
||||
assert.Equal(t, testEvent.EventId, receivedEvent.EventId)
|
||||
assert.Equal(t, testEvent.PublicKey, receivedEvent.PublicKey)
|
||||
|
||||
select {
|
||||
case <-ackReceived:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for ack to be received by flow")
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package hook
|
||||
import (
|
||||
"fmt"
|
||||
"path"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
|
||||
@@ -40,8 +41,12 @@ func (hook ContextHook) Levels() []logrus.Level {
|
||||
|
||||
// Fire extend with the source information the entry.Data
|
||||
func (hook ContextHook) Fire(entry *logrus.Entry) error {
|
||||
src := hook.parseSrc(entry.Caller.File)
|
||||
entry.Data[EntryKeySource] = fmt.Sprintf("%s:%v", src, entry.Caller.Line)
|
||||
caller := &runtime.Frame{Line: 0, File: "caller_not_available"}
|
||||
if entry.Caller != nil {
|
||||
caller = entry.Caller
|
||||
}
|
||||
src := hook.parseSrc(caller.File)
|
||||
entry.Data[EntryKeySource] = fmt.Sprintf("%s:%v", src, caller.Line)
|
||||
additionalEntries(entry)
|
||||
|
||||
if entry.Context == nil {
|
||||
|
||||
17
go.mod
17
go.mod
@@ -25,7 +25,7 @@ require (
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
google.golang.org/grpc v1.64.1
|
||||
google.golang.org/protobuf v1.35.2
|
||||
google.golang.org/protobuf v1.36.5
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.0.0
|
||||
)
|
||||
|
||||
@@ -47,7 +47,7 @@ require (
|
||||
github.com/gliderlabs/ssh v0.3.8
|
||||
github.com/godbus/dbus/v5 v5.1.0
|
||||
github.com/golang/mock v1.6.0
|
||||
github.com/google/go-cmp v0.6.0
|
||||
github.com/google/go-cmp v0.7.0
|
||||
github.com/google/gopacket v1.1.19
|
||||
github.com/google/nftables v0.2.0
|
||||
github.com/gopacket/gopacket v1.1.1
|
||||
@@ -62,7 +62,7 @@ require (
|
||||
github.com/miekg/dns v1.1.59
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/nadoo/ipset v0.5.0
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250320152138-69b93e4ef939
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-7901e0a82203
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d
|
||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||
github.com/oschwald/maxminddb-golang v1.12.0
|
||||
@@ -73,7 +73,7 @@ require (
|
||||
github.com/pion/stun/v2 v2.0.0
|
||||
github.com/pion/transport/v3 v3.0.1
|
||||
github.com/pion/turn/v3 v3.0.1
|
||||
github.com/prometheus/client_golang v1.19.1
|
||||
github.com/prometheus/client_golang v1.22.0
|
||||
github.com/quic-go/quic-go v0.48.2
|
||||
github.com/redis/go-redis/v9 v9.7.1
|
||||
github.com/rs/xid v1.3.0
|
||||
@@ -101,7 +101,7 @@ require (
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
|
||||
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
|
||||
golang.org/x/net v0.36.0
|
||||
golang.org/x/oauth2 v0.19.0
|
||||
golang.org/x/oauth2 v0.24.0
|
||||
golang.org/x/sync v0.12.0
|
||||
golang.org/x/term v0.30.0
|
||||
google.golang.org/api v0.177.0
|
||||
@@ -186,7 +186,7 @@ require (
|
||||
github.com/josharian/native v1.1.0 // indirect
|
||||
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e // indirect
|
||||
github.com/kelseyhightower/envconfig v1.4.0 // indirect
|
||||
github.com/klauspost/compress v1.17.8 // indirect
|
||||
github.com/klauspost/compress v1.18.0 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
||||
github.com/libdns/libdns v0.2.2 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect
|
||||
@@ -201,6 +201,7 @@ require (
|
||||
github.com/moby/sys/userns v0.1.0 // indirect
|
||||
github.com/moby/term v0.5.0 // indirect
|
||||
github.com/morikuni/aec v1.0.0 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/nicksnyder/go-i18n/v2 v2.4.0 // indirect
|
||||
github.com/nxadm/tail v1.4.8 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
|
||||
@@ -213,8 +214,8 @@ require (
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
|
||||
github.com/prometheus/client_model v0.6.1 // indirect
|
||||
github.com/prometheus/common v0.53.0 // indirect
|
||||
github.com/prometheus/procfs v0.15.0 // indirect
|
||||
github.com/prometheus/common v0.62.0 // indirect
|
||||
github.com/prometheus/procfs v0.15.1 // indirect
|
||||
github.com/rymdport/portal v0.3.0 // indirect
|
||||
github.com/shoenig/go-m1cpu v0.1.6 // indirect
|
||||
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect
|
||||
|
||||
35
go.sum
35
go.sum
@@ -292,8 +292,9 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
|
||||
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
|
||||
@@ -411,8 +412,8 @@ github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dv
|
||||
github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg=
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU=
|
||||
github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||
@@ -424,6 +425,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
@@ -482,6 +485,8 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN
|
||||
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc=
|
||||
github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ=
|
||||
github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo=
|
||||
@@ -490,8 +495,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
|
||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250320152138-69b93e4ef939 h1:OsLDdb6ekNaCVSyD+omhio2DECfEqLjCA1zo4HrgGqU=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250320152138-69b93e4ef939/go.mod h1:3LvBPnW+i06K9fQr1SYwsbhvnxQHtIC8vvO4PjLmmy0=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-7901e0a82203 h1:uxxbLPXQgC9VO15epNPtrD6zazyd5rZeqC5hQSmCdZU=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-7901e0a82203/go.mod h1:2ZE6/tBBCKHQggPfO2UOQjyjXI7k+JDVl2ymorTOVQs=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
|
||||
@@ -560,15 +565,15 @@ github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndr
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU=
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||
github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE=
|
||||
github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho=
|
||||
github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q=
|
||||
github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0=
|
||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
|
||||
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
|
||||
github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+aLCE=
|
||||
github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U=
|
||||
github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek=
|
||||
github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk=
|
||||
github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
|
||||
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
|
||||
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||
github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE=
|
||||
github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs=
|
||||
github.com/redis/go-redis/v9 v9.7.1 h1:4LhKRCIduqXqtvCUlaq9c8bdHOkICjDMrr1+Zb3osAc=
|
||||
@@ -856,8 +861,8 @@ golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ
|
||||
golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
|
||||
golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
|
||||
golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE=
|
||||
golang.org/x/oauth2 v0.19.0 h1:9+E/EZBCbTLNrbN35fHv/a/d/mOBatymz1zbtQrXpIg=
|
||||
golang.org/x/oauth2 v0.19.0/go.mod h1:vYi7skDa1x015PmRRYZ7+s1cWyPgrPiSYRe4rnsexc8=
|
||||
golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE=
|
||||
golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -1152,8 +1157,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0
|
||||
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
|
||||
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
google.golang.org/protobuf v1.35.2 h1:8Ar7bF+apOIoThw1EdZl0p1oWvMqTHmpA2fRTyZO8io=
|
||||
google.golang.org/protobuf v1.35.2/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
||||
google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=
|
||||
google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
||||
@@ -58,6 +58,7 @@ NETBIRD_TOKEN_SOURCE=${NETBIRD_TOKEN_SOURCE:-accessToken}
|
||||
# PKCE authorization flow
|
||||
NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS=${NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS:-"53000"}
|
||||
NETBIRD_AUTH_PKCE_USE_ID_TOKEN=${NETBIRD_AUTH_PKCE_USE_ID_TOKEN:-false}
|
||||
NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN=${NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN:-false}
|
||||
NETBIRD_AUTH_PKCE_AUDIENCE=$NETBIRD_AUTH_AUDIENCE
|
||||
|
||||
# Dashboard
|
||||
@@ -120,6 +121,7 @@ export NETBIRD_AUTH_DEVICE_AUTH_SCOPE
|
||||
export NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN
|
||||
export NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT
|
||||
export NETBIRD_AUTH_PKCE_USE_ID_TOKEN
|
||||
export NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN
|
||||
export NETBIRD_AUTH_PKCE_AUDIENCE
|
||||
export NETBIRD_DASH_AUTH_USE_AUDIENCE
|
||||
export NETBIRD_DASH_AUTH_AUDIENCE
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user