mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-29 12:09:59 +00:00
Compare commits
3 Commits
feature/an
...
feature/us
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8d3e5f508c | ||
|
|
8d09ded1db | ||
|
|
a49a052f05 |
@@ -31,7 +31,7 @@ jobs:
|
|||||||
while IFS= read -r dir; do
|
while IFS= read -r dir; do
|
||||||
echo "=== Checking $dir ==="
|
echo "=== Checking $dir ==="
|
||||||
# Search for problematic imports, excluding test files
|
# Search for problematic imports, excluding test files
|
||||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" | grep -v "tools/idp-migrate/" || true)
|
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||||
if [ -n "$RESULTS" ]; then
|
if [ -n "$RESULTS" ]; then
|
||||||
echo "❌ Found problematic dependencies:"
|
echo "❌ Found problematic dependencies:"
|
||||||
echo "$RESULTS"
|
echo "$RESULTS"
|
||||||
@@ -88,7 +88,7 @@ jobs:
|
|||||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||||
|
|
||||||
# Check if any importer is NOT in management/signal/relay
|
# Check if any importer is NOT in management/signal/relay
|
||||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
|
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\)" | head -1)
|
||||||
|
|
||||||
if [ -n "$BSD_IMPORTER" ]; then
|
if [ -n "$BSD_IMPORTER" ]; then
|
||||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||||
|
|||||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
|||||||
- name: codespell
|
- name: codespell
|
||||||
uses: codespell-project/actions-codespell@v2
|
uses: codespell-project/actions-codespell@v2
|
||||||
with:
|
with:
|
||||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA
|
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te
|
||||||
skip: go.mod,go.sum,**/proxy/web/**
|
skip: go.mod,go.sum,**/proxy/web/**
|
||||||
golangci:
|
golangci:
|
||||||
strategy:
|
strategy:
|
||||||
|
|||||||
62
.github/workflows/proto-version-check.yml
vendored
62
.github/workflows/proto-version-check.yml
vendored
@@ -1,62 +0,0 @@
|
|||||||
name: Proto Version Check
|
|
||||||
|
|
||||||
on:
|
|
||||||
pull_request:
|
|
||||||
paths:
|
|
||||||
- "**/*.pb.go"
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
check-proto-versions:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Check for proto tool version changes
|
|
||||||
uses: actions/github-script@v7
|
|
||||||
with:
|
|
||||||
script: |
|
|
||||||
const files = await github.paginate(github.rest.pulls.listFiles, {
|
|
||||||
owner: context.repo.owner,
|
|
||||||
repo: context.repo.repo,
|
|
||||||
pull_number: context.issue.number,
|
|
||||||
per_page: 100,
|
|
||||||
});
|
|
||||||
|
|
||||||
const pbFiles = files.filter(f => f.filename.endsWith('.pb.go'));
|
|
||||||
const missingPatch = pbFiles.filter(f => !f.patch).map(f => f.filename);
|
|
||||||
if (missingPatch.length > 0) {
|
|
||||||
core.setFailed(
|
|
||||||
`Cannot inspect patch data for:\n` +
|
|
||||||
missingPatch.map(f => `- ${f}`).join('\n') +
|
|
||||||
`\nThis can happen with very large PRs. Verify proto versions manually.`
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const versionPattern = /^[+-]\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
|
|
||||||
const violations = [];
|
|
||||||
|
|
||||||
for (const file of pbFiles) {
|
|
||||||
const changed = file.patch
|
|
||||||
.split('\n')
|
|
||||||
.filter(line => versionPattern.test(line));
|
|
||||||
if (changed.length > 0) {
|
|
||||||
violations.push({
|
|
||||||
file: file.filename,
|
|
||||||
lines: changed,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (violations.length > 0) {
|
|
||||||
const details = violations.map(v =>
|
|
||||||
`${v.file}:\n${v.lines.map(l => ' ' + l).join('\n')}`
|
|
||||||
).join('\n\n');
|
|
||||||
|
|
||||||
core.setFailed(
|
|
||||||
`Proto version strings changed in generated files.\n` +
|
|
||||||
`This usually means the wrong protoc or protoc-gen-go version was used.\n` +
|
|
||||||
`Regenerate with the matching tool versions.\n\n` +
|
|
||||||
details
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
console.log('No proto version string changes detected');
|
|
||||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.1.2"
|
SIGN_PIPE_VER: "v0.1.1"
|
||||||
GORELEASER_VER: "v2.14.3"
|
GORELEASER_VER: "v2.14.3"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "NetBird GmbH"
|
COPYRIGHT: "NetBird GmbH"
|
||||||
|
|||||||
@@ -154,26 +154,6 @@ builds:
|
|||||||
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
|
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
|
||||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
- id: netbird-idp-migrate
|
|
||||||
dir: tools/idp-migrate
|
|
||||||
env:
|
|
||||||
- CGO_ENABLED=1
|
|
||||||
- >-
|
|
||||||
{{- if eq .Runtime.Goos "linux" }}
|
|
||||||
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
|
|
||||||
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
|
|
||||||
{{- end }}
|
|
||||||
binary: netbird-idp-migrate
|
|
||||||
goos:
|
|
||||||
- linux
|
|
||||||
goarch:
|
|
||||||
- amd64
|
|
||||||
- arm64
|
|
||||||
- arm
|
|
||||||
ldflags:
|
|
||||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
|
||||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
|
||||||
|
|
||||||
universal_binaries:
|
universal_binaries:
|
||||||
- id: netbird
|
- id: netbird
|
||||||
|
|
||||||
@@ -186,10 +166,6 @@ archives:
|
|||||||
- netbird-wasm
|
- netbird-wasm
|
||||||
name_template: "{{ .ProjectName }}_{{ .Version }}"
|
name_template: "{{ .ProjectName }}_{{ .Version }}"
|
||||||
format: binary
|
format: binary
|
||||||
- id: netbird-idp-migrate
|
|
||||||
builds:
|
|
||||||
- netbird-idp-migrate
|
|
||||||
name_template: "netbird-idp-migrate_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
|
|
||||||
|
|
||||||
nfpms:
|
nfpms:
|
||||||
- maintainer: Netbird <dev@netbird.io>
|
- maintainer: Netbird <dev@netbird.io>
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
## Contributor License Agreement
|
## Contributor License Agreement
|
||||||
|
|
||||||
This Contributor License Agreement (referred to as the "Agreement") is entered into by the individual
|
This Contributor License Agreement (referred to as the "Agreement") is entered into by the individual
|
||||||
submitting this Agreement and NetBird GmbH, Brunnenstraße 196, 10119 Berlin, Germany,
|
submitting this Agreement and NetBird GmbH, c/o Max-Beer-Straße 2-4 Münzstraße 12 10178 Berlin, Germany,
|
||||||
referred to as "NetBird" (collectively, the "Parties"). The Agreement outlines the terms and conditions
|
referred to as "NetBird" (collectively, the "Parties"). The Agreement outlines the terms and conditions
|
||||||
under which NetBird may utilize software contributions provided by the Contributor for inclusion in
|
under which NetBird may utilize software contributions provided by the Contributor for inclusion in
|
||||||
its software development projects. By submitting this Agreement, the Contributor confirms their acceptance
|
its software development projects. By submitting this Agreement, the Contributor confirms their acceptance
|
||||||
|
|||||||
2
Makefile
2
Makefile
@@ -5,7 +5,7 @@ GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
|
|||||||
$(GOLANGCI_LINT):
|
$(GOLANGCI_LINT):
|
||||||
@echo "Installing golangci-lint..."
|
@echo "Installing golangci-lint..."
|
||||||
@mkdir -p ./bin
|
@mkdir -p ./bin
|
||||||
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest
|
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||||
|
|
||||||
# Lint only changed files (fast, for pre-push)
|
# Lint only changed files (fast, for pre-push)
|
||||||
lint: $(GOLANGCI_LINT)
|
lint: $(GOLANGCI_LINT)
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"slices"
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
@@ -16,7 +15,6 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/debug"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
@@ -28,7 +26,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
types "github.com/netbirdio/netbird/upload-server/types"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnectionListener export internal Listener for mobile
|
// ConnectionListener export internal Listener for mobile
|
||||||
@@ -71,30 +68,7 @@ type Client struct {
|
|||||||
uiVersion string
|
uiVersion string
|
||||||
networkChangeListener listener.NetworkChangeListener
|
networkChangeListener listener.NetworkChangeListener
|
||||||
|
|
||||||
stateMu sync.RWMutex
|
|
||||||
connectClient *internal.ConnectClient
|
connectClient *internal.ConnectClient
|
||||||
config *profilemanager.Config
|
|
||||||
cacheDir string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) setState(cfg *profilemanager.Config, cacheDir string, cc *internal.ConnectClient) {
|
|
||||||
c.stateMu.Lock()
|
|
||||||
defer c.stateMu.Unlock()
|
|
||||||
c.config = cfg
|
|
||||||
c.cacheDir = cacheDir
|
|
||||||
c.connectClient = cc
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) stateSnapshot() (*profilemanager.Config, string, *internal.ConnectClient) {
|
|
||||||
c.stateMu.RLock()
|
|
||||||
defer c.stateMu.RUnlock()
|
|
||||||
return c.config, c.cacheDir, c.connectClient
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) getConnectClient() *internal.ConnectClient {
|
|
||||||
c.stateMu.RLock()
|
|
||||||
defer c.stateMu.RUnlock()
|
|
||||||
return c.connectClient
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient instantiate a new Client
|
// NewClient instantiate a new Client
|
||||||
@@ -119,7 +93,6 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
|
|||||||
|
|
||||||
cfgFile := platformFiles.ConfigurationFilePath()
|
cfgFile := platformFiles.ConfigurationFilePath()
|
||||||
stateFile := platformFiles.StateFilePath()
|
stateFile := platformFiles.StateFilePath()
|
||||||
cacheDir := platformFiles.CacheDir()
|
|
||||||
|
|
||||||
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
|
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
|
||||||
|
|
||||||
@@ -151,9 +124,8 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
|
|||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
c.setState(cfg, cacheDir, connectClient)
|
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
|
||||||
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||||
@@ -163,7 +135,6 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
|
|||||||
|
|
||||||
cfgFile := platformFiles.ConfigurationFilePath()
|
cfgFile := platformFiles.ConfigurationFilePath()
|
||||||
stateFile := platformFiles.StateFilePath()
|
stateFile := platformFiles.StateFilePath()
|
||||||
cacheDir := platformFiles.CacheDir()
|
|
||||||
|
|
||||||
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
|
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
|
||||||
|
|
||||||
@@ -186,9 +157,8 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
|
|||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
c.setState(cfg, cacheDir, connectClient)
|
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
|
||||||
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the internal client and free the resources
|
// Stop the internal client and free the resources
|
||||||
@@ -203,12 +173,11 @@ func (c *Client) Stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) RenewTun(fd int) error {
|
func (c *Client) RenewTun(fd int) error {
|
||||||
cc := c.getConnectClient()
|
if c.connectClient == nil {
|
||||||
if cc == nil {
|
|
||||||
return fmt.Errorf("engine not running")
|
return fmt.Errorf("engine not running")
|
||||||
}
|
}
|
||||||
|
|
||||||
e := cc.Engine()
|
e := c.connectClient.Engine()
|
||||||
if e == nil {
|
if e == nil {
|
||||||
return fmt.Errorf("engine not initialized")
|
return fmt.Errorf("engine not initialized")
|
||||||
}
|
}
|
||||||
@@ -216,73 +185,6 @@ func (c *Client) RenewTun(fd int) error {
|
|||||||
return e.RenewTun(fd)
|
return e.RenewTun(fd)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DebugBundle generates a debug bundle, uploads it, and returns the upload key.
|
|
||||||
// It works both with and without a running engine.
|
|
||||||
func (c *Client) DebugBundle(platformFiles PlatformFiles, anonymize bool) (string, error) {
|
|
||||||
cfg, cacheDir, cc := c.stateSnapshot()
|
|
||||||
|
|
||||||
// If the engine hasn't been started, load config from disk
|
|
||||||
if cfg == nil {
|
|
||||||
var err error
|
|
||||||
cfg, err = profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
|
||||||
ConfigPath: platformFiles.ConfigurationFilePath(),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("load config: %w", err)
|
|
||||||
}
|
|
||||||
cacheDir = platformFiles.CacheDir()
|
|
||||||
}
|
|
||||||
|
|
||||||
deps := debug.GeneratorDependencies{
|
|
||||||
InternalConfig: cfg,
|
|
||||||
StatusRecorder: c.recorder,
|
|
||||||
TempDir: cacheDir,
|
|
||||||
}
|
|
||||||
|
|
||||||
if cc != nil {
|
|
||||||
resp, err := cc.GetLatestSyncResponse()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("get latest sync response: %v", err)
|
|
||||||
}
|
|
||||||
deps.SyncResponse = resp
|
|
||||||
|
|
||||||
if e := cc.Engine(); e != nil {
|
|
||||||
if cm := e.GetClientMetrics(); cm != nil {
|
|
||||||
deps.ClientMetrics = cm
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bundleGenerator := debug.NewBundleGenerator(
|
|
||||||
deps,
|
|
||||||
debug.BundleConfig{
|
|
||||||
Anonymize: anonymize,
|
|
||||||
IncludeSystemInfo: true,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
path, err := bundleGenerator.Generate()
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("generate debug bundle: %w", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err := os.Remove(path); err != nil {
|
|
||||||
log.Errorf("failed to remove debug bundle file: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
uploadCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
key, err := debug.UploadDebugBundle(uploadCtx, types.DefaultBundleURL, cfg.ManagementURL.String(), path)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("upload debug bundle: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("debug bundle uploaded with key %s", key)
|
|
||||||
return key, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetTraceLogLevel configure the logger to trace level
|
// SetTraceLogLevel configure the logger to trace level
|
||||||
func (c *Client) SetTraceLogLevel() {
|
func (c *Client) SetTraceLogLevel() {
|
||||||
log.SetLevel(log.TraceLevel)
|
log.SetLevel(log.TraceLevel)
|
||||||
@@ -303,7 +205,7 @@ func (c *Client) PeersList() *PeerInfoArray {
|
|||||||
pi := PeerInfo{
|
pi := PeerInfo{
|
||||||
p.IP,
|
p.IP,
|
||||||
p.FQDN,
|
p.FQDN,
|
||||||
int(p.ConnStatus),
|
p.ConnStatus.String(),
|
||||||
PeerRoutes{routes: maps.Keys(p.GetRoutes())},
|
PeerRoutes{routes: maps.Keys(p.GetRoutes())},
|
||||||
}
|
}
|
||||||
peerInfos[n] = pi
|
peerInfos[n] = pi
|
||||||
@@ -312,13 +214,12 @@ func (c *Client) PeersList() *PeerInfoArray {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Networks() *NetworkArray {
|
func (c *Client) Networks() *NetworkArray {
|
||||||
cc := c.getConnectClient()
|
if c.connectClient == nil {
|
||||||
if cc == nil {
|
|
||||||
log.Error("not connected")
|
log.Error("not connected")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
engine := cc.Engine()
|
engine := c.connectClient.Engine()
|
||||||
if engine == nil {
|
if engine == nil {
|
||||||
log.Error("could not get engine")
|
log.Error("could not get engine")
|
||||||
return nil
|
return nil
|
||||||
@@ -399,7 +300,7 @@ func (c *Client) toggleRoute(command routeCommand) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) getRouteManager() (routemanager.Manager, error) {
|
func (c *Client) getRouteManager() (routemanager.Manager, error) {
|
||||||
client := c.getConnectClient()
|
client := c.connectClient
|
||||||
if client == nil {
|
if client == nil {
|
||||||
return nil, fmt.Errorf("not connected")
|
return nil, fmt.Errorf("not connected")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,20 +2,11 @@
|
|||||||
|
|
||||||
package android
|
package android
|
||||||
|
|
||||||
import "github.com/netbirdio/netbird/client/internal/peer"
|
|
||||||
|
|
||||||
// Connection status constants exported via gomobile.
|
|
||||||
const (
|
|
||||||
ConnStatusIdle = int(peer.StatusIdle)
|
|
||||||
ConnStatusConnecting = int(peer.StatusConnecting)
|
|
||||||
ConnStatusConnected = int(peer.StatusConnected)
|
|
||||||
)
|
|
||||||
|
|
||||||
// PeerInfo describe information about the peers. It designed for the UI usage
|
// PeerInfo describe information about the peers. It designed for the UI usage
|
||||||
type PeerInfo struct {
|
type PeerInfo struct {
|
||||||
IP string
|
IP string
|
||||||
FQDN string
|
FQDN string
|
||||||
ConnStatus int
|
ConnStatus string // Todo replace to enum
|
||||||
Routes PeerRoutes
|
Routes PeerRoutes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,5 +7,4 @@ package android
|
|||||||
type PlatformFiles interface {
|
type PlatformFiles interface {
|
||||||
ConfigurationFilePath() string
|
ConfigurationFilePath() string
|
||||||
StateFilePath() string
|
StateFilePath() string
|
||||||
CacheDir() string
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,434 +0,0 @@
|
|||||||
//go:build android
|
|
||||||
|
|
||||||
package android
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"strconv"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
gossh "golang.org/x/crypto/ssh"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/auth"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
|
||||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
sshDialTimeout = 30 * time.Second
|
|
||||||
sshDetectionTimeout = 5 * time.Second
|
|
||||||
)
|
|
||||||
|
|
||||||
// SSHTerminalListener receives SSH session events. It is implemented in Java.
|
|
||||||
//
|
|
||||||
// All callbacks are invoked from goroutines and may run concurrently with each
|
|
||||||
// other; the implementation must be safe to call from any thread.
|
|
||||||
type SSHTerminalListener interface {
|
|
||||||
OnConnected()
|
|
||||||
OnData(data []byte)
|
|
||||||
OnClose(reason string)
|
|
||||||
OnError(message string)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SSHClient is a NetBird-aware SSH client exposed to Java via gomobile.
|
|
||||||
//
|
|
||||||
// It dials through the running NetBird tunnel and runs a standard SSH session
|
|
||||||
// on top with PTY enabled. Host-key verification uses the NetBird-provided
|
|
||||||
// peer SSH host keys, identical to the desktop client.
|
|
||||||
type SSHClient struct {
|
|
||||||
nb *Client
|
|
||||||
mu sync.Mutex
|
|
||||||
listener SSHTerminalListener
|
|
||||||
urlOpener URLOpener
|
|
||||||
|
|
||||||
sshClient *gossh.Client
|
|
||||||
session *gossh.Session
|
|
||||||
stdin io.WriteCloser
|
|
||||||
closed bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSSHClient creates a new SSH client bound to the running NetBird Client.
|
|
||||||
func NewSSHClient(c *Client) *SSHClient {
|
|
||||||
return &SSHClient{nb: c}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetListener registers the Java listener. Must be called before Connect to
|
|
||||||
// receive any events.
|
|
||||||
func (s *SSHClient) SetListener(l SSHTerminalListener) {
|
|
||||||
s.mu.Lock()
|
|
||||||
s.listener = l
|
|
||||||
s.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetURLOpener registers the Java URL opener used to display the device-code
|
|
||||||
// authorization page in a Custom Tabs window when the target peer requires
|
|
||||||
// JWT authentication. Must be set before Connect to be effective.
|
|
||||||
func (s *SSHClient) SetURLOpener(opener URLOpener) {
|
|
||||||
s.mu.Lock()
|
|
||||||
s.urlOpener = opener
|
|
||||||
s.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Connect dials the SSH server through the NetBird tunnel and performs the
|
|
||||||
// SSH handshake. It auto-detects the server type via SSH banner inspection
|
|
||||||
// and selects the appropriate authentication path:
|
|
||||||
//
|
|
||||||
// - NetBird-SSH server requiring JWT: launches the OAuth 2.0 device-code
|
|
||||||
// flow, opens the verification URL through the registered URLOpener, and
|
|
||||||
// uses the resulting token as the SSH password. Host-key verification
|
|
||||||
// uses the NetBird peer registry.
|
|
||||||
// - NetBird-SSH server without JWT: authenticates with the NetBird SSH
|
|
||||||
// private key. Host-key verification uses the NetBird peer registry.
|
|
||||||
// - Regular SSH server (e.g. OpenSSH): authenticates with the NetBird key
|
|
||||||
// first (so a user-installed NetBird public key works), then falls back
|
|
||||||
// to the supplied password if non-empty. Host-key verification is
|
|
||||||
// disabled (TOFU pending).
|
|
||||||
//
|
|
||||||
// The password parameter is only consulted for regular SSH servers.
|
|
||||||
func (s *SSHClient) Connect(host string, port int, user, password string) error {
|
|
||||||
cfg, _, cc := s.nb.stateSnapshot()
|
|
||||||
if cc == nil {
|
|
||||||
return errors.New("netbird client not running")
|
|
||||||
}
|
|
||||||
if cfg == nil {
|
|
||||||
return errors.New("netbird config not loaded")
|
|
||||||
}
|
|
||||||
engine := cc.Engine()
|
|
||||||
if engine == nil {
|
|
||||||
return errors.New("netbird engine not available")
|
|
||||||
}
|
|
||||||
|
|
||||||
serverType := detectServerType(host, port)
|
|
||||||
log.Infof("SSH server type for %s:%d: %s", host, port, serverType)
|
|
||||||
|
|
||||||
authMethods, hostKeyCallback, err := s.buildAuth(cfg, engine, serverType, password)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
clientConfig := &gossh.ClientConfig{
|
|
||||||
User: user,
|
|
||||||
Auth: authMethods,
|
|
||||||
HostKeyCallback: hostKeyCallback,
|
|
||||||
Timeout: sshDialTimeout,
|
|
||||||
}
|
|
||||||
return s.dialAndHandshake(host, port, clientConfig)
|
|
||||||
}
|
|
||||||
|
|
||||||
// StartSession requests a PTY and starts an interactive shell. Output from
|
|
||||||
// the session is forwarded to the listener via OnData.
|
|
||||||
func (s *SSHClient) StartSession(cols, rows int) error {
|
|
||||||
log.Debugf("SSH: starting session %dx%d", cols, rows)
|
|
||||||
s.mu.Lock()
|
|
||||||
sshClient := s.sshClient
|
|
||||||
s.mu.Unlock()
|
|
||||||
|
|
||||||
if sshClient == nil {
|
|
||||||
return errors.New("ssh client not connected")
|
|
||||||
}
|
|
||||||
|
|
||||||
session, err := sshClient.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("new session: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
modes := gossh.TerminalModes{
|
|
||||||
gossh.ECHO: 1,
|
|
||||||
gossh.TTY_OP_ISPEED: 14400,
|
|
||||||
gossh.TTY_OP_OSPEED: 14400,
|
|
||||||
gossh.VINTR: 3,
|
|
||||||
gossh.VQUIT: 28,
|
|
||||||
gossh.VERASE: 127,
|
|
||||||
}
|
|
||||||
if err := session.RequestPty("xterm-256color", rows, cols, modes); err != nil {
|
|
||||||
closeQuiet(session, "session after pty error")
|
|
||||||
return fmt.Errorf("request pty: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
stdin, err := session.StdinPipe()
|
|
||||||
if err != nil {
|
|
||||||
closeQuiet(session, "session after stdin error")
|
|
||||||
return fmt.Errorf("stdin pipe: %w", err)
|
|
||||||
}
|
|
||||||
stdout, err := session.StdoutPipe()
|
|
||||||
if err != nil {
|
|
||||||
closeQuiet(session, "session after stdout error")
|
|
||||||
return fmt.Errorf("stdout pipe: %w", err)
|
|
||||||
}
|
|
||||||
stderr, err := session.StderrPipe()
|
|
||||||
if err != nil {
|
|
||||||
closeQuiet(session, "session after stderr error")
|
|
||||||
return fmt.Errorf("stderr pipe: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := session.Shell(); err != nil {
|
|
||||||
closeQuiet(session, "session after shell error")
|
|
||||||
return fmt.Errorf("start shell: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.mu.Lock()
|
|
||||||
s.session = session
|
|
||||||
s.stdin = stdin
|
|
||||||
s.mu.Unlock()
|
|
||||||
|
|
||||||
go s.readLoop(stdout, "stdout")
|
|
||||||
go s.readLoop(stderr, "stderr")
|
|
||||||
log.Debug("SSH: session started, shell running")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write sends data to the SSH session stdin.
|
|
||||||
func (s *SSHClient) Write(data []byte) error {
|
|
||||||
s.mu.Lock()
|
|
||||||
stdin := s.stdin
|
|
||||||
s.mu.Unlock()
|
|
||||||
if stdin == nil {
|
|
||||||
return errors.New("ssh session not started")
|
|
||||||
}
|
|
||||||
if _, err := stdin.Write(data); err != nil {
|
|
||||||
return fmt.Errorf("write stdin: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Resize updates the PTY window size.
|
|
||||||
func (s *SSHClient) Resize(cols, rows int) error {
|
|
||||||
s.mu.Lock()
|
|
||||||
session := s.session
|
|
||||||
s.mu.Unlock()
|
|
||||||
if session == nil {
|
|
||||||
return errors.New("ssh session not started")
|
|
||||||
}
|
|
||||||
return session.WindowChange(rows, cols)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close terminates the SSH session and underlying connection. Safe to call
|
|
||||||
// multiple times.
|
|
||||||
func (s *SSHClient) Close() error {
|
|
||||||
s.mu.Lock()
|
|
||||||
sshClient := s.sshClient
|
|
||||||
session := s.session
|
|
||||||
stdin := s.stdin
|
|
||||||
s.sshClient = nil
|
|
||||||
s.session = nil
|
|
||||||
s.stdin = nil
|
|
||||||
s.mu.Unlock()
|
|
||||||
|
|
||||||
if stdin != nil {
|
|
||||||
if err := stdin.Close(); err != nil {
|
|
||||||
log.Debugf("ssh: stdin close: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if session != nil {
|
|
||||||
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
|
|
||||||
log.Debugf("ssh: session close: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var firstErr error
|
|
||||||
if sshClient != nil {
|
|
||||||
if err := sshClient.Close(); err != nil {
|
|
||||||
firstErr = err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
s.notifyClose("closed by client")
|
|
||||||
return firstErr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SSHClient) buildAuth(cfg *profilemanager.Config, engine *internal.Engine,
|
|
||||||
serverType detection.ServerType, password string) ([]gossh.AuthMethod, gossh.HostKeyCallback, error) {
|
|
||||||
|
|
||||||
switch serverType {
|
|
||||||
case detection.ServerTypeNetBirdJWT:
|
|
||||||
token, err := s.requestJWTToken(cfg)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("jwt: %w", err)
|
|
||||||
}
|
|
||||||
auths := []gossh.AuthMethod{gossh.Password(token)}
|
|
||||||
return auths, nbssh.CreateHostKeyCallback(&engineHostKeyVerifier{engine: engine}), nil
|
|
||||||
|
|
||||||
case detection.ServerTypeNetBirdNoJWT:
|
|
||||||
if cfg.SSHKey == "" {
|
|
||||||
return nil, nil, errors.New("no NetBird SSH key available")
|
|
||||||
}
|
|
||||||
signer, err := gossh.ParsePrivateKey([]byte(cfg.SSHKey))
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("parse netbird ssh key: %w", err)
|
|
||||||
}
|
|
||||||
auths := []gossh.AuthMethod{gossh.PublicKeys(signer)}
|
|
||||||
return auths, nbssh.CreateHostKeyCallback(&engineHostKeyVerifier{engine: engine}), nil
|
|
||||||
|
|
||||||
default: // regular SSH
|
|
||||||
var auths []gossh.AuthMethod
|
|
||||||
if cfg.SSHKey != "" {
|
|
||||||
if signer, err := gossh.ParsePrivateKey([]byte(cfg.SSHKey)); err == nil {
|
|
||||||
auths = append(auths, gossh.PublicKeys(signer))
|
|
||||||
} else {
|
|
||||||
log.Debugf("ssh: parse netbird key for regular auth: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if password != "" {
|
|
||||||
pw := password
|
|
||||||
auths = append(auths, gossh.Password(pw))
|
|
||||||
auths = append(auths, gossh.KeyboardInteractive(func(_, _ string, questions []string, _ []bool) ([]string, error) {
|
|
||||||
answers := make([]string, len(questions))
|
|
||||||
for i := range questions {
|
|
||||||
answers[i] = pw
|
|
||||||
}
|
|
||||||
return answers, nil
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
if len(auths) == 0 {
|
|
||||||
return nil, nil, errors.New("no auth method available: provide a password or configure NetBird SSH key")
|
|
||||||
}
|
|
||||||
return auths, gossh.InsecureIgnoreHostKey(), nil // nolint:gosec // TOFU not yet implemented
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SSHClient) requestJWTToken(cfg *profilemanager.Config) (string, error) {
|
|
||||||
s.mu.Lock()
|
|
||||||
urlOpener := s.urlOpener
|
|
||||||
s.mu.Unlock()
|
|
||||||
if urlOpener == nil {
|
|
||||||
return "", errors.New("URL opener not configured for JWT auth")
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
flow, err := auth.NewOAuthFlow(ctx, cfg, false, true, profilemanager.GetLoginHint())
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("create oauth flow: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
flowInfo, err := flow.RequestAuthInfo(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("request auth info: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
|
||||||
|
|
||||||
tokenInfo, err := flow.WaitToken(ctx, flowInfo)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("wait for token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
token := tokenInfo.GetTokenToUse()
|
|
||||||
if token == "" {
|
|
||||||
return "", errors.New("empty token returned by IdP")
|
|
||||||
}
|
|
||||||
return token, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SSHClient) dialAndHandshake(host string, port int, clientConfig *gossh.ClientConfig) error {
|
|
||||||
addr := net.JoinHostPort(host, strconv.Itoa(port))
|
|
||||||
log.Infof("SSH: connecting to %s as %s", addr, clientConfig.User)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), sshDialTimeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
var dialer net.Dialer
|
|
||||||
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("dial %s: %w", addr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
sshConn, chans, reqs, err := gossh.NewClientConn(conn, addr, clientConfig)
|
|
||||||
if err != nil {
|
|
||||||
if cerr := conn.Close(); cerr != nil {
|
|
||||||
log.Debugf("ssh: close after handshake error: %v", cerr)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("ssh handshake: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.mu.Lock()
|
|
||||||
s.sshClient = gossh.NewClient(sshConn, chans, reqs)
|
|
||||||
listener := s.listener
|
|
||||||
s.mu.Unlock()
|
|
||||||
|
|
||||||
log.Infof("SSH: connected to %s", addr)
|
|
||||||
if listener != nil {
|
|
||||||
listener.OnConnected()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SSHClient) readLoop(r io.Reader, name string) {
|
|
||||||
buf := make([]byte, 4096)
|
|
||||||
for {
|
|
||||||
n, err := r.Read(buf)
|
|
||||||
if n > 0 {
|
|
||||||
s.mu.Lock()
|
|
||||||
listener := s.listener
|
|
||||||
s.mu.Unlock()
|
|
||||||
if listener != nil {
|
|
||||||
chunk := make([]byte, n)
|
|
||||||
copy(chunk, buf[:n])
|
|
||||||
listener.OnData(chunk)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
if !errors.Is(err, io.EOF) {
|
|
||||||
log.Debugf("ssh %s read: %v", name, err)
|
|
||||||
}
|
|
||||||
s.notifyClose(err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SSHClient) notifyClose(reason string) {
|
|
||||||
s.mu.Lock()
|
|
||||||
if s.closed {
|
|
||||||
s.mu.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.closed = true
|
|
||||||
listener := s.listener
|
|
||||||
s.mu.Unlock()
|
|
||||||
if listener != nil {
|
|
||||||
listener.OnClose(reason)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// engineHostKeyVerifier adapts *internal.Engine to nbssh.HostKeyVerifier.
|
|
||||||
type engineHostKeyVerifier struct {
|
|
||||||
engine *internal.Engine
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *engineHostKeyVerifier) VerifySSHHostKey(peerAddress string, presented []byte) error {
|
|
||||||
storedKey, found := v.engine.GetPeerSSHKey(peerAddress)
|
|
||||||
if !found {
|
|
||||||
return nbssh.ErrPeerNotFound
|
|
||||||
}
|
|
||||||
return nbssh.VerifyHostKey(storedKey, presented, peerAddress)
|
|
||||||
}
|
|
||||||
|
|
||||||
func closeQuiet(c io.Closer, label string) {
|
|
||||||
if c == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := c.Close(); err != nil && !errors.Is(err, io.EOF) {
|
|
||||||
log.Debugf("ssh: close %s: %v", label, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func detectServerType(host string, port int) detection.ServerType {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), sshDetectionTimeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
dialer := &net.Dialer{}
|
|
||||||
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("ssh: server detection for %s:%d failed: %v (assuming regular SSH)", host, port, err)
|
|
||||||
return detection.ServerTypeRegular
|
|
||||||
}
|
|
||||||
return serverType
|
|
||||||
}
|
|
||||||
@@ -199,11 +199,9 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
cmd.Println("Log level set to trace.")
|
cmd.Println("Log level set to trace.")
|
||||||
}
|
}
|
||||||
|
|
||||||
needsRestoreUp := false
|
|
||||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||||
cmd.PrintErrf("Failed to bring service down: %v\n", status.Convert(err).Message())
|
cmd.PrintErrf("Failed to bring service down: %v\n", status.Convert(err).Message())
|
||||||
} else {
|
} else {
|
||||||
needsRestoreUp = !stateWasDown
|
|
||||||
cmd.Println("netbird down")
|
cmd.Println("netbird down")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -219,7 +217,6 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||||
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
|
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
|
||||||
} else {
|
} else {
|
||||||
needsRestoreUp = false
|
|
||||||
cmd.Println("netbird up")
|
cmd.Println("netbird up")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -267,14 +264,6 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||||
}
|
}
|
||||||
|
|
||||||
if needsRestoreUp {
|
|
||||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
|
||||||
cmd.PrintErrf("Failed to restore service up state: %v\n", status.Convert(err).Message())
|
|
||||||
} else {
|
|
||||||
cmd.Println("netbird up (restored)")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if stateWasDown {
|
if stateWasDown {
|
||||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||||
cmd.PrintErrf("Failed to restore service down state: %v\n", status.Convert(err).Message())
|
cmd.PrintErrf("Failed to restore service down state: %v\n", status.Convert(err).Message())
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/expose"
|
"github.com/netbirdio/netbird/client/internal/expose"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
@@ -202,7 +201,7 @@ func exposeFn(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
stream, err := client.ExposeService(ctx, req)
|
stream, err := client.ExposeService(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("expose service: %v", status.Convert(err).Message())
|
return fmt.Errorf("expose service: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := handleExposeReady(cmd, stream, port); err != nil {
|
if err := handleExposeReady(cmd, stream, port); err != nil {
|
||||||
@@ -237,7 +236,7 @@ func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) {
|
|||||||
func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error {
|
func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error {
|
||||||
event, err := stream.Recv()
|
event, err := stream.Recv()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("receive expose event: %v", status.Convert(err).Message())
|
return fmt.Errorf("receive expose event: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ready, ok := event.Event.(*proto.ExposeServiceEvent_Ready)
|
ready, ok := event.Event.(*proto.ExposeServiceEvent_Ready)
|
||||||
|
|||||||
@@ -75,7 +75,6 @@ var (
|
|||||||
mtu uint16
|
mtu uint16
|
||||||
profilesDisabled bool
|
profilesDisabled bool
|
||||||
updateSettingsDisabled bool
|
updateSettingsDisabled bool
|
||||||
networksDisabled bool
|
|
||||||
|
|
||||||
rootCmd = &cobra.Command{
|
rootCmd = &cobra.Command{
|
||||||
Use: "netbird",
|
Use: "netbird",
|
||||||
|
|||||||
@@ -44,13 +44,10 @@ func init() {
|
|||||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd)
|
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd)
|
||||||
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
|
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
|
||||||
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
|
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
|
||||||
serviceCmd.PersistentFlags().BoolVar(&networksDisabled, "disable-networks", false, "Disables network selection. If enabled, the client will not allow listing, selecting, or deselecting networks. To persist, use: netbird service install --disable-networks")
|
|
||||||
|
|
||||||
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
|
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
|
||||||
serviceEnvDesc := `Sets extra environment variables for the service. ` +
|
serviceEnvDesc := `Sets extra environment variables for the service. ` +
|
||||||
`You can specify a comma-separated list of KEY=VALUE pairs. ` +
|
`You can specify a comma-separated list of KEY=VALUE pairs. ` +
|
||||||
`New keys are merged with previously saved env vars; existing keys are overwritten. ` +
|
|
||||||
`Use --service-env "" to clear all saved env vars. ` +
|
|
||||||
`E.g. --service-env NB_LOG_LEVEL=debug,CUSTOM_VAR=value`
|
`E.g. --service-env NB_LOG_LEVEL=debug,CUSTOM_VAR=value`
|
||||||
|
|
||||||
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
|
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ func (p *program) Start(svc service.Service) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, networksDisabled)
|
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled)
|
||||||
if err := serverInstance.Start(); err != nil {
|
if err := serverInstance.Start(); err != nil {
|
||||||
log.Fatalf("failed to start daemon: %v", err)
|
log.Fatalf("failed to start daemon: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -59,10 +59,6 @@ func buildServiceArguments() []string {
|
|||||||
args = append(args, "--disable-update-settings")
|
args = append(args, "--disable-update-settings")
|
||||||
}
|
}
|
||||||
|
|
||||||
if networksDisabled {
|
|
||||||
args = append(args, "--disable-networks")
|
|
||||||
}
|
|
||||||
|
|
||||||
return args
|
return args
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ type serviceParams struct {
|
|||||||
LogFiles []string `json:"log_files,omitempty"`
|
LogFiles []string `json:"log_files,omitempty"`
|
||||||
DisableProfiles bool `json:"disable_profiles,omitempty"`
|
DisableProfiles bool `json:"disable_profiles,omitempty"`
|
||||||
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
|
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
|
||||||
DisableNetworks bool `json:"disable_networks,omitempty"`
|
|
||||||
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
|
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -79,12 +78,11 @@ func currentServiceParams() *serviceParams {
|
|||||||
LogFiles: logFiles,
|
LogFiles: logFiles,
|
||||||
DisableProfiles: profilesDisabled,
|
DisableProfiles: profilesDisabled,
|
||||||
DisableUpdateSettings: updateSettingsDisabled,
|
DisableUpdateSettings: updateSettingsDisabled,
|
||||||
DisableNetworks: networksDisabled,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(serviceEnvVars) > 0 {
|
if len(serviceEnvVars) > 0 {
|
||||||
parsed, err := parseServiceEnvVars(serviceEnvVars)
|
parsed, err := parseServiceEnvVars(serviceEnvVars)
|
||||||
if err == nil {
|
if err == nil && len(parsed) > 0 {
|
||||||
params.ServiceEnvVars = parsed
|
params.ServiceEnvVars = parsed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -144,46 +142,31 @@ func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
|
|||||||
updateSettingsDisabled = params.DisableUpdateSettings
|
updateSettingsDisabled = params.DisableUpdateSettings
|
||||||
}
|
}
|
||||||
|
|
||||||
if !serviceCmd.PersistentFlags().Changed("disable-networks") {
|
|
||||||
networksDisabled = params.DisableNetworks
|
|
||||||
}
|
|
||||||
|
|
||||||
applyServiceEnvParams(cmd, params)
|
applyServiceEnvParams(cmd, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
// applyServiceEnvParams merges saved service environment variables.
|
// applyServiceEnvParams merges saved service environment variables.
|
||||||
// If --service-env was explicitly set with values, explicit values win on key
|
// If --service-env was explicitly set, explicit values win on key conflict
|
||||||
// conflict but saved keys not in the explicit set are carried over.
|
// but saved keys not in the explicit set are carried over.
|
||||||
// If --service-env was explicitly set to empty, all saved env vars are cleared.
|
|
||||||
// If --service-env was not set, saved env vars are used entirely.
|
// If --service-env was not set, saved env vars are used entirely.
|
||||||
func applyServiceEnvParams(cmd *cobra.Command, params *serviceParams) {
|
func applyServiceEnvParams(cmd *cobra.Command, params *serviceParams) {
|
||||||
if !cmd.Flags().Changed("service-env") {
|
if len(params.ServiceEnvVars) == 0 {
|
||||||
if len(params.ServiceEnvVars) > 0 {
|
|
||||||
// No explicit env vars: rebuild serviceEnvVars from saved params.
|
|
||||||
serviceEnvVars = envMapToSlice(params.ServiceEnvVars)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flag was explicitly set: parse what the user provided.
|
if !cmd.Flags().Changed("service-env") {
|
||||||
|
// No explicit env vars: rebuild serviceEnvVars from saved params.
|
||||||
|
serviceEnvVars = envMapToSlice(params.ServiceEnvVars)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Explicit env vars were provided: merge saved values underneath.
|
||||||
explicit, err := parseServiceEnvVars(serviceEnvVars)
|
explicit, err := parseServiceEnvVars(serviceEnvVars)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cmd.PrintErrf("Warning: parse explicit service env vars for merge: %v\n", err)
|
cmd.PrintErrf("Warning: parse explicit service env vars for merge: %v\n", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the user passed an empty value (e.g. --service-env ""), clear all
|
|
||||||
// saved env vars rather than merging.
|
|
||||||
if len(explicit) == 0 {
|
|
||||||
serviceEnvVars = nil
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(params.ServiceEnvVars) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Merge saved values underneath explicit ones.
|
|
||||||
merged := make(map[string]string, len(params.ServiceEnvVars)+len(explicit))
|
merged := make(map[string]string, len(params.ServiceEnvVars)+len(explicit))
|
||||||
maps.Copy(merged, params.ServiceEnvVars)
|
maps.Copy(merged, params.ServiceEnvVars)
|
||||||
maps.Copy(merged, explicit) // explicit wins on conflict
|
maps.Copy(merged, explicit) // explicit wins on conflict
|
||||||
|
|||||||
@@ -25,10 +25,10 @@ func TestServiceParamsPath(t *testing.T) {
|
|||||||
t.Cleanup(func() { configs.StateDir = original })
|
t.Cleanup(func() { configs.StateDir = original })
|
||||||
|
|
||||||
configs.StateDir = "/var/lib/netbird"
|
configs.StateDir = "/var/lib/netbird"
|
||||||
assert.Equal(t, filepath.Join("/var/lib/netbird", "service.json"), serviceParamsPath())
|
assert.Equal(t, "/var/lib/netbird/service.json", serviceParamsPath())
|
||||||
|
|
||||||
configs.StateDir = "/custom/state"
|
configs.StateDir = "/custom/state"
|
||||||
assert.Equal(t, filepath.Join("/custom/state", "service.json"), serviceParamsPath())
|
assert.Equal(t, "/custom/state/service.json", serviceParamsPath())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSaveAndLoadServiceParams(t *testing.T) {
|
func TestSaveAndLoadServiceParams(t *testing.T) {
|
||||||
@@ -327,41 +327,6 @@ func TestApplyServiceEnvParams_NotChanged(t *testing.T) {
|
|||||||
assert.Equal(t, map[string]string{"FROM_SAVED": "val"}, result)
|
assert.Equal(t, map[string]string{"FROM_SAVED": "val"}, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestApplyServiceEnvParams_ExplicitEmptyClears(t *testing.T) {
|
|
||||||
origServiceEnvVars := serviceEnvVars
|
|
||||||
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
|
|
||||||
|
|
||||||
// Simulate --service-env "" which produces [""] in the slice.
|
|
||||||
serviceEnvVars = []string{""}
|
|
||||||
|
|
||||||
cmd := &cobra.Command{}
|
|
||||||
cmd.Flags().StringSlice("service-env", nil, "")
|
|
||||||
require.NoError(t, cmd.Flags().Set("service-env", ""))
|
|
||||||
|
|
||||||
saved := &serviceParams{
|
|
||||||
ServiceEnvVars: map[string]string{"OLD_VAR": "should_be_cleared"},
|
|
||||||
}
|
|
||||||
|
|
||||||
applyServiceEnvParams(cmd, saved)
|
|
||||||
|
|
||||||
assert.Nil(t, serviceEnvVars, "explicit empty --service-env should clear all saved env vars")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCurrentServiceParams_EmptyEnvVarsAfterParse(t *testing.T) {
|
|
||||||
origServiceEnvVars := serviceEnvVars
|
|
||||||
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
|
|
||||||
|
|
||||||
// Simulate --service-env "" which produces [""] in the slice.
|
|
||||||
serviceEnvVars = []string{""}
|
|
||||||
|
|
||||||
params := currentServiceParams()
|
|
||||||
|
|
||||||
// After parsing, the empty string is skipped, resulting in an empty map.
|
|
||||||
// The map should still be set (not nil) so it overwrites saved values.
|
|
||||||
assert.NotNil(t, params.ServiceEnvVars, "empty env vars should produce empty map, not nil")
|
|
||||||
assert.Empty(t, params.ServiceEnvVars, "no valid env vars should be parsed from empty string")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestServiceParams_FieldsCoveredInFunctions ensures that all serviceParams fields are
|
// TestServiceParams_FieldsCoveredInFunctions ensures that all serviceParams fields are
|
||||||
// referenced in both currentServiceParams() and applyServiceParams(). If a new field is
|
// referenced in both currentServiceParams() and applyServiceParams(). If a new field is
|
||||||
// added to serviceParams but not wired into these functions, this test fails.
|
// added to serviceParams but not wired into these functions, this test fails.
|
||||||
@@ -535,7 +500,6 @@ func fieldToGlobalVar(field string) string {
|
|||||||
"LogFiles": "logFiles",
|
"LogFiles": "logFiles",
|
||||||
"DisableProfiles": "profilesDisabled",
|
"DisableProfiles": "profilesDisabled",
|
||||||
"DisableUpdateSettings": "updateSettingsDisabled",
|
"DisableUpdateSettings": "updateSettingsDisabled",
|
||||||
"DisableNetworks": "networksDisabled",
|
|
||||||
"ServiceEnvVars": "serviceEnvVars",
|
"ServiceEnvVars": "serviceEnvVars",
|
||||||
}
|
}
|
||||||
if v, ok := m[field]; ok {
|
if v, ok := m[field]; ok {
|
||||||
|
|||||||
@@ -4,9 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
"syscall"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -15,22 +13,6 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestMain intercepts when this test binary is run as a daemon subprocess.
|
|
||||||
// On FreeBSD, the rc.d service script runs the binary via daemon(8) -r with
|
|
||||||
// "service run ..." arguments. Since the test binary can't handle cobra CLI
|
|
||||||
// args, it exits immediately, causing daemon -r to respawn rapidly until
|
|
||||||
// hitting the rate limit and exiting. This makes service restart unreliable.
|
|
||||||
// Blocking here keeps the subprocess alive until the init system sends SIGTERM.
|
|
||||||
func TestMain(m *testing.M) {
|
|
||||||
if len(os.Args) > 2 && os.Args[1] == "service" && os.Args[2] == "run" {
|
|
||||||
sig := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(sig, syscall.SIGTERM, os.Interrupt)
|
|
||||||
<-sig
|
|
||||||
return
|
|
||||||
}
|
|
||||||
os.Exit(m.Run())
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
serviceStartTimeout = 10 * time.Second
|
serviceStartTimeout = 10 * time.Second
|
||||||
serviceStopTimeout = 5 * time.Second
|
serviceStopTimeout = 5 * time.Second
|
||||||
@@ -97,34 +79,6 @@ func TestServiceLifecycle(t *testing.T) {
|
|||||||
logLevel = "info"
|
logLevel = "info"
|
||||||
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
|
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
|
||||||
|
|
||||||
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
|
|
||||||
t.Cleanup(func() {
|
|
||||||
cfg, err := newSVCConfig()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("cleanup: create service config: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("cleanup: create service: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the subtests already cleaned up, there's nothing to do.
|
|
||||||
if _, err := s.Status(); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.Stop(); err != nil {
|
|
||||||
t.Errorf("cleanup: stop service: %v", err)
|
|
||||||
}
|
|
||||||
if err := s.Uninstall(); err != nil {
|
|
||||||
t.Errorf("cleanup: uninstall service: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
t.Run("Install", func(t *testing.T) {
|
t.Run("Install", func(t *testing.T) {
|
||||||
|
|||||||
@@ -13,8 +13,6 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
@@ -102,16 +100,9 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
|||||||
|
|
||||||
jobManager := job.NewJobManager(nil, store, peersmanager)
|
jobManager := job.NewJobManager(nil, store, peersmanager)
|
||||||
|
|
||||||
ctx := context.Background()
|
iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
|
||||||
|
|
||||||
cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100)
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
iv, _ := integrations.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
|
|
||||||
|
|
||||||
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
settingsMockManager := settings.NewMockManager(ctrl)
|
settingsMockManager := settings.NewMockManager(ctrl)
|
||||||
@@ -122,11 +113,12 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
|||||||
Return(&types.Settings{}, nil).
|
Return(&types.Settings{}, nil).
|
||||||
AnyTimes()
|
AnyTimes()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||||
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
|
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
|
||||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
|
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
|
||||||
|
|
||||||
accountManager, err := mgmt.BuildManager(ctx, config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
|
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -160,7 +152,7 @@ func startClientDaemon(
|
|||||||
s := grpc.NewServer()
|
s := grpc.NewServer()
|
||||||
|
|
||||||
server := client.New(ctx,
|
server := client.New(ctx,
|
||||||
"", "", false, false, false)
|
"", "", false, false)
|
||||||
if err := server.Start(); err != nil {
|
if err := server.Start(); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
@@ -36,34 +35,20 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
|||||||
type FWType int
|
type FWType int
|
||||||
|
|
||||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
|
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
|
||||||
// We run in userspace mode and force userspace firewall was requested. We don't attempt native firewall.
|
// on the linux system we try to user nftables or iptables
|
||||||
if iface.IsUserspaceBind() && forceUserspaceFirewall() {
|
// in any case, because we need to allow netbird interface traffic
|
||||||
log.Info("forcing userspace firewall")
|
// so we use AllowNetbird traffic from these firewall managers
|
||||||
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
|
// for the userspace packet filtering firewall
|
||||||
}
|
|
||||||
|
|
||||||
// Use native firewall for either kernel or userspace, the interface appears identical to netfilter
|
|
||||||
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
|
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
|
||||||
|
|
||||||
// Kernel cannot fall back to anything else, need to return error
|
|
||||||
if !iface.IsUserspaceBind() {
|
if !iface.IsUserspaceBind() {
|
||||||
return fm, err
|
return fm, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fall back to the userspace packet filter if native is unavailable
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||||
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
|
|
||||||
}
|
}
|
||||||
|
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
|
||||||
// Native firewall handles packet filtering, but the userspace WireGuard bind
|
|
||||||
// needs a device filter for DNS interception hooks. Install a minimal
|
|
||||||
// hooks-only filter that passes all traffic through to the kernel firewall.
|
|
||||||
if err := iface.SetFilter(&uspfilter.HooksFilter{}); err != nil {
|
|
||||||
log.Warnf("failed to set hooks filter, DNS via memory hooks will not work: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return fm, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) {
|
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) {
|
||||||
@@ -175,17 +160,3 @@ func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
|||||||
_, err := client.ListChains("filter")
|
_, err := client.ListChains("filter")
|
||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func forceUserspaceFirewall() bool {
|
|
||||||
val := os.Getenv(EnvForceUserspaceFirewall)
|
|
||||||
if val == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
force, err := strconv.ParseBool(val)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to parse %s: %v", EnvForceUserspaceFirewall, err)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return force
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -7,12 +7,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
// EnvForceUserspaceFirewall forces the use of the userspace packet filter even when
|
|
||||||
// native iptables/nftables is available. This only applies when the WireGuard interface
|
|
||||||
// runs in userspace mode. When set, peer ACLs are handled by USPFilter instead of
|
|
||||||
// kernel netfilter rules.
|
|
||||||
const EnvForceUserspaceFirewall = "NB_FORCE_USERSPACE_FIREWALL"
|
|
||||||
|
|
||||||
// IFaceMapper defines subset methods of interface required for manager
|
// IFaceMapper defines subset methods of interface required for manager
|
||||||
type IFaceMapper interface {
|
type IFaceMapper interface {
|
||||||
Name() string
|
Name() string
|
||||||
|
|||||||
@@ -21,10 +21,6 @@ const (
|
|||||||
|
|
||||||
// rules chains contains the effective ACL rules
|
// rules chains contains the effective ACL rules
|
||||||
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
||||||
|
|
||||||
// mangleFwdKey is the entries map key for mangle FORWARD guard rules that prevent
|
|
||||||
// external DNAT from bypassing ACL rules.
|
|
||||||
mangleFwdKey = "MANGLE-FORWARD"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type aclEntries map[string][][]string
|
type aclEntries map[string][][]string
|
||||||
@@ -278,12 +274,6 @@ func (m *aclManager) cleanChains() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, rule := range m.entries[mangleFwdKey] {
|
|
||||||
if err := m.iptablesClient.DeleteIfExists(tableMangle, chainFORWARD, rule...); err != nil {
|
|
||||||
log.Errorf("failed to delete mangle FORWARD guard rule: %v, %s", rule, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
||||||
if err := m.flushIPSet(ipsetName); err != nil {
|
if err := m.flushIPSet(ipsetName); err != nil {
|
||||||
if errors.Is(err, ipset.ErrSetNotExist) {
|
if errors.Is(err, ipset.ErrSetNotExist) {
|
||||||
@@ -313,10 +303,6 @@ func (m *aclManager) createDefaultChains() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for chainName, rules := range m.entries {
|
for chainName, rules := range m.entries {
|
||||||
// mangle FORWARD guard rules are handled separately below
|
|
||||||
if chainName == mangleFwdKey {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
|
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
|
||||||
log.Debugf("failed to create input chain jump rule: %s", err)
|
log.Debugf("failed to create input chain jump rule: %s", err)
|
||||||
@@ -336,13 +322,6 @@ func (m *aclManager) createDefaultChains() error {
|
|||||||
}
|
}
|
||||||
clear(m.optionalEntries)
|
clear(m.optionalEntries)
|
||||||
|
|
||||||
// Insert mangle FORWARD guard rules to prevent external DNAT bypass.
|
|
||||||
for _, rule := range m.entries[mangleFwdKey] {
|
|
||||||
if err := m.iptablesClient.AppendUnique(tableMangle, chainFORWARD, rule...); err != nil {
|
|
||||||
log.Errorf("failed to add mangle FORWARD guard rule: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -364,22 +343,6 @@ func (m *aclManager) seedInitialEntries() {
|
|||||||
|
|
||||||
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
|
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
|
||||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
|
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
|
||||||
|
|
||||||
// Mangle FORWARD guard: when external DNAT redirects traffic from the wg interface, it
|
|
||||||
// traverses FORWARD instead of INPUT, bypassing ACL rules. ACCEPT rules in filter FORWARD
|
|
||||||
// can be inserted above ours. Mangle runs before filter, so these guard rules enforce the
|
|
||||||
// ACL mark check where it cannot be overridden.
|
|
||||||
m.appendToEntries(mangleFwdKey, []string{
|
|
||||||
"-i", m.wgIface.Name(),
|
|
||||||
"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED",
|
|
||||||
"-j", "ACCEPT",
|
|
||||||
})
|
|
||||||
m.appendToEntries(mangleFwdKey, []string{
|
|
||||||
"-i", m.wgIface.Name(),
|
|
||||||
"-m", "conntrack", "--ctstate", "DNAT",
|
|
||||||
"-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
|
|
||||||
"-j", "DROP",
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *aclManager) seedInitialOptionalEntries() {
|
func (m *aclManager) seedInitialOptionalEntries() {
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ type Manager struct {
|
|||||||
type iFaceMapper interface {
|
type iFaceMapper interface {
|
||||||
Name() string
|
Name() string
|
||||||
Address() wgaddr.Address
|
Address() wgaddr.Address
|
||||||
|
IsUserspaceBind() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create iptables firewall manager
|
// Create iptables firewall manager
|
||||||
@@ -63,9 +64,10 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
|||||||
func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||||
state := &ShutdownState{
|
state := &ShutdownState{
|
||||||
InterfaceState: &InterfaceState{
|
InterfaceState: &InterfaceState{
|
||||||
NameStr: m.wgIface.Name(),
|
NameStr: m.wgIface.Name(),
|
||||||
WGAddress: m.wgIface.Address(),
|
WGAddress: m.wgIface.Address(),
|
||||||
MTU: m.router.mtu,
|
UserspaceBind: m.wgIface.IsUserspaceBind(),
|
||||||
|
MTU: m.router.mtu,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
stateManager.RegisterState(state)
|
stateManager.RegisterState(state)
|
||||||
@@ -201,10 +203,12 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AllowNetbird allows netbird interface traffic.
|
// AllowNetbird allows netbird interface traffic
|
||||||
// This is called when USPFilter wraps the native firewall, adding blanket accept
|
|
||||||
// rules so that packet filtering is handled in userspace instead of by netfilter.
|
|
||||||
func (m *Manager) AllowNetbird() error {
|
func (m *Manager) AllowNetbird() error {
|
||||||
|
if !m.wgIface.IsUserspaceBind() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
_, err := m.AddPeerFiltering(
|
_, err := m.AddPeerFiltering(
|
||||||
nil,
|
nil,
|
||||||
net.IP{0, 0, 0, 0},
|
net.IP{0, 0, 0, 0},
|
||||||
@@ -282,22 +286,6 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
|||||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
|
||||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
|
||||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
chainNameRaw = "NETBIRD-RAW"
|
chainNameRaw = "NETBIRD-RAW"
|
||||||
chainOUTPUT = "OUTPUT"
|
chainOUTPUT = "OUTPUT"
|
||||||
|
|||||||
@@ -47,6 +47,8 @@ func (i *iFaceMock) Address() wgaddr.Address {
|
|||||||
panic("AddressFunc is not set")
|
panic("AddressFunc is not set")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
||||||
|
|
||||||
func TestIptablesManager(t *testing.T) {
|
func TestIptablesManager(t *testing.T) {
|
||||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -36,7 +36,6 @@ const (
|
|||||||
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
|
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
|
||||||
chainRTPRE = "NETBIRD-RT-PRE"
|
chainRTPRE = "NETBIRD-RT-PRE"
|
||||||
chainRTRDR = "NETBIRD-RT-RDR"
|
chainRTRDR = "NETBIRD-RT-RDR"
|
||||||
chainNATOutput = "NETBIRD-NAT-OUTPUT"
|
|
||||||
chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP"
|
chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP"
|
||||||
routingFinalForwardJump = "ACCEPT"
|
routingFinalForwardJump = "ACCEPT"
|
||||||
routingFinalNatJump = "MASQUERADE"
|
routingFinalNatJump = "MASQUERADE"
|
||||||
@@ -44,7 +43,6 @@ const (
|
|||||||
jumpManglePre = "jump-mangle-pre"
|
jumpManglePre = "jump-mangle-pre"
|
||||||
jumpNatPre = "jump-nat-pre"
|
jumpNatPre = "jump-nat-pre"
|
||||||
jumpNatPost = "jump-nat-post"
|
jumpNatPost = "jump-nat-post"
|
||||||
jumpNatOutput = "jump-nat-output"
|
|
||||||
jumpMSSClamp = "jump-mss-clamp"
|
jumpMSSClamp = "jump-mss-clamp"
|
||||||
markManglePre = "mark-mangle-pre"
|
markManglePre = "mark-mangle-pre"
|
||||||
markManglePost = "mark-mangle-post"
|
markManglePost = "mark-mangle-post"
|
||||||
@@ -389,14 +387,6 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("flushing routing related tables")
|
log.Debug("flushing routing related tables")
|
||||||
|
|
||||||
// Remove jump rules from built-in chains before deleting custom chains,
|
|
||||||
// otherwise the chain deletion fails with "device or resource busy".
|
|
||||||
jumpRule := []string{"-j", chainNATOutput}
|
|
||||||
if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil {
|
|
||||||
log.Debugf("clean OUTPUT jump rule: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, chainInfo := range []struct {
|
for _, chainInfo := range []struct {
|
||||||
chain string
|
chain string
|
||||||
table string
|
table string
|
||||||
@@ -406,7 +396,6 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
|||||||
{chainRTPRE, tableMangle},
|
{chainRTPRE, tableMangle},
|
||||||
{chainRTNAT, tableNat},
|
{chainRTNAT, tableNat},
|
||||||
{chainRTRDR, tableNat},
|
{chainRTRDR, tableNat},
|
||||||
{chainNATOutput, tableNat},
|
|
||||||
{chainRTMSSCLAMP, tableMangle},
|
{chainRTMSSCLAMP, tableMangle},
|
||||||
} {
|
} {
|
||||||
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
||||||
@@ -981,81 +970,6 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ensureNATOutputChain lazily creates the OUTPUT NAT chain and jump rule on first use.
|
|
||||||
func (r *router) ensureNATOutputChain() error {
|
|
||||||
if _, exists := r.rules[jumpNatOutput]; exists {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
chainExists, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
|
|
||||||
}
|
|
||||||
if !chainExists {
|
|
||||||
if err := r.iptablesClient.NewChain(tableNat, chainNATOutput); err != nil {
|
|
||||||
return fmt.Errorf("create chain %s: %w", chainNATOutput, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
jumpRule := []string{"-j", chainNATOutput}
|
|
||||||
if err := r.iptablesClient.Insert(tableNat, "OUTPUT", 1, jumpRule...); err != nil {
|
|
||||||
if !chainExists {
|
|
||||||
if delErr := r.iptablesClient.ClearAndDeleteChain(tableNat, chainNATOutput); delErr != nil {
|
|
||||||
log.Warnf("failed to rollback chain %s: %v", chainNATOutput, delErr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fmt.Errorf("add OUTPUT jump rule: %w", err)
|
|
||||||
}
|
|
||||||
r.rules[jumpNatOutput] = jumpRule
|
|
||||||
|
|
||||||
r.updateState()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
|
||||||
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
|
||||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
|
||||||
|
|
||||||
if _, exists := r.rules[ruleID]; exists {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.ensureNATOutputChain(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
dnatRule := []string{
|
|
||||||
"-p", strings.ToLower(string(protocol)),
|
|
||||||
"--dport", strconv.Itoa(int(sourcePort)),
|
|
||||||
"-d", localAddr.String(),
|
|
||||||
"-j", "DNAT",
|
|
||||||
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil {
|
|
||||||
return fmt.Errorf("add output DNAT rule: %w", err)
|
|
||||||
}
|
|
||||||
r.rules[ruleID] = dnatRule
|
|
||||||
|
|
||||||
r.updateState()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
|
||||||
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
|
||||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
|
||||||
|
|
||||||
if dnatRule, exists := r.rules[ruleID]; exists {
|
|
||||||
if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil {
|
|
||||||
return fmt.Errorf("delete output DNAT rule: %w", err)
|
|
||||||
}
|
|
||||||
delete(r.rules, ruleID)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.updateState()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func applyPort(flag string, port *firewall.Port) []string {
|
func applyPort(flag string, port *firewall.Port) []string {
|
||||||
if port == nil {
|
if port == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -9,9 +9,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type InterfaceState struct {
|
type InterfaceState struct {
|
||||||
NameStr string `json:"name"`
|
NameStr string `json:"name"`
|
||||||
WGAddress wgaddr.Address `json:"wg_address"`
|
WGAddress wgaddr.Address `json:"wg_address"`
|
||||||
MTU uint16 `json:"mtu"`
|
UserspaceBind bool `json:"userspace_bind"`
|
||||||
|
MTU uint16 `json:"mtu"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *InterfaceState) Name() string {
|
func (i *InterfaceState) Name() string {
|
||||||
@@ -22,6 +23,10 @@ func (i *InterfaceState) Address() wgaddr.Address {
|
|||||||
return i.WGAddress
|
return i.WGAddress
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i *InterfaceState) IsUserspaceBind() bool {
|
||||||
|
return i.UserspaceBind
|
||||||
|
}
|
||||||
|
|
||||||
type ShutdownState struct {
|
type ShutdownState struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
|
|
||||||
|
|||||||
@@ -169,14 +169,6 @@ type Manager interface {
|
|||||||
// RemoveInboundDNAT removes inbound DNAT rule
|
// RemoveInboundDNAT removes inbound DNAT rule
|
||||||
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||||
|
|
||||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
|
||||||
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
|
|
||||||
AddOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
|
||||||
|
|
||||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
|
||||||
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
|
|
||||||
RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
|
||||||
|
|
||||||
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
|
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
|
||||||
// This prevents conntrack from interfering with WireGuard proxy communication.
|
// This prevents conntrack from interfering with WireGuard proxy communication.
|
||||||
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
|
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ func getTableName() string {
|
|||||||
type iFaceMapper interface {
|
type iFaceMapper interface {
|
||||||
Name() string
|
Name() string
|
||||||
Address() wgaddr.Address
|
Address() wgaddr.Address
|
||||||
|
IsUserspaceBind() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Manager of iptables firewall
|
// Manager of iptables firewall
|
||||||
@@ -105,9 +106,10 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
// cleanup using Close() without needing to store specific rules.
|
// cleanup using Close() without needing to store specific rules.
|
||||||
if err := stateManager.UpdateState(&ShutdownState{
|
if err := stateManager.UpdateState(&ShutdownState{
|
||||||
InterfaceState: &InterfaceState{
|
InterfaceState: &InterfaceState{
|
||||||
NameStr: m.wgIface.Name(),
|
NameStr: m.wgIface.Name(),
|
||||||
WGAddress: m.wgIface.Address(),
|
WGAddress: m.wgIface.Address(),
|
||||||
MTU: m.router.mtu,
|
UserspaceBind: m.wgIface.IsUserspaceBind(),
|
||||||
|
MTU: m.router.mtu,
|
||||||
},
|
},
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
log.Errorf("failed to update state: %v", err)
|
log.Errorf("failed to update state: %v", err)
|
||||||
@@ -203,10 +205,12 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
|||||||
return m.router.RemoveNatRule(pair)
|
return m.router.RemoveNatRule(pair)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AllowNetbird allows netbird interface traffic.
|
// AllowNetbird allows netbird interface traffic
|
||||||
// This is called when USPFilter wraps the native firewall, adding blanket accept
|
|
||||||
// rules so that packet filtering is handled in userspace instead of by netfilter.
|
|
||||||
func (m *Manager) AllowNetbird() error {
|
func (m *Manager) AllowNetbird() error {
|
||||||
|
if !m.wgIface.IsUserspaceBind() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
@@ -342,22 +346,6 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
|||||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
|
||||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
|
||||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
chainNameRawOutput = "netbird-raw-out"
|
chainNameRawOutput = "netbird-raw-out"
|
||||||
chainNameRawPrerouting = "netbird-raw-pre"
|
chainNameRawPrerouting = "netbird-raw-pre"
|
||||||
|
|||||||
@@ -52,6 +52,8 @@ func (i *iFaceMock) Address() wgaddr.Address {
|
|||||||
panic("AddressFunc is not set")
|
panic("AddressFunc is not set")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
||||||
|
|
||||||
func TestNftablesManager(t *testing.T) {
|
func TestNftablesManager(t *testing.T) {
|
||||||
|
|
||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
|
|||||||
@@ -36,7 +36,6 @@ const (
|
|||||||
chainNameRoutingFw = "netbird-rt-fwd"
|
chainNameRoutingFw = "netbird-rt-fwd"
|
||||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||||
chainNameRoutingRdr = "netbird-rt-redirect"
|
chainNameRoutingRdr = "netbird-rt-redirect"
|
||||||
chainNameNATOutput = "netbird-nat-output"
|
|
||||||
chainNameForward = "FORWARD"
|
chainNameForward = "FORWARD"
|
||||||
chainNameMangleForward = "netbird-mangle-forward"
|
chainNameMangleForward = "netbird-mangle-forward"
|
||||||
|
|
||||||
@@ -1854,130 +1853,6 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ensureNATOutputChain lazily creates the OUTPUT NAT chain on first use.
|
|
||||||
func (r *router) ensureNATOutputChain() error {
|
|
||||||
if _, exists := r.chains[chainNameNATOutput]; exists {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
r.chains[chainNameNATOutput] = r.conn.AddChain(&nftables.Chain{
|
|
||||||
Name: chainNameNATOutput,
|
|
||||||
Table: r.workTable,
|
|
||||||
Hooknum: nftables.ChainHookOutput,
|
|
||||||
Priority: nftables.ChainPriorityNATDest,
|
|
||||||
Type: nftables.ChainTypeNAT,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
delete(r.chains, chainNameNATOutput)
|
|
||||||
return fmt.Errorf("create NAT output chain: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
|
||||||
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
|
||||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
|
||||||
|
|
||||||
if _, exists := r.rules[ruleID]; exists {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.ensureNATOutputChain(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
protoNum, err := protoToInt(protocol)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("convert protocol to number: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
exprs := []expr.Any{
|
|
||||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: []byte{protoNum},
|
|
||||||
},
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 2,
|
|
||||||
Base: expr.PayloadBaseTransportHeader,
|
|
||||||
Offset: 2,
|
|
||||||
Len: 2,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 2,
|
|
||||||
Data: binaryutil.BigEndian.PutUint16(sourcePort),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
|
|
||||||
|
|
||||||
exprs = append(exprs,
|
|
||||||
&expr.Immediate{
|
|
||||||
Register: 1,
|
|
||||||
Data: localAddr.AsSlice(),
|
|
||||||
},
|
|
||||||
&expr.Immediate{
|
|
||||||
Register: 2,
|
|
||||||
Data: binaryutil.BigEndian.PutUint16(targetPort),
|
|
||||||
},
|
|
||||||
&expr.NAT{
|
|
||||||
Type: expr.NATTypeDestNAT,
|
|
||||||
Family: uint32(nftables.TableFamilyIPv4),
|
|
||||||
RegAddrMin: 1,
|
|
||||||
RegProtoMin: 2,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
dnatRule := &nftables.Rule{
|
|
||||||
Table: r.workTable,
|
|
||||||
Chain: r.chains[chainNameNATOutput],
|
|
||||||
Exprs: exprs,
|
|
||||||
UserData: []byte(ruleID),
|
|
||||||
}
|
|
||||||
r.conn.AddRule(dnatRule)
|
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
return fmt.Errorf("add output DNAT rule: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.rules[ruleID] = dnatRule
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
|
||||||
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
|
||||||
return fmt.Errorf(refreshRulesMapError, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
|
||||||
|
|
||||||
rule, exists := r.rules[ruleID]
|
|
||||||
if !exists {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if rule.Handle == 0 {
|
|
||||||
log.Warnf("output DNAT rule %s has no handle, removing stale entry", ruleID)
|
|
||||||
delete(r.rules, ruleID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
|
||||||
return fmt.Errorf("delete output DNAT rule %s: %w", ruleID, err)
|
|
||||||
}
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
return fmt.Errorf("flush delete output DNAT rule: %w", err)
|
|
||||||
}
|
|
||||||
delete(r.rules, ruleID)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// applyNetwork generates nftables expressions for networks (CIDR) or sets
|
// applyNetwork generates nftables expressions for networks (CIDR) or sets
|
||||||
func (r *router) applyNetwork(
|
func (r *router) applyNetwork(
|
||||||
network firewall.Network,
|
network firewall.Network,
|
||||||
|
|||||||
@@ -8,9 +8,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type InterfaceState struct {
|
type InterfaceState struct {
|
||||||
NameStr string `json:"name"`
|
NameStr string `json:"name"`
|
||||||
WGAddress wgaddr.Address `json:"wg_address"`
|
WGAddress wgaddr.Address `json:"wg_address"`
|
||||||
MTU uint16 `json:"mtu"`
|
UserspaceBind bool `json:"userspace_bind"`
|
||||||
|
MTU uint16 `json:"mtu"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *InterfaceState) Name() string {
|
func (i *InterfaceState) Name() string {
|
||||||
@@ -21,6 +22,10 @@ func (i *InterfaceState) Address() wgaddr.Address {
|
|||||||
return i.WGAddress
|
return i.WGAddress
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i *InterfaceState) IsUserspaceBind() bool {
|
||||||
|
return i.UserspaceBind
|
||||||
|
}
|
||||||
|
|
||||||
type ShutdownState struct {
|
type ShutdownState struct {
|
||||||
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
|
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,37 +0,0 @@
|
|||||||
package common
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"sync/atomic"
|
|
||||||
)
|
|
||||||
|
|
||||||
// PacketHook stores a registered hook for a specific IP:port.
|
|
||||||
type PacketHook struct {
|
|
||||||
IP netip.Addr
|
|
||||||
Port uint16
|
|
||||||
Fn func([]byte) bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// HookMatches checks if a packet's destination matches the hook and invokes it.
|
|
||||||
func HookMatches(h *PacketHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
|
|
||||||
if h == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if h.IP == dstIP && h.Port == dport {
|
|
||||||
return h.Fn(packetData)
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetHook atomically stores a hook, handling nil removal.
|
|
||||||
func SetHook(ptr *atomic.Pointer[PacketHook], ip netip.Addr, dPort uint16, hook func([]byte) bool) {
|
|
||||||
if hook == nil {
|
|
||||||
ptr.Store(nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ptr.Store(&PacketHook{
|
|
||||||
IP: ip,
|
|
||||||
Port: dPort,
|
|
||||||
Fn: hook,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -140,10 +140,6 @@ type Manager struct {
|
|||||||
mtu uint16
|
mtu uint16
|
||||||
mssClampValue uint16
|
mssClampValue uint16
|
||||||
mssClampEnabled bool
|
mssClampEnabled bool
|
||||||
|
|
||||||
// Only one hook per protocol is supported. Outbound direction only.
|
|
||||||
udpHookOut atomic.Pointer[common.PacketHook]
|
|
||||||
tcpHookOut atomic.Pointer[common.PacketHook]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// decoder for packages
|
// decoder for packages
|
||||||
@@ -598,8 +594,6 @@ func (m *Manager) resetState() {
|
|||||||
maps.Clear(m.incomingRules)
|
maps.Clear(m.incomingRules)
|
||||||
maps.Clear(m.routeRulesMap)
|
maps.Clear(m.routeRulesMap)
|
||||||
m.routeRules = m.routeRules[:0]
|
m.routeRules = m.routeRules[:0]
|
||||||
m.udpHookOut.Store(nil)
|
|
||||||
m.tcpHookOut.Store(nil)
|
|
||||||
|
|
||||||
if m.udpTracker != nil {
|
if m.udpTracker != nil {
|
||||||
m.udpTracker.Close()
|
m.udpTracker.Close()
|
||||||
@@ -719,9 +713,6 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
case layers.LayerTypeTCP:
|
case layers.LayerTypeTCP:
|
||||||
if m.tcpHooksDrop(uint16(d.tcp.DstPort), dstIP, packetData) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// Clamp MSS on all TCP SYN packets, including those from local IPs.
|
// Clamp MSS on all TCP SYN packets, including those from local IPs.
|
||||||
// SNATed routed traffic may appear as local IP but still requires clamping.
|
// SNATed routed traffic may appear as local IP but still requires clamping.
|
||||||
if m.mssClampEnabled {
|
if m.mssClampEnabled {
|
||||||
@@ -904,12 +895,39 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
|
|||||||
d.dnatOrigPort = 0
|
d.dnatOrigPort = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// udpHooksDrop checks if any UDP hooks should drop the packet
|
||||||
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
||||||
return common.HookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
|
m.mutex.RLock()
|
||||||
}
|
defer m.mutex.RUnlock()
|
||||||
|
|
||||||
func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
// Check specific destination IP first
|
||||||
return common.HookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
|
if rules, exists := m.outgoingRules[dstIP]; exists {
|
||||||
|
for _, rule := range rules {
|
||||||
|
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||||
|
return rule.udpHook(packetData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check IPv4 unspecified address
|
||||||
|
if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists {
|
||||||
|
for _, rule := range rules {
|
||||||
|
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||||
|
return rule.udpHook(packetData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check IPv6 unspecified address
|
||||||
|
if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists {
|
||||||
|
for _, rule := range rules {
|
||||||
|
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||||
|
return rule.udpHook(packetData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// filterInbound implements filtering logic for incoming packets.
|
// filterInbound implements filtering logic for incoming packets.
|
||||||
@@ -1260,6 +1278,12 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
|
|||||||
return rule.mgmtId, rule.drop, true
|
return rule.mgmtId, rule.drop, true
|
||||||
}
|
}
|
||||||
case layers.LayerTypeUDP:
|
case layers.LayerTypeUDP:
|
||||||
|
// if rule has UDP hook (and if we are here we match this rule)
|
||||||
|
// we ignore rule.drop and call this hook
|
||||||
|
if rule.udpHook != nil {
|
||||||
|
return rule.mgmtId, rule.udpHook(packetData), true
|
||||||
|
}
|
||||||
|
|
||||||
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
||||||
return rule.mgmtId, rule.drop, true
|
return rule.mgmtId, rule.drop, true
|
||||||
}
|
}
|
||||||
@@ -1318,14 +1342,65 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
|
|||||||
return sourceMatched
|
return sourceMatched
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove.
|
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||||
func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
|
//
|
||||||
common.SetHook(&m.udpHookOut, ip, dPort, hook)
|
// Hook function returns flag which indicates should be the matched package dropped or not
|
||||||
|
func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
|
||||||
|
r := PeerRule{
|
||||||
|
id: uuid.New().String(),
|
||||||
|
ip: ip,
|
||||||
|
protoLayer: layers.LayerTypeUDP,
|
||||||
|
dPort: &firewall.Port{Values: []uint16{dPort}},
|
||||||
|
ipLayer: layers.LayerTypeIPv6,
|
||||||
|
udpHook: hook,
|
||||||
|
}
|
||||||
|
|
||||||
|
if ip.Is4() {
|
||||||
|
r.ipLayer = layers.LayerTypeIPv4
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mutex.Lock()
|
||||||
|
if in {
|
||||||
|
// Incoming UDP hooks are stored in allow rules map
|
||||||
|
if _, ok := m.incomingRules[r.ip]; !ok {
|
||||||
|
m.incomingRules[r.ip] = make(map[string]PeerRule)
|
||||||
|
}
|
||||||
|
m.incomingRules[r.ip][r.id] = r
|
||||||
|
} else {
|
||||||
|
if _, ok := m.outgoingRules[r.ip]; !ok {
|
||||||
|
m.outgoingRules[r.ip] = make(map[string]PeerRule)
|
||||||
|
}
|
||||||
|
m.outgoingRules[r.ip][r.id] = r
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
|
||||||
|
return r.id
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove.
|
// RemovePacketHook removes packet hook by given ID
|
||||||
func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
|
func (m *Manager) RemovePacketHook(hookID string) error {
|
||||||
common.SetHook(&m.tcpHookOut, ip, dPort, hook)
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
// Check incoming hooks (stored in allow rules)
|
||||||
|
for _, arr := range m.incomingRules {
|
||||||
|
for _, r := range arr {
|
||||||
|
if r.id == hookID {
|
||||||
|
delete(arr, r.id)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Check outgoing hooks
|
||||||
|
for _, arr := range m.outgoingRules {
|
||||||
|
for _, r := range arr {
|
||||||
|
if r.id == hookID {
|
||||||
|
delete(arr, r.id)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fmt.Errorf("hook with given id not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetLogLevel sets the log level for the firewall manager
|
// SetLogLevel sets the log level for the firewall manager
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
@@ -187,52 +186,81 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSetUDPPacketHook(t *testing.T) {
|
func TestAddUDPPacketHook(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
tests := []struct {
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
name string
|
||||||
}, false, flowLogger, nbiface.DefaultMTU)
|
in bool
|
||||||
require.NoError(t, err)
|
expDir fw.RuleDirection
|
||||||
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
ip netip.Addr
|
||||||
|
dPort uint16
|
||||||
|
hook func([]byte) bool
|
||||||
|
expectedID string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Test Outgoing UDP Packet Hook",
|
||||||
|
in: false,
|
||||||
|
expDir: fw.RuleDirectionOUT,
|
||||||
|
ip: netip.MustParseAddr("10.168.0.1"),
|
||||||
|
dPort: 8000,
|
||||||
|
hook: func([]byte) bool { return true },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Test Incoming UDP Packet Hook",
|
||||||
|
in: true,
|
||||||
|
expDir: fw.RuleDirectionIN,
|
||||||
|
ip: netip.MustParseAddr("::1"),
|
||||||
|
dPort: 9000,
|
||||||
|
hook: func([]byte) bool { return false },
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
var called bool
|
for _, tt := range tests {
|
||||||
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, func([]byte) bool {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
called = true
|
manager, err := Create(&IFaceMock{
|
||||||
return true
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
})
|
}, false, flowLogger, nbiface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
h := manager.udpHookOut.Load()
|
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||||
require.NotNil(t, h)
|
|
||||||
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.IP)
|
|
||||||
assert.Equal(t, uint16(8000), h.Port)
|
|
||||||
assert.True(t, h.Fn(nil))
|
|
||||||
assert.True(t, called)
|
|
||||||
|
|
||||||
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, nil)
|
var addedRule PeerRule
|
||||||
assert.Nil(t, manager.udpHookOut.Load())
|
if tt.in {
|
||||||
}
|
// Incoming UDP hooks are stored in allow rules map
|
||||||
|
if len(manager.incomingRules[tt.ip]) != 1 {
|
||||||
|
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules[tt.ip]))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, rule := range manager.incomingRules[tt.ip] {
|
||||||
|
addedRule = rule
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if len(manager.outgoingRules[tt.ip]) != 1 {
|
||||||
|
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules[tt.ip]))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, rule := range manager.outgoingRules[tt.ip] {
|
||||||
|
addedRule = rule
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestSetTCPPacketHook(t *testing.T) {
|
if tt.ip.Compare(addedRule.ip) != 0 {
|
||||||
manager, err := Create(&IFaceMock{
|
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
return
|
||||||
}, false, flowLogger, nbiface.DefaultMTU)
|
}
|
||||||
require.NoError(t, err)
|
if tt.dPort != addedRule.dPort.Values[0] {
|
||||||
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort.Values[0])
|
||||||
|
return
|
||||||
var called bool
|
}
|
||||||
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, func([]byte) bool {
|
if layers.LayerTypeUDP != addedRule.protoLayer {
|
||||||
called = true
|
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
|
||||||
return true
|
return
|
||||||
})
|
}
|
||||||
|
if addedRule.udpHook == nil {
|
||||||
h := manager.tcpHookOut.Load()
|
t.Errorf("expected udpHook to be set")
|
||||||
require.NotNil(t, h)
|
return
|
||||||
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.IP)
|
}
|
||||||
assert.Equal(t, uint16(53), h.Port)
|
})
|
||||||
assert.True(t, h.Fn(nil))
|
}
|
||||||
assert.True(t, called)
|
|
||||||
|
|
||||||
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, nil)
|
|
||||||
assert.Nil(t, manager.tcpHookOut.Load())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
|
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
|
||||||
@@ -502,12 +530,39 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, func([]byte) bool { return true })
|
// Add a UDP packet hook
|
||||||
|
hookFunc := func(data []byte) bool { return true }
|
||||||
|
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
|
||||||
|
|
||||||
require.NotNil(t, manager.udpHookOut.Load(), "hook should be registered")
|
// Assert the hook is added by finding it in the manager's outgoing rules
|
||||||
|
found := false
|
||||||
|
for _, arr := range manager.outgoingRules {
|
||||||
|
for _, rule := range arr {
|
||||||
|
if rule.id == hookID {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, nil)
|
if !found {
|
||||||
assert.Nil(t, manager.udpHookOut.Load(), "hook should be removed")
|
t.Fatalf("The hook was not added properly.")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now remove the packet hook
|
||||||
|
err = manager.RemovePacketHook(hookID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to remove hook: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert the hook is removed by checking it in the manager's outgoing rules
|
||||||
|
for _, arr := range manager.outgoingRules {
|
||||||
|
for _, rule := range arr {
|
||||||
|
if rule.id == hookID {
|
||||||
|
t.Fatalf("The hook was not removed properly.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessOutgoingHooks(t *testing.T) {
|
func TestProcessOutgoingHooks(t *testing.T) {
|
||||||
@@ -537,7 +592,8 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
hookCalled := false
|
hookCalled := false
|
||||||
manager.SetUDPPacketHook(
|
hookID := manager.AddUDPPacketHook(
|
||||||
|
false,
|
||||||
netip.MustParseAddr("100.10.0.100"),
|
netip.MustParseAddr("100.10.0.100"),
|
||||||
53,
|
53,
|
||||||
func([]byte) bool {
|
func([]byte) bool {
|
||||||
@@ -545,6 +601,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
return true
|
return true
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
require.NotEmpty(t, hookID)
|
||||||
|
|
||||||
// Create test UDP packet
|
// Create test UDP packet
|
||||||
ipv4 := &layers.IPv4{
|
ipv4 := &layers.IPv4{
|
||||||
|
|||||||
@@ -1,90 +0,0 @@
|
|||||||
package uspfilter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"net/netip"
|
|
||||||
"sync/atomic"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
ipv4HeaderMinLen = 20
|
|
||||||
ipv4ProtoOffset = 9
|
|
||||||
ipv4FlagsOffset = 6
|
|
||||||
ipv4DstOffset = 16
|
|
||||||
ipProtoUDP = 17
|
|
||||||
ipProtoTCP = 6
|
|
||||||
ipv4FragOffMask = 0x1fff
|
|
||||||
// dstPortOffset is the offset of the destination port within a UDP or TCP header.
|
|
||||||
dstPortOffset = 2
|
|
||||||
)
|
|
||||||
|
|
||||||
// HooksFilter is a minimal packet filter that only handles outbound DNS hooks.
|
|
||||||
// It is installed on the WireGuard interface when the userspace bind is active
|
|
||||||
// but a full firewall filter (Manager) is not needed because a native kernel
|
|
||||||
// firewall (nftables/iptables) handles packet filtering.
|
|
||||||
type HooksFilter struct {
|
|
||||||
udpHook atomic.Pointer[common.PacketHook]
|
|
||||||
tcpHook atomic.Pointer[common.PacketHook]
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ device.PacketFilter = (*HooksFilter)(nil)
|
|
||||||
|
|
||||||
// FilterOutbound checks outbound packets for DNS hook matches.
|
|
||||||
// Only IPv4 packets matching the registered hook IP:port are intercepted.
|
|
||||||
// IPv6 and non-IP packets pass through unconditionally.
|
|
||||||
func (f *HooksFilter) FilterOutbound(packetData []byte, _ int) bool {
|
|
||||||
if len(packetData) < ipv4HeaderMinLen {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only process IPv4 packets, let everything else pass through.
|
|
||||||
if packetData[0]>>4 != 4 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
ihl := int(packetData[0]&0x0f) * 4
|
|
||||||
if ihl < ipv4HeaderMinLen || len(packetData) < ihl+4 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Skip non-first fragments: they don't carry L4 headers.
|
|
||||||
flagsAndOffset := binary.BigEndian.Uint16(packetData[ipv4FlagsOffset : ipv4FlagsOffset+2])
|
|
||||||
if flagsAndOffset&ipv4FragOffMask != 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
dstIP, ok := netip.AddrFromSlice(packetData[ipv4DstOffset : ipv4DstOffset+4])
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
proto := packetData[ipv4ProtoOffset]
|
|
||||||
dstPort := binary.BigEndian.Uint16(packetData[ihl+dstPortOffset : ihl+dstPortOffset+2])
|
|
||||||
|
|
||||||
switch proto {
|
|
||||||
case ipProtoUDP:
|
|
||||||
return common.HookMatches(f.udpHook.Load(), dstIP, dstPort, packetData)
|
|
||||||
case ipProtoTCP:
|
|
||||||
return common.HookMatches(f.tcpHook.Load(), dstIP, dstPort, packetData)
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// FilterInbound allows all inbound packets (native firewall handles filtering).
|
|
||||||
func (f *HooksFilter) FilterInbound([]byte, int) bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetUDPPacketHook registers the UDP packet hook.
|
|
||||||
func (f *HooksFilter) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func([]byte) bool) {
|
|
||||||
common.SetHook(&f.udpHook, ip, dPort, hook)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetTCPPacketHook registers the TCP packet hook.
|
|
||||||
func (f *HooksFilter) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func([]byte) bool) {
|
|
||||||
common.SetHook(&f.tcpHook, ip, dPort, hook)
|
|
||||||
}
|
|
||||||
@@ -144,8 +144,6 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to get interfaces: %v", err)
|
log.Warnf("failed to get interfaces: %v", err)
|
||||||
} else {
|
} else {
|
||||||
// TODO: filter out down interfaces (net.FlagUp). Also handle the reverse
|
|
||||||
// case where an interface comes up between refreshes.
|
|
||||||
for _, intf := range interfaces {
|
for _, intf := range interfaces {
|
||||||
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
|
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -421,7 +421,6 @@ func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.Laye
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||||
// TODO: also delegate to nativeFirewall when available for kernel WG mode
|
|
||||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
var layerType gopacket.LayerType
|
var layerType gopacket.LayerType
|
||||||
switch protocol {
|
switch protocol {
|
||||||
@@ -467,22 +466,6 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
|||||||
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
|
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddOutputDNAT delegates to the native firewall if available.
|
|
||||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
|
||||||
if m.nativeFirewall == nil {
|
|
||||||
return fmt.Errorf("output DNAT not supported without native firewall")
|
|
||||||
}
|
|
||||||
return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveOutputDNAT delegates to the native firewall if available.
|
|
||||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
|
||||||
if m.nativeFirewall == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
|
||||||
}
|
|
||||||
|
|
||||||
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
|
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
|
||||||
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
|
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||||
if !m.portDNATEnabled.Load() {
|
if !m.portDNATEnabled.Load() {
|
||||||
|
|||||||
@@ -18,7 +18,9 @@ type PeerRule struct {
|
|||||||
protoLayer gopacket.LayerType
|
protoLayer gopacket.LayerType
|
||||||
sPort *firewall.Port
|
sPort *firewall.Port
|
||||||
dPort *firewall.Port
|
dPort *firewall.Port
|
||||||
drop bool
|
drop bool
|
||||||
|
|
||||||
|
udpHook func([]byte) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ID returns the rule id
|
// ID returns the rule id
|
||||||
|
|||||||
@@ -399,17 +399,21 @@ func TestTracePacket(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "UDPTraffic_WithHook",
|
name: "UDPTraffic_WithHook",
|
||||||
setup: func(m *Manager) {
|
setup: func(m *Manager) {
|
||||||
m.SetUDPPacketHook(netip.MustParseAddr("100.10.255.254"), 53, func([]byte) bool {
|
hookFunc := func([]byte) bool {
|
||||||
return true // drop (intercepted by hook)
|
return true
|
||||||
})
|
}
|
||||||
|
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
|
||||||
},
|
},
|
||||||
packetBuilder: func() *PacketBuilder {
|
packetBuilder: func() *PacketBuilder {
|
||||||
return createPacketBuilder("100.10.0.100", "100.10.255.254", "udp", 12345, 53, fw.RuleDirectionOUT)
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
|
||||||
},
|
},
|
||||||
expectedStages: []PacketStage{
|
expectedStages: []PacketStage{
|
||||||
StageReceived,
|
StageReceived,
|
||||||
StageOutbound1to1NAT,
|
StageInboundPortDNAT,
|
||||||
StageOutboundPortReverse,
|
StageInbound1to1NAT,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
StageCompleted,
|
StageCompleted,
|
||||||
},
|
},
|
||||||
expectedAllow: false,
|
expectedAllow: false,
|
||||||
|
|||||||
@@ -15,17 +15,14 @@ type PacketFilter interface {
|
|||||||
// FilterInbound filter incoming packets from external sources to host
|
// FilterInbound filter incoming packets from external sources to host
|
||||||
FilterInbound(packetData []byte, size int) bool
|
FilterInbound(packetData []byte, size int) bool
|
||||||
|
|
||||||
// SetUDPPacketHook registers a hook for outbound UDP packets matching the given IP and port.
|
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||||
// Hook function returns true if the packet should be dropped.
|
//
|
||||||
// Only one UDP hook is supported; calling again replaces the previous hook.
|
// Hook function returns flag which indicates should be the matched package dropped or not.
|
||||||
// Pass nil hook to remove.
|
// Hook function receives raw network packet data as argument.
|
||||||
SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
|
AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
|
||||||
|
|
||||||
// SetTCPPacketHook registers a hook for outbound TCP packets matching the given IP and port.
|
// RemovePacketHook removes hook by ID
|
||||||
// Hook function returns true if the packet should be dropped.
|
RemovePacketHook(hookID string) error
|
||||||
// Only one TCP hook is supported; calling again replaces the previous hook.
|
|
||||||
// Pass nil hook to remove.
|
|
||||||
SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// FilteredDevice to override Read or Write of packets
|
// FilteredDevice to override Read or Write of packets
|
||||||
|
|||||||
@@ -217,6 +217,7 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
|||||||
// Close closes the tunnel interface
|
// Close closes the tunnel interface
|
||||||
func (w *WGIface) Close() error {
|
func (w *WGIface) Close() error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
var result *multierror.Error
|
var result *multierror.Error
|
||||||
|
|
||||||
@@ -224,15 +225,7 @@ func (w *WGIface) Close() error {
|
|||||||
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
|
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Release w.mu before calling w.tun.Close(): the underlying
|
if err := w.tun.Close(); err != nil {
|
||||||
// wireguard-go device.Close() waits for its send/receive goroutines
|
|
||||||
// to drain. Some of those goroutines re-enter WGIface methods that
|
|
||||||
// take w.mu (e.g. the packet filter DNS hook calls GetDevice()), so
|
|
||||||
// holding the mutex here would deadlock the shutdown path.
|
|
||||||
tun := w.tun
|
|
||||||
w.mu.Unlock()
|
|
||||||
|
|
||||||
if err := tun.Close(); err != nil {
|
|
||||||
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
|
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,113 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package iface
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
|
||||||
)
|
|
||||||
|
|
||||||
// fakeTunDevice implements WGTunDevice and lets the test control when
|
|
||||||
// Close() returns. It mimics the wireguard-go shutdown path, which blocks
|
|
||||||
// until its goroutines drain. Some of those goroutines (e.g. the packet
|
|
||||||
// filter DNS hook in client/internal/dns) call back into WGIface, so if
|
|
||||||
// WGIface.Close() held w.mu across tun.Close() the shutdown would
|
|
||||||
// deadlock.
|
|
||||||
type fakeTunDevice struct {
|
|
||||||
closeStarted chan struct{}
|
|
||||||
unblockClose chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeTunDevice) Create() (device.WGConfigurer, error) {
|
|
||||||
return nil, errors.New("not implemented")
|
|
||||||
}
|
|
||||||
func (f *fakeTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
|
||||||
return nil, errors.New("not implemented")
|
|
||||||
}
|
|
||||||
func (f *fakeTunDevice) UpdateAddr(wgaddr.Address) error { return nil }
|
|
||||||
func (f *fakeTunDevice) WgAddress() wgaddr.Address { return wgaddr.Address{} }
|
|
||||||
func (f *fakeTunDevice) MTU() uint16 { return DefaultMTU }
|
|
||||||
func (f *fakeTunDevice) DeviceName() string { return "nb-close-test" }
|
|
||||||
func (f *fakeTunDevice) FilteredDevice() *device.FilteredDevice { return nil }
|
|
||||||
func (f *fakeTunDevice) Device() *wgdevice.Device { return nil }
|
|
||||||
func (f *fakeTunDevice) GetNet() *netstack.Net { return nil }
|
|
||||||
func (f *fakeTunDevice) GetICEBind() device.EndpointManager { return nil }
|
|
||||||
|
|
||||||
func (f *fakeTunDevice) Close() error {
|
|
||||||
close(f.closeStarted)
|
|
||||||
<-f.unblockClose
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type fakeProxyFactory struct{}
|
|
||||||
|
|
||||||
func (fakeProxyFactory) GetProxy() wgproxy.Proxy { return nil }
|
|
||||||
func (fakeProxyFactory) GetProxyPort() uint16 { return 0 }
|
|
||||||
func (fakeProxyFactory) Free() error { return nil }
|
|
||||||
|
|
||||||
// TestWGIface_CloseReleasesMutexBeforeTunClose guards against a deadlock
|
|
||||||
// that surfaces as a macOS test-timeout in
|
|
||||||
// TestDNSPermanent_updateUpstream: WGIface.Close() used to hold w.mu
|
|
||||||
// while waiting for the wireguard-go device goroutines to finish, and
|
|
||||||
// one of those goroutines (the DNS filter hook) calls back into
|
|
||||||
// WGIface.GetDevice() which needs the same mutex. The fix is to drop
|
|
||||||
// the lock before tun.Close() returns control.
|
|
||||||
func TestWGIface_CloseReleasesMutexBeforeTunClose(t *testing.T) {
|
|
||||||
tun := &fakeTunDevice{
|
|
||||||
closeStarted: make(chan struct{}),
|
|
||||||
unblockClose: make(chan struct{}),
|
|
||||||
}
|
|
||||||
w := &WGIface{
|
|
||||||
tun: tun,
|
|
||||||
wgProxyFactory: fakeProxyFactory{},
|
|
||||||
}
|
|
||||||
|
|
||||||
closeDone := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
closeDone <- w.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-tun.closeStarted:
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
close(tun.unblockClose)
|
|
||||||
t.Fatal("tun.Close() was never invoked")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Simulate the WireGuard read goroutine calling back into WGIface
|
|
||||||
// via the packet filter's DNS hook. If Close() still held w.mu
|
|
||||||
// during tun.Close(), this would block until the test timeout.
|
|
||||||
getDeviceDone := make(chan struct{})
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
_ = w.GetDevice()
|
|
||||||
close(getDeviceDone)
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-getDeviceDone:
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
close(tun.unblockClose)
|
|
||||||
wg.Wait()
|
|
||||||
t.Fatal("GetDevice() deadlocked while WGIface.Close was closing the tun")
|
|
||||||
}
|
|
||||||
|
|
||||||
close(tun.unblockClose)
|
|
||||||
select {
|
|
||||||
case <-closeDone:
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
t.Fatal("WGIface.Close() never returned after the tun was unblocked")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -34,28 +34,18 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
|
|||||||
return m.recorder
|
return m.recorder
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetUDPPacketHook mocks base method.
|
// AddUDPPacketHook mocks base method.
|
||||||
func (m *MockPacketFilter) SetUDPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) {
|
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
m.ctrl.Call(m, "SetUDPPacketHook", arg0, arg1, arg2)
|
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
||||||
|
ret0, _ := ret[0].(string)
|
||||||
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetUDPPacketHook indicates an expected call of SetUDPPacketHook.
|
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
|
||||||
func (mr *MockPacketFilterMockRecorder) SetUDPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call {
|
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetUDPPacketHook), arg0, arg1, arg2)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
||||||
}
|
|
||||||
|
|
||||||
// SetTCPPacketHook mocks base method.
|
|
||||||
func (m *MockPacketFilter) SetTCPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
m.ctrl.Call(m, "SetTCPPacketHook", arg0, arg1, arg2)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetTCPPacketHook indicates an expected call of SetTCPPacketHook.
|
|
||||||
func (mr *MockPacketFilterMockRecorder) SetTCPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTCPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetTCPPacketHook), arg0, arg1, arg2)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// FilterInbound mocks base method.
|
// FilterInbound mocks base method.
|
||||||
@@ -85,3 +75,17 @@ func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}, arg1 an
|
|||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RemovePacketHook mocks base method.
|
||||||
|
func (m *MockPacketFilter) RemovePacketHook(arg0 string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "RemovePacketHook", arg0)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemovePacketHook indicates an expected call of RemovePacketHook.
|
||||||
|
func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
|
||||||
|
}
|
||||||
|
|||||||
87
client/iface/mocks/iface/mocks/filter.go
Normal file
87
client/iface/mocks/iface/mocks/filter.go
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
|
// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter)
|
||||||
|
|
||||||
|
// Package mocks is a generated GoMock package.
|
||||||
|
package mocks
|
||||||
|
|
||||||
|
import (
|
||||||
|
net "net"
|
||||||
|
reflect "reflect"
|
||||||
|
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockPacketFilter is a mock of PacketFilter interface.
|
||||||
|
type MockPacketFilter struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockPacketFilterMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockPacketFilterMockRecorder is the mock recorder for MockPacketFilter.
|
||||||
|
type MockPacketFilterMockRecorder struct {
|
||||||
|
mock *MockPacketFilter
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockPacketFilter creates a new mock instance.
|
||||||
|
func NewMockPacketFilter(ctrl *gomock.Controller) *MockPacketFilter {
|
||||||
|
mock := &MockPacketFilter{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockPacketFilterMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||||
|
func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddUDPPacketHook mocks base method.
|
||||||
|
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func(*net.UDPAddr, []byte) bool) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
|
||||||
|
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilterInbound mocks base method.
|
||||||
|
func (m *MockPacketFilter) FilterInbound(arg0 []byte) bool {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "FilterInbound", arg0)
|
||||||
|
ret0, _ := ret[0].(bool)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilterInbound indicates an expected call of FilterInbound.
|
||||||
|
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilterOutbound mocks base method.
|
||||||
|
func (m *MockPacketFilter) FilterOutbound(arg0 []byte) bool {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "FilterOutbound", arg0)
|
||||||
|
ret0, _ := ret[0].(bool)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilterOutbound indicates an expected call of FilterOutbound.
|
||||||
|
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNetwork mocks base method.
|
||||||
|
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "SetNetwork", arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNetwork indicates an expected call of SetNetwork.
|
||||||
|
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
|
||||||
|
}
|
||||||
@@ -171,7 +171,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if u.address.Network.Contains(a) {
|
if u.address.Network.Contains(a) {
|
||||||
log.Warnf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||||
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -181,7 +181,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
|
|||||||
u.addrCache.Store(addr.String(), isRouted)
|
u.addrCache.Store(addr.String(), isRouted)
|
||||||
if isRouted {
|
if isRouted {
|
||||||
// Extra log, as the error only shows up with ICE logging enabled
|
// Extra log, as the error only shows up with ICE logging enabled
|
||||||
log.Infof("address %s is part of routed network %s, refusing to write", addr, prefix)
|
log.Infof("Address %s is part of routed network %s, refusing to write", addr, prefix)
|
||||||
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
|
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,9 +19,6 @@ import (
|
|||||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||||
|
|
||||||
func TestDefaultManager(t *testing.T) {
|
func TestDefaultManager(t *testing.T) {
|
||||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
|
||||||
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
|
||||||
|
|
||||||
networkMap := &mgmProto.NetworkMap{
|
networkMap := &mgmProto.NetworkMap{
|
||||||
FirewallRules: []*mgmProto.FirewallRule{
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
{
|
{
|
||||||
@@ -138,7 +135,6 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
func TestDefaultManagerStateless(t *testing.T) {
|
func TestDefaultManagerStateless(t *testing.T) {
|
||||||
// stateless currently only in userspace, so we have to disable kernel
|
// stateless currently only in userspace, so we have to disable kernel
|
||||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
|
||||||
t.Setenv("NB_DISABLE_CONNTRACK", "true")
|
t.Setenv("NB_DISABLE_CONNTRACK", "true")
|
||||||
|
|
||||||
networkMap := &mgmProto.NetworkMap{
|
networkMap := &mgmProto.NetworkMap{
|
||||||
@@ -198,7 +194,6 @@ func TestDefaultManagerStateless(t *testing.T) {
|
|||||||
// This tests the full ACL manager -> uspfilter integration.
|
// This tests the full ACL manager -> uspfilter integration.
|
||||||
func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
|
func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
|
||||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
|
||||||
|
|
||||||
networkMap := &mgmProto.NetworkMap{
|
networkMap := &mgmProto.NetworkMap{
|
||||||
FirewallRules: []*mgmProto.FirewallRule{
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
@@ -263,7 +258,6 @@ func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
|
|||||||
// up when they're removed from the network map in a subsequent update.
|
// up when they're removed from the network map in a subsequent update.
|
||||||
func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
|
func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
|
||||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
defer ctrl.Finish()
|
defer ctrl.Finish()
|
||||||
@@ -345,7 +339,6 @@ func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
|
|||||||
// one added without leaking.
|
// one added without leaking.
|
||||||
func TestRuleUpdateChangingAction(t *testing.T) {
|
func TestRuleUpdateChangingAction(t *testing.T) {
|
||||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
defer ctrl.Finish()
|
defer ctrl.Finish()
|
||||||
|
|||||||
@@ -155,7 +155,7 @@ func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
|
|||||||
var needsLogin bool
|
var needsLogin bool
|
||||||
|
|
||||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||||
err := a.doMgmLogin(client, ctx, pubSSHKey)
|
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||||
if isLoginNeeded(err) {
|
if isLoginNeeded(err) {
|
||||||
needsLogin = true
|
needsLogin = true
|
||||||
return nil
|
return nil
|
||||||
@@ -179,8 +179,8 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
|
|||||||
var isAuthError bool
|
var isAuthError bool
|
||||||
|
|
||||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||||
err := a.doMgmLogin(client, ctx, pubSSHKey)
|
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||||
if isRegistrationNeeded(err) {
|
if serverKey != nil && isRegistrationNeeded(err) {
|
||||||
log.Debugf("peer registration required")
|
log.Debugf("peer registration required")
|
||||||
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
|
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -201,7 +201,13 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
|
|||||||
|
|
||||||
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
|
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
|
||||||
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
|
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
|
||||||
protoFlow, err := client.GetPKCEAuthorizationFlow()
|
serverKey, err := client.GetServerPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||||
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
||||||
@@ -215,7 +221,7 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro
|
|||||||
config := &PKCEAuthProviderConfig{
|
config := &PKCEAuthProviderConfig{
|
||||||
Audience: protoConfig.GetAudience(),
|
Audience: protoConfig.GetAudience(),
|
||||||
ClientID: protoConfig.GetClientID(),
|
ClientID: protoConfig.GetClientID(),
|
||||||
ClientSecret: protoConfig.GetClientSecret(), //nolint:staticcheck
|
ClientSecret: protoConfig.GetClientSecret(),
|
||||||
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||||
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
|
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
|
||||||
Scope: protoConfig.GetScope(),
|
Scope: protoConfig.GetScope(),
|
||||||
@@ -240,7 +246,13 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro
|
|||||||
|
|
||||||
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
|
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
|
||||||
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
|
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
|
||||||
protoFlow, err := client.GetDeviceAuthorizationFlow()
|
serverKey, err := client.GetServerPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||||
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
||||||
@@ -254,7 +266,7 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow,
|
|||||||
config := &DeviceAuthProviderConfig{
|
config := &DeviceAuthProviderConfig{
|
||||||
Audience: protoConfig.GetAudience(),
|
Audience: protoConfig.GetAudience(),
|
||||||
ClientID: protoConfig.GetClientID(),
|
ClientID: protoConfig.GetClientID(),
|
||||||
ClientSecret: protoConfig.GetClientSecret(), //nolint:staticcheck
|
ClientSecret: protoConfig.GetClientSecret(),
|
||||||
Domain: protoConfig.Domain,
|
Domain: protoConfig.Domain,
|
||||||
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||||
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
|
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
|
||||||
@@ -280,16 +292,28 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// doMgmLogin performs the actual login operation with the management service
|
// doMgmLogin performs the actual login operation with the management service
|
||||||
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) error {
|
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
|
||||||
|
serverKey, err := client.GetServerPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
sysInfo := system.GetInfo(ctx)
|
sysInfo := system.GetInfo(ctx)
|
||||||
a.setSystemInfoFlags(sysInfo)
|
a.setSystemInfoFlags(sysInfo)
|
||||||
_, err := client.Login(sysInfo, pubSSHKey, a.config.DNSLabels)
|
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
|
||||||
return err
|
return serverKey, loginResp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||||
// Otherwise tries to register with the provided setupKey via command line.
|
// Otherwise tries to register with the provided setupKey via command line.
|
||||||
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
||||||
|
serverPublicKey, err := client.GetServerPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
validSetupKey, err := uuid.Parse(setupKey)
|
validSetupKey, err := uuid.Parse(setupKey)
|
||||||
if err != nil && jwtToken == "" {
|
if err != nil && jwtToken == "" {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
||||||
@@ -298,7 +322,7 @@ func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKe
|
|||||||
log.Debugf("sending peer registration request to Management Service")
|
log.Debugf("sending peer registration request to Management Service")
|
||||||
info := system.GetInfo(ctx)
|
info := system.GetInfo(ctx)
|
||||||
a.setSystemInfoFlags(info)
|
a.setSystemInfoFlags(info)
|
||||||
loginResp, err := client.Register(validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
|
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed registering peer %v", err)
|
log.Errorf("failed registering peer %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -44,10 +44,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
// androidRunOverride is set on Android to inject mobile dependencies
|
|
||||||
// when using embed.Client (which calls Run() with empty MobileDependency).
|
|
||||||
var androidRunOverride func(c *ConnectClient, runningChan chan struct{}, logPath string) error
|
|
||||||
|
|
||||||
type ConnectClient struct {
|
type ConnectClient struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
config *profilemanager.Config
|
config *profilemanager.Config
|
||||||
@@ -80,9 +76,6 @@ func (c *ConnectClient) SetUpdateManager(um *updater.Manager) {
|
|||||||
|
|
||||||
// Run with main logic.
|
// Run with main logic.
|
||||||
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
|
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
|
||||||
if androidRunOverride != nil {
|
|
||||||
return androidRunOverride(c, runningChan, logPath)
|
|
||||||
}
|
|
||||||
return c.run(MobileDependency{}, runningChan, logPath)
|
return c.run(MobileDependency{}, runningChan, logPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -94,7 +87,6 @@ func (c *ConnectClient) RunOnAndroid(
|
|||||||
dnsAddresses []netip.AddrPort,
|
dnsAddresses []netip.AddrPort,
|
||||||
dnsReadyListener dns.ReadyListener,
|
dnsReadyListener dns.ReadyListener,
|
||||||
stateFilePath string,
|
stateFilePath string,
|
||||||
cacheDir string,
|
|
||||||
) error {
|
) error {
|
||||||
// in case of non Android os these variables will be nil
|
// in case of non Android os these variables will be nil
|
||||||
mobileDependency := MobileDependency{
|
mobileDependency := MobileDependency{
|
||||||
@@ -104,7 +96,6 @@ func (c *ConnectClient) RunOnAndroid(
|
|||||||
HostDNSAddresses: dnsAddresses,
|
HostDNSAddresses: dnsAddresses,
|
||||||
DnsReadyListener: dnsReadyListener,
|
DnsReadyListener: dnsReadyListener,
|
||||||
StateFilePath: stateFilePath,
|
StateFilePath: stateFilePath,
|
||||||
TempDir: cacheDir,
|
|
||||||
}
|
}
|
||||||
return c.run(mobileDependency, nil, "")
|
return c.run(mobileDependency, nil, "")
|
||||||
}
|
}
|
||||||
@@ -113,7 +104,6 @@ func (c *ConnectClient) RunOniOS(
|
|||||||
fileDescriptor int32,
|
fileDescriptor int32,
|
||||||
networkChangeListener listener.NetworkChangeListener,
|
networkChangeListener listener.NetworkChangeListener,
|
||||||
dnsManager dns.IosDnsManager,
|
dnsManager dns.IosDnsManager,
|
||||||
dnsAddresses []netip.AddrPort,
|
|
||||||
stateFilePath string,
|
stateFilePath string,
|
||||||
) error {
|
) error {
|
||||||
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
|
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
|
||||||
@@ -123,7 +113,6 @@ func (c *ConnectClient) RunOniOS(
|
|||||||
FileDescriptor: fileDescriptor,
|
FileDescriptor: fileDescriptor,
|
||||||
NetworkChangeListener: networkChangeListener,
|
NetworkChangeListener: networkChangeListener,
|
||||||
DnsManager: dnsManager,
|
DnsManager: dnsManager,
|
||||||
HostDNSAddresses: dnsAddresses,
|
|
||||||
StateFilePath: stateFilePath,
|
StateFilePath: stateFilePath,
|
||||||
}
|
}
|
||||||
return c.run(mobileDependency, nil, "")
|
return c.run(mobileDependency, nil, "")
|
||||||
@@ -340,7 +329,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
log.Error(err)
|
log.Error(err)
|
||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
}
|
}
|
||||||
engineConfig.TempDir = mobileDependency.TempDir
|
|
||||||
|
|
||||||
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
|
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
|
||||||
c.statusRecorder.SetRelayMgr(relayManager)
|
c.statusRecorder.SetRelayMgr(relayManager)
|
||||||
@@ -622,6 +610,12 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
|
|||||||
|
|
||||||
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
|
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
|
||||||
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
|
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
|
||||||
|
|
||||||
|
serverPublicKey, err := client.GetServerPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, gstatus.Errorf(codes.FailedPrecondition, "failed while getting Management Service public key: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
sysInfo := system.GetInfo(ctx)
|
sysInfo := system.GetInfo(ctx)
|
||||||
sysInfo.SetFlags(
|
sysInfo.SetFlags(
|
||||||
config.RosenpassEnabled,
|
config.RosenpassEnabled,
|
||||||
@@ -640,7 +634,12 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
|||||||
config.EnableSSHRemotePortForwarding,
|
config.EnableSSHRemotePortForwarding,
|
||||||
config.DisableSSHAuth,
|
config.DisableSSHAuth,
|
||||||
)
|
)
|
||||||
return client.Login(sysInfo, pubSSHKey, config.DNSLabels)
|
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return loginResp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier {
|
func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier {
|
||||||
|
|||||||
@@ -1,73 +0,0 @@
|
|||||||
//go:build android
|
|
||||||
|
|
||||||
package internal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
|
||||||
)
|
|
||||||
|
|
||||||
// noopIFaceDiscover is a stub ExternalIFaceDiscover for embed.Client on Android.
|
|
||||||
// It returns an empty interface list, which means ICE P2P candidates won't be
|
|
||||||
// discovered — connections will fall back to relay. Applications that need P2P
|
|
||||||
// should provide a real implementation via runOnAndroidEmbed that uses
|
|
||||||
// Android's ConnectivityManager to enumerate network interfaces.
|
|
||||||
type noopIFaceDiscover struct{}
|
|
||||||
|
|
||||||
func (noopIFaceDiscover) IFaces() (string, error) {
|
|
||||||
// Return empty JSON array — no local interfaces advertised for ICE.
|
|
||||||
// This is intentional: without Android's ConnectivityManager, we cannot
|
|
||||||
// reliably enumerate interfaces (netlink is restricted on Android 11+).
|
|
||||||
// Relay connections still work; only P2P hole-punching is disabled.
|
|
||||||
return "[]", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// noopNetworkChangeListener is a stub for embed.Client on Android.
|
|
||||||
// Network change events are ignored since the embed client manages its own
|
|
||||||
// reconnection logic via the engine's built-in retry mechanism.
|
|
||||||
type noopNetworkChangeListener struct{}
|
|
||||||
|
|
||||||
func (noopNetworkChangeListener) OnNetworkChanged(string) {
|
|
||||||
// No-op: embed.Client relies on the engine's internal reconnection
|
|
||||||
// logic rather than OS-level network change notifications.
|
|
||||||
}
|
|
||||||
|
|
||||||
func (noopNetworkChangeListener) SetInterfaceIP(string) {
|
|
||||||
// No-op: in netstack mode, the overlay IP is managed by the userspace
|
|
||||||
// network stack, not by OS-level interface configuration.
|
|
||||||
}
|
|
||||||
|
|
||||||
// noopDnsReadyListener is a stub for embed.Client on Android.
|
|
||||||
// DNS readiness notifications are not needed in netstack/embed mode
|
|
||||||
// since system DNS is disabled and DNS resolution happens externally.
|
|
||||||
type noopDnsReadyListener struct{}
|
|
||||||
|
|
||||||
func (noopDnsReadyListener) OnReady() {
|
|
||||||
// No-op: embed.Client does not need DNS readiness notifications.
|
|
||||||
// System DNS is disabled in netstack mode.
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ stdnet.ExternalIFaceDiscover = noopIFaceDiscover{}
|
|
||||||
var _ listener.NetworkChangeListener = noopNetworkChangeListener{}
|
|
||||||
var _ dns.ReadyListener = noopDnsReadyListener{}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
// Wire up the default override so embed.Client.Start() works on Android
|
|
||||||
// with netstack mode. Provides complete no-op stubs for all mobile
|
|
||||||
// dependencies so the engine's existing Android code paths work unchanged.
|
|
||||||
// Applications that need P2P ICE or real DNS should replace this by
|
|
||||||
// setting androidRunOverride before calling Start().
|
|
||||||
androidRunOverride = func(c *ConnectClient, runningChan chan struct{}, logPath string) error {
|
|
||||||
return c.runOnAndroidEmbed(
|
|
||||||
noopIFaceDiscover{},
|
|
||||||
noopNetworkChangeListener{},
|
|
||||||
[]netip.AddrPort{},
|
|
||||||
noopDnsReadyListener{},
|
|
||||||
runningChan,
|
|
||||||
logPath,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
//go:build android
|
|
||||||
|
|
||||||
package internal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
|
||||||
)
|
|
||||||
|
|
||||||
// runOnAndroidEmbed is like RunOnAndroid but accepts a runningChan
|
|
||||||
// so embed.Client.Start() can detect when the engine is ready.
|
|
||||||
// It provides complete MobileDependency so the engine's existing
|
|
||||||
// Android code paths work unchanged.
|
|
||||||
func (c *ConnectClient) runOnAndroidEmbed(
|
|
||||||
iFaceDiscover stdnet.ExternalIFaceDiscover,
|
|
||||||
networkChangeListener listener.NetworkChangeListener,
|
|
||||||
dnsAddresses []netip.AddrPort,
|
|
||||||
dnsReadyListener dns.ReadyListener,
|
|
||||||
runningChan chan struct{},
|
|
||||||
logPath string,
|
|
||||||
) error {
|
|
||||||
mobileDependency := MobileDependency{
|
|
||||||
IFaceDiscover: iFaceDiscover,
|
|
||||||
NetworkChangeListener: networkChangeListener,
|
|
||||||
HostDNSAddresses: dnsAddresses,
|
|
||||||
DnsReadyListener: dnsReadyListener,
|
|
||||||
}
|
|
||||||
return c.run(mobileDependency, runningChan, logPath)
|
|
||||||
}
|
|
||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"runtime/pprof"
|
"runtime/pprof"
|
||||||
|
"slices"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -24,12 +25,12 @@ import (
|
|||||||
"google.golang.org/protobuf/encoding/protojson"
|
"google.golang.org/protobuf/encoding/protojson"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/anonymize"
|
"github.com/netbirdio/netbird/client/anonymize"
|
||||||
"github.com/netbirdio/netbird/client/configs"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
||||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const readmeContent = `Netbird debug bundle
|
const readmeContent = `Netbird debug bundle
|
||||||
@@ -51,7 +52,6 @@ resolved_domains.txt: Anonymized resolved domain IP addresses from the status re
|
|||||||
config.txt: Anonymized configuration information of the NetBird client.
|
config.txt: Anonymized configuration information of the NetBird client.
|
||||||
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
|
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
|
||||||
state.json: Anonymized client state dump containing netbird states for the active profile.
|
state.json: Anonymized client state dump containing netbird states for the active profile.
|
||||||
service_params.json: Sanitized service install parameters (service.json). Sensitive environment variable values are masked. Only present when service.json exists.
|
|
||||||
metrics.txt: Buffered client metrics in InfluxDB line protocol format. Only present when metrics collection is enabled. Peer identifiers are anonymized.
|
metrics.txt: Buffered client metrics in InfluxDB line protocol format. Only present when metrics collection is enabled. Peer identifiers are anonymized.
|
||||||
mutex.prof: Mutex profiling information.
|
mutex.prof: Mutex profiling information.
|
||||||
goroutine.prof: Goroutine profiling information.
|
goroutine.prof: Goroutine profiling information.
|
||||||
@@ -232,7 +232,6 @@ type BundleGenerator struct {
|
|||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
syncResponse *mgmProto.SyncResponse
|
syncResponse *mgmProto.SyncResponse
|
||||||
logPath string
|
logPath string
|
||||||
tempDir string
|
|
||||||
cpuProfile []byte
|
cpuProfile []byte
|
||||||
refreshStatus func() // Optional callback to refresh status before bundle generation
|
refreshStatus func() // Optional callback to refresh status before bundle generation
|
||||||
clientMetrics MetricsExporter
|
clientMetrics MetricsExporter
|
||||||
@@ -255,7 +254,6 @@ type GeneratorDependencies struct {
|
|||||||
StatusRecorder *peer.Status
|
StatusRecorder *peer.Status
|
||||||
SyncResponse *mgmProto.SyncResponse
|
SyncResponse *mgmProto.SyncResponse
|
||||||
LogPath string
|
LogPath string
|
||||||
TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used.
|
|
||||||
CPUProfile []byte
|
CPUProfile []byte
|
||||||
RefreshStatus func() // Optional callback to refresh status before bundle generation
|
RefreshStatus func() // Optional callback to refresh status before bundle generation
|
||||||
ClientMetrics MetricsExporter
|
ClientMetrics MetricsExporter
|
||||||
@@ -275,7 +273,6 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
|||||||
statusRecorder: deps.StatusRecorder,
|
statusRecorder: deps.StatusRecorder,
|
||||||
syncResponse: deps.SyncResponse,
|
syncResponse: deps.SyncResponse,
|
||||||
logPath: deps.LogPath,
|
logPath: deps.LogPath,
|
||||||
tempDir: deps.TempDir,
|
|
||||||
cpuProfile: deps.CPUProfile,
|
cpuProfile: deps.CPUProfile,
|
||||||
refreshStatus: deps.RefreshStatus,
|
refreshStatus: deps.RefreshStatus,
|
||||||
clientMetrics: deps.ClientMetrics,
|
clientMetrics: deps.ClientMetrics,
|
||||||
@@ -288,7 +285,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
|||||||
|
|
||||||
// Generate creates a debug bundle and returns the location.
|
// Generate creates a debug bundle and returns the location.
|
||||||
func (g *BundleGenerator) Generate() (resp string, err error) {
|
func (g *BundleGenerator) Generate() (resp string, err error) {
|
||||||
bundlePath, err := os.CreateTemp(g.tempDir, "netbird.debug.*.zip")
|
bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("create zip file: %w", err)
|
return "", fmt.Errorf("create zip file: %w", err)
|
||||||
}
|
}
|
||||||
@@ -362,10 +359,6 @@ func (g *BundleGenerator) createArchive() error {
|
|||||||
log.Errorf("failed to add corrupted state files to debug bundle: %v", err)
|
log.Errorf("failed to add corrupted state files to debug bundle: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := g.addServiceParams(); err != nil {
|
|
||||||
log.Errorf("failed to add service params to debug bundle: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := g.addMetrics(); err != nil {
|
if err := g.addMetrics(); err != nil {
|
||||||
log.Errorf("failed to add metrics to debug bundle: %v", err)
|
log.Errorf("failed to add metrics to debug bundle: %v", err)
|
||||||
}
|
}
|
||||||
@@ -374,8 +367,15 @@ func (g *BundleGenerator) createArchive() error {
|
|||||||
log.Errorf("failed to add wg show output: %v", err)
|
log.Errorf("failed to add wg show output: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := g.addPlatformLog(); err != nil {
|
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
|
||||||
log.Errorf("failed to add logs to debug bundle: %v", err)
|
if err := g.addLogfile(); err != nil {
|
||||||
|
log.Errorf("failed to add log file to debug bundle: %v", err)
|
||||||
|
if err := g.trySystemdLogFallback(); err != nil {
|
||||||
|
log.Errorf("failed to add systemd logs as fallback: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if err := g.trySystemdLogFallback(); err != nil {
|
||||||
|
log.Errorf("failed to add systemd logs: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := g.addUpdateLogs(); err != nil {
|
if err := g.addUpdateLogs(); err != nil {
|
||||||
@@ -488,90 +488,6 @@ func (g *BundleGenerator) addConfig() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
|
||||||
serviceParamsFile = "service.json"
|
|
||||||
serviceParamsBundle = "service_params.json"
|
|
||||||
maskedValue = "***"
|
|
||||||
envVarPrefix = "NB_"
|
|
||||||
jsonKeyManagementURL = "management_url"
|
|
||||||
jsonKeyServiceEnv = "service_env_vars"
|
|
||||||
)
|
|
||||||
|
|
||||||
var sensitiveEnvSubstrings = []string{"key", "token", "secret", "password", "credential"}
|
|
||||||
|
|
||||||
// addServiceParams reads the service.json file and adds a sanitized version to the bundle.
|
|
||||||
// Non-NB_ env vars and vars with sensitive names are masked. Other NB_ values are anonymized.
|
|
||||||
func (g *BundleGenerator) addServiceParams() error {
|
|
||||||
path := filepath.Join(configs.StateDir, serviceParamsFile)
|
|
||||||
|
|
||||||
data, err := os.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
if os.IsNotExist(err) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return fmt.Errorf("read service params: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var params map[string]any
|
|
||||||
if err := json.Unmarshal(data, ¶ms); err != nil {
|
|
||||||
return fmt.Errorf("parse service params: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if g.anonymize {
|
|
||||||
if mgmtURL, ok := params[jsonKeyManagementURL].(string); ok && mgmtURL != "" {
|
|
||||||
params[jsonKeyManagementURL] = g.anonymizer.AnonymizeURI(mgmtURL)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
g.sanitizeServiceEnvVars(params)
|
|
||||||
|
|
||||||
sanitizedData, err := json.MarshalIndent(params, "", " ")
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("marshal sanitized service params: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := g.addFileToZip(bytes.NewReader(sanitizedData), serviceParamsBundle); err != nil {
|
|
||||||
return fmt.Errorf("add service params to zip: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// sanitizeServiceEnvVars masks or anonymizes env var values in service params.
|
|
||||||
// Non-NB_ vars and vars with sensitive names (key, token, etc.) are fully masked.
|
|
||||||
// Other NB_ var values are passed through the anonymizer when anonymization is enabled.
|
|
||||||
func (g *BundleGenerator) sanitizeServiceEnvVars(params map[string]any) {
|
|
||||||
envVars, ok := params[jsonKeyServiceEnv].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sanitized := make(map[string]any, len(envVars))
|
|
||||||
for k, v := range envVars {
|
|
||||||
val, _ := v.(string)
|
|
||||||
switch {
|
|
||||||
case !strings.HasPrefix(k, envVarPrefix) || isSensitiveEnvVar(k):
|
|
||||||
sanitized[k] = maskedValue
|
|
||||||
case g.anonymize:
|
|
||||||
sanitized[k] = g.anonymizer.AnonymizeString(val)
|
|
||||||
default:
|
|
||||||
sanitized[k] = val
|
|
||||||
}
|
|
||||||
}
|
|
||||||
params[jsonKeyServiceEnv] = sanitized
|
|
||||||
}
|
|
||||||
|
|
||||||
// isSensitiveEnvVar returns true for env var names that may contain secrets.
|
|
||||||
func isSensitiveEnvVar(key string) bool {
|
|
||||||
lower := strings.ToLower(key)
|
|
||||||
for _, s := range sensitiveEnvSubstrings {
|
|
||||||
if strings.Contains(lower, s) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) {
|
func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) {
|
||||||
configContent.WriteString("NetBird Client Configuration:\n\n")
|
configContent.WriteString("NetBird Client Configuration:\n\n")
|
||||||
|
|
||||||
|
|||||||
@@ -1,41 +0,0 @@
|
|||||||
//go:build android
|
|
||||||
|
|
||||||
package debug
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"os/exec"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (g *BundleGenerator) addPlatformLog() error {
|
|
||||||
cmd := exec.Command("/system/bin/logcat", "-d")
|
|
||||||
stdout, err := cmd.StdoutPipe()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("logcat stdout pipe: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := cmd.Start(); err != nil {
|
|
||||||
return fmt.Errorf("start logcat: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var logReader io.Reader = stdout
|
|
||||||
if g.anonymize {
|
|
||||||
var pw *io.PipeWriter
|
|
||||||
logReader, pw = io.Pipe()
|
|
||||||
go anonymizeLog(stdout, pw, g.anonymizer)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := g.addFileToZip(logReader, "logcat.txt"); err != nil {
|
|
||||||
return fmt.Errorf("add logcat to zip: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := cmd.Wait(); err != nil {
|
|
||||||
return fmt.Errorf("wait logcat: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug("added logcat output to debug bundle")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package debug
|
|
||||||
|
|
||||||
import (
|
|
||||||
"slices"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/util"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (g *BundleGenerator) addPlatformLog() error {
|
|
||||||
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
|
|
||||||
if err := g.addLogfile(); err != nil {
|
|
||||||
log.Errorf("failed to add log file to debug bundle: %v", err)
|
|
||||||
if err := g.trySystemdLogFallback(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if err := g.trySystemdLogFallback(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,12 +1,8 @@
|
|||||||
package debug
|
package debug
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"archive/zip"
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -14,7 +10,6 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/anonymize"
|
"github.com/netbirdio/netbird/client/anonymize"
|
||||||
"github.com/netbirdio/netbird/client/configs"
|
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -425,226 +420,6 @@ func TestAnonymizeNetworkMap(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIsSensitiveEnvVar(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
key string
|
|
||||||
sensitive bool
|
|
||||||
}{
|
|
||||||
{"NB_SETUP_KEY", true},
|
|
||||||
{"NB_API_TOKEN", true},
|
|
||||||
{"NB_CLIENT_SECRET", true},
|
|
||||||
{"NB_PASSWORD", true},
|
|
||||||
{"NB_CREDENTIAL", true},
|
|
||||||
{"NB_LOG_LEVEL", false},
|
|
||||||
{"NB_MANAGEMENT_URL", false},
|
|
||||||
{"NB_HOSTNAME", false},
|
|
||||||
{"HOME", false},
|
|
||||||
{"PATH", false},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.key, func(t *testing.T) {
|
|
||||||
assert.Equal(t, tt.sensitive, isSensitiveEnvVar(tt.key))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSanitizeServiceEnvVars(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
anonymize bool
|
|
||||||
input map[string]any
|
|
||||||
check func(t *testing.T, params map[string]any)
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "no env vars key",
|
|
||||||
anonymize: false,
|
|
||||||
input: map[string]any{"management_url": "https://mgmt.example.com"},
|
|
||||||
check: func(t *testing.T, params map[string]any) {
|
|
||||||
t.Helper()
|
|
||||||
assert.Equal(t, "https://mgmt.example.com", params["management_url"], "non-env fields should be untouched")
|
|
||||||
_, ok := params[jsonKeyServiceEnv]
|
|
||||||
assert.False(t, ok, "service_env_vars should not be added")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "non-NB vars are masked",
|
|
||||||
anonymize: false,
|
|
||||||
input: map[string]any{
|
|
||||||
jsonKeyServiceEnv: map[string]any{
|
|
||||||
"HOME": "/root",
|
|
||||||
"PATH": "/usr/bin",
|
|
||||||
"NB_LOG_LEVEL": "debug",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
check: func(t *testing.T, params map[string]any) {
|
|
||||||
t.Helper()
|
|
||||||
env := params[jsonKeyServiceEnv].(map[string]any)
|
|
||||||
assert.Equal(t, maskedValue, env["HOME"], "non-NB_ var should be masked")
|
|
||||||
assert.Equal(t, maskedValue, env["PATH"], "non-NB_ var should be masked")
|
|
||||||
assert.Equal(t, "debug", env["NB_LOG_LEVEL"], "safe NB_ var should pass through")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "sensitive NB vars are masked",
|
|
||||||
anonymize: false,
|
|
||||||
input: map[string]any{
|
|
||||||
jsonKeyServiceEnv: map[string]any{
|
|
||||||
"NB_SETUP_KEY": "abc123",
|
|
||||||
"NB_API_TOKEN": "tok_xyz",
|
|
||||||
"NB_LOG_LEVEL": "info",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
check: func(t *testing.T, params map[string]any) {
|
|
||||||
t.Helper()
|
|
||||||
env := params[jsonKeyServiceEnv].(map[string]any)
|
|
||||||
assert.Equal(t, maskedValue, env["NB_SETUP_KEY"], "sensitive NB_ var should be masked")
|
|
||||||
assert.Equal(t, maskedValue, env["NB_API_TOKEN"], "sensitive NB_ var should be masked")
|
|
||||||
assert.Equal(t, "info", env["NB_LOG_LEVEL"], "safe NB_ var should pass through")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "safe NB vars anonymized when anonymize is true",
|
|
||||||
anonymize: true,
|
|
||||||
input: map[string]any{
|
|
||||||
jsonKeyServiceEnv: map[string]any{
|
|
||||||
"NB_MANAGEMENT_URL": "https://mgmt.example.com:443",
|
|
||||||
"NB_LOG_LEVEL": "debug",
|
|
||||||
"NB_SETUP_KEY": "secret",
|
|
||||||
"SOME_OTHER": "val",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
check: func(t *testing.T, params map[string]any) {
|
|
||||||
t.Helper()
|
|
||||||
env := params[jsonKeyServiceEnv].(map[string]any)
|
|
||||||
// Safe NB_ values should be anonymized (not the original, not masked)
|
|
||||||
mgmtVal := env["NB_MANAGEMENT_URL"].(string)
|
|
||||||
assert.NotEqual(t, "https://mgmt.example.com:443", mgmtVal, "should be anonymized")
|
|
||||||
assert.NotEqual(t, maskedValue, mgmtVal, "should not be masked")
|
|
||||||
|
|
||||||
logVal := env["NB_LOG_LEVEL"].(string)
|
|
||||||
assert.NotEqual(t, maskedValue, logVal, "safe NB_ var should not be masked")
|
|
||||||
|
|
||||||
// Sensitive and non-NB_ still masked
|
|
||||||
assert.Equal(t, maskedValue, env["NB_SETUP_KEY"])
|
|
||||||
assert.Equal(t, maskedValue, env["SOME_OTHER"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
|
||||||
g := &BundleGenerator{
|
|
||||||
anonymize: tt.anonymize,
|
|
||||||
anonymizer: anonymizer,
|
|
||||||
}
|
|
||||||
g.sanitizeServiceEnvVars(tt.input)
|
|
||||||
tt.check(t, tt.input)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAddServiceParams(t *testing.T) {
|
|
||||||
t.Run("missing service.json returns nil", func(t *testing.T) {
|
|
||||||
g := &BundleGenerator{
|
|
||||||
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
|
|
||||||
}
|
|
||||||
|
|
||||||
origStateDir := configs.StateDir
|
|
||||||
configs.StateDir = t.TempDir()
|
|
||||||
t.Cleanup(func() { configs.StateDir = origStateDir })
|
|
||||||
|
|
||||||
err := g.addServiceParams()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("management_url anonymized when anonymize is true", func(t *testing.T) {
|
|
||||||
dir := t.TempDir()
|
|
||||||
origStateDir := configs.StateDir
|
|
||||||
configs.StateDir = dir
|
|
||||||
t.Cleanup(func() { configs.StateDir = origStateDir })
|
|
||||||
|
|
||||||
input := map[string]any{
|
|
||||||
jsonKeyManagementURL: "https://api.example.com:443",
|
|
||||||
jsonKeyServiceEnv: map[string]any{
|
|
||||||
"NB_LOG_LEVEL": "trace",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
data, err := json.Marshal(input)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600))
|
|
||||||
|
|
||||||
var buf bytes.Buffer
|
|
||||||
zw := zip.NewWriter(&buf)
|
|
||||||
|
|
||||||
g := &BundleGenerator{
|
|
||||||
anonymize: true,
|
|
||||||
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
|
|
||||||
archive: zw,
|
|
||||||
}
|
|
||||||
|
|
||||||
require.NoError(t, g.addServiceParams())
|
|
||||||
require.NoError(t, zw.Close())
|
|
||||||
|
|
||||||
zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, zr.File, 1)
|
|
||||||
assert.Equal(t, serviceParamsBundle, zr.File[0].Name)
|
|
||||||
|
|
||||||
rc, err := zr.File[0].Open()
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer rc.Close()
|
|
||||||
|
|
||||||
var result map[string]any
|
|
||||||
require.NoError(t, json.NewDecoder(rc).Decode(&result))
|
|
||||||
|
|
||||||
mgmt := result[jsonKeyManagementURL].(string)
|
|
||||||
assert.NotEqual(t, "https://api.example.com:443", mgmt, "management_url should be anonymized")
|
|
||||||
assert.NotEmpty(t, mgmt)
|
|
||||||
|
|
||||||
env := result[jsonKeyServiceEnv].(map[string]any)
|
|
||||||
assert.NotEqual(t, maskedValue, env["NB_LOG_LEVEL"], "safe NB_ var should not be masked")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("management_url preserved when anonymize is false", func(t *testing.T) {
|
|
||||||
dir := t.TempDir()
|
|
||||||
origStateDir := configs.StateDir
|
|
||||||
configs.StateDir = dir
|
|
||||||
t.Cleanup(func() { configs.StateDir = origStateDir })
|
|
||||||
|
|
||||||
input := map[string]any{
|
|
||||||
jsonKeyManagementURL: "https://api.example.com:443",
|
|
||||||
}
|
|
||||||
data, err := json.Marshal(input)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600))
|
|
||||||
|
|
||||||
var buf bytes.Buffer
|
|
||||||
zw := zip.NewWriter(&buf)
|
|
||||||
|
|
||||||
g := &BundleGenerator{
|
|
||||||
anonymize: false,
|
|
||||||
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
|
|
||||||
archive: zw,
|
|
||||||
}
|
|
||||||
|
|
||||||
require.NoError(t, g.addServiceParams())
|
|
||||||
require.NoError(t, zw.Close())
|
|
||||||
|
|
||||||
zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
rc, err := zr.File[0].Open()
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer rc.Close()
|
|
||||||
|
|
||||||
var result map[string]any
|
|
||||||
require.NoError(t, json.NewDecoder(rc).Decode(&result))
|
|
||||||
|
|
||||||
assert.Equal(t, "https://api.example.com:443", result[jsonKeyManagementURL], "management_url should be preserved")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper function to check if IP is in CGNAT range
|
// Helper function to check if IP is in CGNAT range
|
||||||
func isInCGNATRange(ip net.IP) bool {
|
func isInCGNATRange(ip net.IP) bool {
|
||||||
cgnat := net.IPNet{
|
cgnat := net.IPNet{
|
||||||
|
|||||||
@@ -73,9 +73,6 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
w.response = m
|
w.response = m
|
||||||
if m.MsgHdr.Truncated {
|
|
||||||
w.SetMeta("truncated", "true")
|
|
||||||
}
|
|
||||||
return w.ResponseWriter.WriteMsg(m)
|
return w.ResponseWriter.WriteMsg(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -198,14 +195,10 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
|
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
requestID := resutil.GenerateRequestID()
|
requestID := resutil.GenerateRequestID()
|
||||||
fields := log.Fields{
|
logger := log.WithFields(log.Fields{
|
||||||
"request_id": requestID,
|
"request_id": requestID,
|
||||||
"dns_id": fmt.Sprintf("%04x", r.Id),
|
"dns_id": fmt.Sprintf("%04x", r.Id),
|
||||||
}
|
})
|
||||||
if addr := w.RemoteAddr(); addr != nil {
|
|
||||||
fields["client"] = addr.String()
|
|
||||||
}
|
|
||||||
logger := log.WithFields(fields)
|
|
||||||
|
|
||||||
question := r.Question[0]
|
question := r.Question[0]
|
||||||
qname := strings.ToLower(question.Name)
|
qname := strings.ToLower(question.Name)
|
||||||
@@ -268,9 +261,9 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
|
|||||||
meta += " " + k + "=" + v
|
meta += " " + k + "=" + v
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB%s took=%s",
|
logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s",
|
||||||
qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer),
|
qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer),
|
||||||
cw.response.Len(), meta, time.Since(startTime))
|
meta, time.Since(startTime))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||||
|
|||||||
@@ -1263,9 +1263,9 @@ func TestLocalResolver_AuthoritativeFlag(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestLocalResolver_Stop tests cleanup on GracefullyStop
|
// TestLocalResolver_Stop tests cleanup on Stop
|
||||||
func TestLocalResolver_Stop(t *testing.T) {
|
func TestLocalResolver_Stop(t *testing.T) {
|
||||||
t.Run("GracefullyStop clears all state", func(t *testing.T) {
|
t.Run("Stop clears all state", func(t *testing.T) {
|
||||||
resolver := NewResolver()
|
resolver := NewResolver()
|
||||||
resolver.Update([]nbdns.CustomZone{{
|
resolver.Update([]nbdns.CustomZone{{
|
||||||
Domain: "example.com.",
|
Domain: "example.com.",
|
||||||
@@ -1285,7 +1285,7 @@ func TestLocalResolver_Stop(t *testing.T) {
|
|||||||
assert.False(t, resolver.isInManagedZone("host.example.com."))
|
assert.False(t, resolver.isInManagedZone("host.example.com."))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("GracefullyStop is safe to call multiple times", func(t *testing.T) {
|
t.Run("Stop is safe to call multiple times", func(t *testing.T) {
|
||||||
resolver := NewResolver()
|
resolver := NewResolver()
|
||||||
resolver.Update([]nbdns.CustomZone{{
|
resolver.Update([]nbdns.CustomZone{{
|
||||||
Domain: "example.com.",
|
Domain: "example.com.",
|
||||||
@@ -1299,7 +1299,7 @@ func TestLocalResolver_Stop(t *testing.T) {
|
|||||||
resolver.Stop()
|
resolver.Stop()
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("GracefullyStop cancels in-flight external resolution", func(t *testing.T) {
|
t.Run("Stop cancels in-flight external resolution", func(t *testing.T) {
|
||||||
resolver := NewResolver()
|
resolver := NewResolver()
|
||||||
|
|
||||||
lookupStarted := make(chan struct{})
|
lookupStarted := make(chan struct{})
|
||||||
|
|||||||
@@ -90,11 +90,6 @@ func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) {
|
|||||||
// Mock implementation - no-op
|
// Mock implementation - no-op
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetFirewall mock implementation of SetFirewall from Server interface
|
|
||||||
func (m *MockServer) SetFirewall(Firewall) {
|
|
||||||
// Mock implementation - no-op
|
|
||||||
}
|
|
||||||
|
|
||||||
// BeginBatch mock implementation of BeginBatch from Server interface
|
// BeginBatch mock implementation of BeginBatch from Server interface
|
||||||
func (m *MockServer) BeginBatch() {
|
func (m *MockServer) BeginBatch() {
|
||||||
// Mock implementation - no-op
|
// Mock implementation - no-op
|
||||||
|
|||||||
@@ -104,23 +104,3 @@ func (r *responseWriter) TsigTimersOnly(bool) {
|
|||||||
// After a call to Hijack(), the DNS package will not do anything with the connection.
|
// After a call to Hijack(), the DNS package will not do anything with the connection.
|
||||||
func (r *responseWriter) Hijack() {
|
func (r *responseWriter) Hijack() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// remoteAddrFromPacket extracts the source IP:port from a decoded packet for logging.
|
|
||||||
func remoteAddrFromPacket(packet gopacket.Packet) *net.UDPAddr {
|
|
||||||
var srcIP net.IP
|
|
||||||
if ipv4 := packet.Layer(layers.LayerTypeIPv4); ipv4 != nil {
|
|
||||||
srcIP = ipv4.(*layers.IPv4).SrcIP
|
|
||||||
} else if ipv6 := packet.Layer(layers.LayerTypeIPv6); ipv6 != nil {
|
|
||||||
srcIP = ipv6.(*layers.IPv6).SrcIP
|
|
||||||
}
|
|
||||||
|
|
||||||
var srcPort int
|
|
||||||
if udp := packet.Layer(layers.LayerTypeUDP); udp != nil {
|
|
||||||
srcPort = int(udp.(*layers.UDP).SrcPort)
|
|
||||||
}
|
|
||||||
|
|
||||||
if srcIP == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return &net.UDPAddr{IP: srcIP, Port: srcPort}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -58,7 +58,6 @@ type Server interface {
|
|||||||
UpdateServerConfig(domains dnsconfig.ServerDomains) error
|
UpdateServerConfig(domains dnsconfig.ServerDomains) error
|
||||||
PopulateManagementDomain(mgmtURL *url.URL) error
|
PopulateManagementDomain(mgmtURL *url.URL) error
|
||||||
SetRouteChecker(func(netip.Addr) bool)
|
SetRouteChecker(func(netip.Addr) bool)
|
||||||
SetFirewall(Firewall)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type nsGroupsByDomain struct {
|
type nsGroupsByDomain struct {
|
||||||
@@ -152,7 +151,7 @@ func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*Default
|
|||||||
if config.WgInterface.IsUserspaceBind() {
|
if config.WgInterface.IsUserspaceBind() {
|
||||||
dnsService = NewServiceViaMemory(config.WgInterface)
|
dnsService = NewServiceViaMemory(config.WgInterface)
|
||||||
} else {
|
} else {
|
||||||
dnsService = newServiceViaListener(config.WgInterface, addrPort, nil)
|
dnsService = newServiceViaListener(config.WgInterface, addrPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys)
|
server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys)
|
||||||
@@ -187,16 +186,11 @@ func NewDefaultServerIos(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
wgInterface WGIface,
|
wgInterface WGIface,
|
||||||
iosDnsManager IosDnsManager,
|
iosDnsManager IosDnsManager,
|
||||||
hostsDnsList []netip.AddrPort,
|
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
disableSys bool,
|
disableSys bool,
|
||||||
) *DefaultServer {
|
) *DefaultServer {
|
||||||
log.Debugf("iOS host dns address list is: %v", hostsDnsList)
|
|
||||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
|
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
|
||||||
ds.iosDnsManager = iosDnsManager
|
ds.iosDnsManager = iosDnsManager
|
||||||
ds.hostsDNSHolder.set(hostsDnsList)
|
|
||||||
ds.permanent = true
|
|
||||||
ds.addHostRootZone()
|
|
||||||
return ds
|
return ds
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -380,17 +374,6 @@ func (s *DefaultServer) DnsIP() netip.Addr {
|
|||||||
return s.service.RuntimeIP()
|
return s.service.RuntimeIP()
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetFirewall sets the firewall used for DNS port DNAT rules.
|
|
||||||
// This must be called before Initialize when using the listener-based service,
|
|
||||||
// because the firewall is typically not available at construction time.
|
|
||||||
func (s *DefaultServer) SetFirewall(fw Firewall) {
|
|
||||||
if svc, ok := s.service.(*serviceViaListener); ok {
|
|
||||||
svc.listenerFlagLock.Lock()
|
|
||||||
svc.firewall = fw
|
|
||||||
svc.listenerFlagLock.Unlock()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop stops the server
|
// Stop stops the server
|
||||||
func (s *DefaultServer) Stop() {
|
func (s *DefaultServer) Stop() {
|
||||||
s.probeMu.Lock()
|
s.probeMu.Lock()
|
||||||
@@ -412,12 +395,8 @@ func (s *DefaultServer) Stop() {
|
|||||||
maps.Clear(s.extraDomains)
|
maps.Clear(s.extraDomains)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) disableDNS() (retErr error) {
|
func (s *DefaultServer) disableDNS() error {
|
||||||
defer func() {
|
defer s.service.Stop()
|
||||||
if err := s.service.Stop(); err != nil {
|
|
||||||
retErr = errors.Join(retErr, fmt.Errorf("stop DNS service: %w", err))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if s.isUsingNoopHostManager() {
|
if s.isUsingNoopHostManager() {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -476,8 +476,8 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
|
|
||||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||||
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
|
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
|
||||||
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||||
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
||||||
|
|
||||||
if err := wgIface.SetFilter(packetfilter); err != nil {
|
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||||
t.Errorf("set packet filter: %v", err)
|
t.Errorf("set packet filter: %v", err)
|
||||||
@@ -1071,7 +1071,7 @@ func (m *mockHandler) ID() types.HandlerID { return types.Hand
|
|||||||
type mockService struct{}
|
type mockService struct{}
|
||||||
|
|
||||||
func (m *mockService) Listen() error { return nil }
|
func (m *mockService) Listen() error { return nil }
|
||||||
func (m *mockService) Stop() error { return nil }
|
func (m *mockService) Stop() {}
|
||||||
func (m *mockService) RuntimeIP() netip.Addr { return netip.MustParseAddr("127.0.0.1") }
|
func (m *mockService) RuntimeIP() netip.Addr { return netip.MustParseAddr("127.0.0.1") }
|
||||||
func (m *mockService) RuntimePort() int { return 53 }
|
func (m *mockService) RuntimePort() int { return 53 }
|
||||||
func (m *mockService) RegisterMux(string, dns.Handler) {}
|
func (m *mockService) RegisterMux(string, dns.Handler) {}
|
||||||
|
|||||||
@@ -4,25 +4,15 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
DefaultPort = 53
|
DefaultPort = 53
|
||||||
)
|
)
|
||||||
|
|
||||||
// Firewall provides DNAT capabilities for DNS port redirection.
|
|
||||||
// This is used when the DNS server cannot bind port 53 directly
|
|
||||||
// and needs firewall rules to redirect traffic.
|
|
||||||
type Firewall interface {
|
|
||||||
AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error
|
|
||||||
RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type service interface {
|
type service interface {
|
||||||
Listen() error
|
Listen() error
|
||||||
Stop() error
|
Stop()
|
||||||
RegisterMux(domain string, handler dns.Handler)
|
RegisterMux(domain string, handler dns.Handler)
|
||||||
DeregisterMux(key string)
|
DeregisterMux(key string)
|
||||||
RuntimePort() int
|
RuntimePort() int
|
||||||
|
|||||||
@@ -10,13 +10,9 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/ebpf"
|
"github.com/netbirdio/netbird/client/internal/ebpf"
|
||||||
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||||
)
|
)
|
||||||
@@ -35,33 +31,25 @@ type serviceViaListener struct {
|
|||||||
dnsMux *dns.ServeMux
|
dnsMux *dns.ServeMux
|
||||||
customAddr *netip.AddrPort
|
customAddr *netip.AddrPort
|
||||||
server *dns.Server
|
server *dns.Server
|
||||||
tcpServer *dns.Server
|
|
||||||
listenIP netip.Addr
|
listenIP netip.Addr
|
||||||
listenPort uint16
|
listenPort uint16
|
||||||
listenerIsRunning bool
|
listenerIsRunning bool
|
||||||
listenerFlagLock sync.Mutex
|
listenerFlagLock sync.Mutex
|
||||||
ebpfService ebpfMgr.Manager
|
ebpfService ebpfMgr.Manager
|
||||||
firewall Firewall
|
|
||||||
tcpDNATConfigured bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort, fw Firewall) *serviceViaListener {
|
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
|
||||||
mux := dns.NewServeMux()
|
mux := dns.NewServeMux()
|
||||||
|
|
||||||
s := &serviceViaListener{
|
s := &serviceViaListener{
|
||||||
wgInterface: wgIface,
|
wgInterface: wgIface,
|
||||||
dnsMux: mux,
|
dnsMux: mux,
|
||||||
customAddr: customAddr,
|
customAddr: customAddr,
|
||||||
firewall: fw,
|
|
||||||
server: &dns.Server{
|
server: &dns.Server{
|
||||||
Net: "udp",
|
Net: "udp",
|
||||||
Handler: mux,
|
Handler: mux,
|
||||||
UDPSize: 65535,
|
UDPSize: 65535,
|
||||||
},
|
},
|
||||||
tcpServer: &dns.Server{
|
|
||||||
Net: "tcp",
|
|
||||||
Handler: mux,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return s
|
return s
|
||||||
@@ -82,86 +70,43 @@ func (s *serviceViaListener) Listen() error {
|
|||||||
return fmt.Errorf("eval listen address: %w", err)
|
return fmt.Errorf("eval listen address: %w", err)
|
||||||
}
|
}
|
||||||
s.listenIP = s.listenIP.Unmap()
|
s.listenIP = s.listenIP.Unmap()
|
||||||
addr := net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort)))
|
s.server.Addr = net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort)))
|
||||||
s.server.Addr = addr
|
log.Debugf("starting dns on %s", s.server.Addr)
|
||||||
s.tcpServer.Addr = addr
|
|
||||||
|
|
||||||
log.Debugf("starting dns on %s (UDP + TCP)", addr)
|
|
||||||
s.listenerIsRunning = true
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if err := s.server.ListenAndServe(); err != nil {
|
s.setListenerStatus(true)
|
||||||
log.Errorf("failed to run DNS UDP server on port %d: %v", s.listenPort, err)
|
defer s.setListenerStatus(false)
|
||||||
}
|
|
||||||
|
|
||||||
s.listenerFlagLock.Lock()
|
err := s.server.ListenAndServe()
|
||||||
unexpected := s.listenerIsRunning
|
if err != nil {
|
||||||
s.listenerIsRunning = false
|
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.listenPort, err)
|
||||||
s.listenerFlagLock.Unlock()
|
|
||||||
|
|
||||||
if unexpected {
|
|
||||||
if err := s.tcpServer.Shutdown(); err != nil {
|
|
||||||
log.Debugf("failed to shutdown DNS TCP server: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
|
||||||
if err := s.tcpServer.ListenAndServe(); err != nil {
|
|
||||||
log.Errorf("failed to run DNS TCP server on port %d: %v", s.listenPort, err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// When eBPF redirects UDP port 53 to our listen port, TCP still needs
|
|
||||||
// a DNAT rule because eBPF only handles UDP.
|
|
||||||
if s.ebpfService != nil && s.firewall != nil && s.listenPort != DefaultPort {
|
|
||||||
if err := s.firewall.AddOutputDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil {
|
|
||||||
log.Warnf("failed to add DNS TCP DNAT rule, TCP DNS on port 53 will not work: %v", err)
|
|
||||||
} else {
|
|
||||||
s.tcpDNATConfigured = true
|
|
||||||
log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", s.listenIP, DefaultPort, s.listenIP, s.listenPort)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceViaListener) Stop() error {
|
func (s *serviceViaListener) Stop() {
|
||||||
s.listenerFlagLock.Lock()
|
s.listenerFlagLock.Lock()
|
||||||
defer s.listenerFlagLock.Unlock()
|
defer s.listenerFlagLock.Unlock()
|
||||||
|
|
||||||
if !s.listenerIsRunning {
|
if !s.listenerIsRunning {
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
s.listenerIsRunning = false
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
var merr *multierror.Error
|
err := s.server.ShutdownContext(ctx)
|
||||||
|
if err != nil {
|
||||||
if err := s.server.ShutdownContext(ctx); err != nil {
|
log.Errorf("stopping dns server listener returned an error: %v", err)
|
||||||
merr = multierror.Append(merr, fmt.Errorf("stop DNS UDP server: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.tcpServer.ShutdownContext(ctx); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("stop DNS TCP server: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.tcpDNATConfigured && s.firewall != nil {
|
|
||||||
if err := s.firewall.RemoveOutputDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err))
|
|
||||||
}
|
|
||||||
s.tcpDNATConfigured = false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.ebpfService != nil {
|
if s.ebpfService != nil {
|
||||||
if err := s.ebpfService.FreeDNSFwd(); err != nil {
|
err = s.ebpfService.FreeDNSFwd()
|
||||||
merr = multierror.Append(merr, fmt.Errorf("stop traffic forwarder: %w", err))
|
if err != nil {
|
||||||
|
log.Errorf("stopping traffic forwarder returned an error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
|
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
|
||||||
@@ -188,6 +133,12 @@ func (s *serviceViaListener) RuntimeIP() netip.Addr {
|
|||||||
return s.listenIP
|
return s.listenIP
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaListener) setListenerStatus(running bool) {
|
||||||
|
s.listenerFlagLock.Lock()
|
||||||
|
defer s.listenerFlagLock.Unlock()
|
||||||
|
|
||||||
|
s.listenerIsRunning = running
|
||||||
|
}
|
||||||
|
|
||||||
// evalListenAddress figure out the listen address for the DNS server
|
// evalListenAddress figure out the listen address for the DNS server
|
||||||
// first check the 53 port availability on WG interface or lo, if not success
|
// first check the 53 port availability on WG interface or lo, if not success
|
||||||
@@ -236,28 +187,18 @@ func (s *serviceViaListener) testFreePort(port int) (netip.Addr, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceViaListener) tryToBind(ip netip.Addr, port int) bool {
|
func (s *serviceViaListener) tryToBind(ip netip.Addr, port int) bool {
|
||||||
addrPort := netip.AddrPortFrom(ip, uint16(port))
|
addrString := net.JoinHostPort(ip.String(), strconv.Itoa(port))
|
||||||
|
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
||||||
udpAddr := net.UDPAddrFromAddrPort(addrPort)
|
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||||
udpLn, err := net.ListenUDP("udp", udpAddr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("binding dns UDP on %s is not available: %s", addrPort, err)
|
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if err := udpLn.Close(); err != nil {
|
|
||||||
log.Debugf("close UDP probe listener: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tcpAddr := net.TCPAddrFromAddrPort(addrPort)
|
err = probeListener.Close()
|
||||||
tcpLn, err := net.ListenTCP("tcp", tcpAddr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("binding dns TCP on %s is not available: %s", addrPort, err)
|
log.Errorf("got an error closing the probe listener, error: %s", err)
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
if err := tcpLn.Close(); err != nil {
|
|
||||||
log.Debugf("close TCP probe listener: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,86 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestServiceViaListener_TCPAndUDP(t *testing.T) {
|
|
||||||
handler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
|
||||||
m := new(dns.Msg)
|
|
||||||
m.SetReply(r)
|
|
||||||
m.Answer = append(m.Answer, &dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP("192.0.2.1"),
|
|
||||||
})
|
|
||||||
if err := w.WriteMsg(m); err != nil {
|
|
||||||
t.Logf("write msg: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// Create a service using a custom address to avoid needing root
|
|
||||||
svc := newServiceViaListener(nil, nil, nil)
|
|
||||||
svc.dnsMux.Handle(".", handler)
|
|
||||||
|
|
||||||
// Bind both transports up front to avoid TOCTOU races.
|
|
||||||
udpAddr := net.UDPAddrFromAddrPort(netip.AddrPortFrom(customIP, 0))
|
|
||||||
udpConn, err := net.ListenUDP("udp", udpAddr)
|
|
||||||
if err != nil {
|
|
||||||
t.Skip("cannot bind to 127.0.0.153, skipping")
|
|
||||||
}
|
|
||||||
port := uint16(udpConn.LocalAddr().(*net.UDPAddr).Port)
|
|
||||||
|
|
||||||
tcpAddr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(customIP, port))
|
|
||||||
tcpLn, err := net.ListenTCP("tcp", tcpAddr)
|
|
||||||
if err != nil {
|
|
||||||
udpConn.Close()
|
|
||||||
t.Skip("cannot bind TCP on same port, skipping")
|
|
||||||
}
|
|
||||||
|
|
||||||
addr := fmt.Sprintf("%s:%d", customIP, port)
|
|
||||||
svc.server.PacketConn = udpConn
|
|
||||||
svc.tcpServer.Listener = tcpLn
|
|
||||||
svc.listenIP = customIP
|
|
||||||
svc.listenPort = port
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
if err := svc.server.ActivateAndServe(); err != nil {
|
|
||||||
t.Logf("udp server: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
if err := svc.tcpServer.ActivateAndServe(); err != nil {
|
|
||||||
t.Logf("tcp server: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
svc.listenerIsRunning = true
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
require.NoError(t, svc.Stop())
|
|
||||||
}()
|
|
||||||
|
|
||||||
q := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
|
||||||
|
|
||||||
// Test UDP query
|
|
||||||
udpClient := &dns.Client{Net: "udp", Timeout: 2 * time.Second}
|
|
||||||
udpResp, _, err := udpClient.Exchange(q, addr)
|
|
||||||
require.NoError(t, err, "UDP query should succeed")
|
|
||||||
require.NotNil(t, udpResp)
|
|
||||||
require.NotEmpty(t, udpResp.Answer)
|
|
||||||
assert.Contains(t, udpResp.Answer[0].String(), "192.0.2.1", "UDP response should contain expected IP")
|
|
||||||
|
|
||||||
// Test TCP query
|
|
||||||
tcpClient := &dns.Client{Net: "tcp", Timeout: 2 * time.Second}
|
|
||||||
tcpResp, _, err := tcpClient.Exchange(q, addr)
|
|
||||||
require.NoError(t, err, "TCP query should succeed")
|
|
||||||
require.NotNil(t, tcpResp)
|
|
||||||
require.NotEmpty(t, tcpResp.Answer)
|
|
||||||
assert.Contains(t, tcpResp.Answer[0].String(), "192.0.2.1", "TCP response should contain expected IP")
|
|
||||||
}
|
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -11,7 +10,6 @@ import (
|
|||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
|
||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -20,8 +18,7 @@ type ServiceViaMemory struct {
|
|||||||
dnsMux *dns.ServeMux
|
dnsMux *dns.ServeMux
|
||||||
runtimeIP netip.Addr
|
runtimeIP netip.Addr
|
||||||
runtimePort int
|
runtimePort int
|
||||||
tcpDNS *tcpDNSServer
|
udpFilterHookID string
|
||||||
tcpHookSet bool
|
|
||||||
listenerIsRunning bool
|
listenerIsRunning bool
|
||||||
listenerFlagLock sync.Mutex
|
listenerFlagLock sync.Mutex
|
||||||
}
|
}
|
||||||
@@ -31,13 +28,14 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("get last ip from network: %v", err)
|
log.Errorf("get last ip from network: %v", err)
|
||||||
}
|
}
|
||||||
|
s := &ServiceViaMemory{
|
||||||
return &ServiceViaMemory{
|
|
||||||
wgInterface: wgIface,
|
wgInterface: wgIface,
|
||||||
dnsMux: dns.NewServeMux(),
|
dnsMux: dns.NewServeMux(),
|
||||||
|
|
||||||
runtimeIP: lastIP,
|
runtimeIP: lastIP,
|
||||||
runtimePort: DefaultPort,
|
runtimePort: DefaultPort,
|
||||||
}
|
}
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServiceViaMemory) Listen() error {
|
func (s *ServiceViaMemory) Listen() error {
|
||||||
@@ -48,8 +46,10 @@ func (s *ServiceViaMemory) Listen() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.filterDNSTraffic(); err != nil {
|
var err error
|
||||||
return fmt.Errorf("filter dns traffic: %w", err)
|
s.udpFilterHookID, err = s.filterDNSTraffic()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("filter dns traffice: %w", err)
|
||||||
}
|
}
|
||||||
s.listenerIsRunning = true
|
s.listenerIsRunning = true
|
||||||
|
|
||||||
@@ -57,29 +57,19 @@ func (s *ServiceViaMemory) Listen() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServiceViaMemory) Stop() error {
|
func (s *ServiceViaMemory) Stop() {
|
||||||
s.listenerFlagLock.Lock()
|
s.listenerFlagLock.Lock()
|
||||||
defer s.listenerFlagLock.Unlock()
|
defer s.listenerFlagLock.Unlock()
|
||||||
|
|
||||||
if !s.listenerIsRunning {
|
if !s.listenerIsRunning {
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
filter := s.wgInterface.GetFilter()
|
if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil {
|
||||||
if filter != nil {
|
log.Errorf("unable to remove DNS packet hook: %s", err)
|
||||||
filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil)
|
|
||||||
if s.tcpHookSet {
|
|
||||||
filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.tcpDNS != nil {
|
|
||||||
s.tcpDNS.Stop()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s.listenerIsRunning = false
|
s.listenerIsRunning = false
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
|
func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
|
||||||
@@ -98,18 +88,10 @@ func (s *ServiceViaMemory) RuntimeIP() netip.Addr {
|
|||||||
return s.runtimeIP
|
return s.runtimeIP
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServiceViaMemory) filterDNSTraffic() error {
|
func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
||||||
filter := s.wgInterface.GetFilter()
|
filter := s.wgInterface.GetFilter()
|
||||||
if filter == nil {
|
if filter == nil {
|
||||||
return errors.New("DNS filter not initialized")
|
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
|
||||||
}
|
|
||||||
|
|
||||||
// Create TCP DNS server lazily here since the device may not exist at construction time.
|
|
||||||
if s.tcpDNS == nil {
|
|
||||||
if dev := s.wgInterface.GetDevice(); dev != nil {
|
|
||||||
// MTU only affects TCP segment sizing; DNS messages are small so this has no practical impact.
|
|
||||||
s.tcpDNS = newTCPDNSServer(s.dnsMux, dev.Device, s.runtimeIP, uint16(s.runtimePort), iface.DefaultMTU)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
firstLayerDecoder := layers.LayerTypeIPv4
|
firstLayerDecoder := layers.LayerTypeIPv4
|
||||||
@@ -118,16 +100,12 @@ func (s *ServiceViaMemory) filterDNSTraffic() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
hook := func(packetData []byte) bool {
|
hook := func(packetData []byte) bool {
|
||||||
|
// Decode the packet
|
||||||
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
|
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
|
||||||
|
|
||||||
|
// Get the UDP layer
|
||||||
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
||||||
if udpLayer == nil {
|
udp := udpLayer.(*layers.UDP)
|
||||||
return true
|
|
||||||
}
|
|
||||||
udp, ok := udpLayer.(*layers.UDP)
|
|
||||||
if !ok {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := new(dns.Msg)
|
msg := new(dns.Msg)
|
||||||
if err := msg.Unpack(udp.Payload); err != nil {
|
if err := msg.Unpack(udp.Payload); err != nil {
|
||||||
@@ -135,30 +113,13 @@ func (s *ServiceViaMemory) filterDNSTraffic() error {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
dev := s.wgInterface.GetDevice()
|
writer := responseWriter{
|
||||||
if dev == nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
writer := &responseWriter{
|
|
||||||
remote: remoteAddrFromPacket(packet),
|
|
||||||
packet: packet,
|
packet: packet,
|
||||||
device: dev.Device,
|
device: s.wgInterface.GetDevice().Device,
|
||||||
}
|
}
|
||||||
go s.dnsMux.ServeDNS(writer, msg)
|
go s.dnsMux.ServeDNS(&writer, msg)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), hook)
|
return filter.AddUDPPacketHook(false, s.runtimeIP, uint16(s.runtimePort), hook), nil
|
||||||
|
|
||||||
if s.tcpDNS != nil {
|
|
||||||
tcpHook := func(packetData []byte) bool {
|
|
||||||
s.tcpDNS.InjectPacket(packetData)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), tcpHook)
|
|
||||||
s.tcpHookSet = true
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,444 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
|
||||||
"gvisor.dev/gvisor/pkg/buffer"
|
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
|
||||||
"gvisor.dev/gvisor/pkg/waiter"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
dnsTCPReceiveWindow = 8192
|
|
||||||
dnsTCPMaxInFlight = 16
|
|
||||||
dnsTCPIdleTimeout = 30 * time.Second
|
|
||||||
dnsTCPReadTimeout = 5 * time.Second
|
|
||||||
)
|
|
||||||
|
|
||||||
// tcpDNSServer is an on-demand TCP DNS server backed by a minimal gvisor stack.
|
|
||||||
// It is started lazily when a truncated DNS response is detected and shuts down
|
|
||||||
// after a period of inactivity to conserve resources.
|
|
||||||
type tcpDNSServer struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
s *stack.Stack
|
|
||||||
ep *dnsEndpoint
|
|
||||||
mux *dns.ServeMux
|
|
||||||
tunDev tun.Device
|
|
||||||
ip netip.Addr
|
|
||||||
port uint16
|
|
||||||
mtu uint16
|
|
||||||
|
|
||||||
running bool
|
|
||||||
closed bool
|
|
||||||
timerID uint64
|
|
||||||
timer *time.Timer
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTCPDNSServer(mux *dns.ServeMux, tunDev tun.Device, ip netip.Addr, port uint16, mtu uint16) *tcpDNSServer {
|
|
||||||
return &tcpDNSServer{
|
|
||||||
mux: mux,
|
|
||||||
tunDev: tunDev,
|
|
||||||
ip: ip,
|
|
||||||
port: port,
|
|
||||||
mtu: mtu,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// InjectPacket ensures the stack is running and delivers a raw IP packet into
|
|
||||||
// the gvisor stack for TCP processing. Combining both operations under a single
|
|
||||||
// lock prevents a race where the idle timer could stop the stack between
|
|
||||||
// start and delivery.
|
|
||||||
func (t *tcpDNSServer) InjectPacket(payload []byte) {
|
|
||||||
t.mu.Lock()
|
|
||||||
defer t.mu.Unlock()
|
|
||||||
|
|
||||||
if t.closed {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if !t.running {
|
|
||||||
if err := t.startLocked(); err != nil {
|
|
||||||
log.Errorf("failed to start TCP DNS stack: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.running = true
|
|
||||||
log.Debugf("TCP DNS stack started on %s:%d (triggered by %s)", t.ip, t.port, srcAddrFromPacket(payload))
|
|
||||||
}
|
|
||||||
t.resetTimerLocked()
|
|
||||||
|
|
||||||
ep := t.ep
|
|
||||||
if ep == nil || ep.dispatcher == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
|
||||||
Payload: buffer.MakeWithData(payload),
|
|
||||||
})
|
|
||||||
// DeliverNetworkPacket takes ownership of the packet buffer; do not DecRef.
|
|
||||||
ep.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop tears down the gvisor stack and releases resources permanently.
|
|
||||||
// After Stop, InjectPacket becomes a no-op.
|
|
||||||
func (t *tcpDNSServer) Stop() {
|
|
||||||
t.mu.Lock()
|
|
||||||
defer t.mu.Unlock()
|
|
||||||
|
|
||||||
t.stopLocked()
|
|
||||||
t.closed = true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tcpDNSServer) startLocked() error {
|
|
||||||
// TODO: add ipv6.NewProtocol when IPv6 overlay support lands.
|
|
||||||
s := stack.New(stack.Options{
|
|
||||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
|
||||||
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
|
|
||||||
HandleLocal: false,
|
|
||||||
})
|
|
||||||
|
|
||||||
nicID := tcpip.NICID(1)
|
|
||||||
ep := &dnsEndpoint{
|
|
||||||
tunDev: t.tunDev,
|
|
||||||
}
|
|
||||||
ep.mtu.Store(uint32(t.mtu))
|
|
||||||
|
|
||||||
if err := s.CreateNIC(nicID, ep); err != nil {
|
|
||||||
s.Close()
|
|
||||||
s.Wait()
|
|
||||||
return fmt.Errorf("create NIC: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
protoAddr := tcpip.ProtocolAddress{
|
|
||||||
Protocol: ipv4.ProtocolNumber,
|
|
||||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
|
||||||
Address: tcpip.AddrFromSlice(t.ip.AsSlice()),
|
|
||||||
PrefixLen: 32,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
|
|
||||||
s.Close()
|
|
||||||
s.Wait()
|
|
||||||
return fmt.Errorf("add protocol address: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.SetPromiscuousMode(nicID, true); err != nil {
|
|
||||||
s.Close()
|
|
||||||
s.Wait()
|
|
||||||
return fmt.Errorf("set promiscuous mode: %s", err)
|
|
||||||
}
|
|
||||||
if err := s.SetSpoofing(nicID, true); err != nil {
|
|
||||||
s.Close()
|
|
||||||
s.Wait()
|
|
||||||
return fmt.Errorf("set spoofing: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defaultSubnet, err := tcpip.NewSubnet(
|
|
||||||
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
|
|
||||||
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
s.Close()
|
|
||||||
s.Wait()
|
|
||||||
return fmt.Errorf("create default subnet: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.SetRouteTable([]tcpip.Route{
|
|
||||||
{Destination: defaultSubnet, NIC: nicID},
|
|
||||||
})
|
|
||||||
|
|
||||||
tcpFwd := tcp.NewForwarder(s, dnsTCPReceiveWindow, dnsTCPMaxInFlight, func(r *tcp.ForwarderRequest) {
|
|
||||||
t.handleTCPDNS(r)
|
|
||||||
})
|
|
||||||
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket)
|
|
||||||
|
|
||||||
t.s = s
|
|
||||||
t.ep = ep
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tcpDNSServer) stopLocked() {
|
|
||||||
if !t.running {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if t.timer != nil {
|
|
||||||
t.timer.Stop()
|
|
||||||
t.timer = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if t.s != nil {
|
|
||||||
t.s.Close()
|
|
||||||
t.s.Wait()
|
|
||||||
t.s = nil
|
|
||||||
}
|
|
||||||
t.ep = nil
|
|
||||||
t.running = false
|
|
||||||
|
|
||||||
log.Debugf("TCP DNS stack stopped")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tcpDNSServer) resetTimerLocked() {
|
|
||||||
if t.timer != nil {
|
|
||||||
t.timer.Stop()
|
|
||||||
}
|
|
||||||
t.timerID++
|
|
||||||
id := t.timerID
|
|
||||||
t.timer = time.AfterFunc(dnsTCPIdleTimeout, func() {
|
|
||||||
t.mu.Lock()
|
|
||||||
defer t.mu.Unlock()
|
|
||||||
|
|
||||||
// Only stop if this timer is still the active one.
|
|
||||||
// A racing InjectPacket may have replaced it.
|
|
||||||
if t.timerID != id {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.stopLocked()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tcpDNSServer) handleTCPDNS(r *tcp.ForwarderRequest) {
|
|
||||||
id := r.ID()
|
|
||||||
|
|
||||||
wq := waiter.Queue{}
|
|
||||||
ep, epErr := r.CreateEndpoint(&wq)
|
|
||||||
if epErr != nil {
|
|
||||||
log.Debugf("TCP DNS: failed to create endpoint: %v", epErr)
|
|
||||||
r.Complete(true)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
r.Complete(false)
|
|
||||||
|
|
||||||
conn := gonet.NewTCPConn(&wq, ep)
|
|
||||||
defer func() {
|
|
||||||
if err := conn.Close(); err != nil {
|
|
||||||
log.Tracef("TCP DNS: close conn: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Reset idle timer on activity
|
|
||||||
t.mu.Lock()
|
|
||||||
t.resetTimerLocked()
|
|
||||||
t.mu.Unlock()
|
|
||||||
|
|
||||||
localAddr := &net.TCPAddr{
|
|
||||||
IP: id.LocalAddress.AsSlice(),
|
|
||||||
Port: int(id.LocalPort),
|
|
||||||
}
|
|
||||||
remoteAddr := &net.TCPAddr{
|
|
||||||
IP: id.RemoteAddress.AsSlice(),
|
|
||||||
Port: int(id.RemotePort),
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
|
||||||
if err := conn.SetReadDeadline(time.Now().Add(dnsTCPReadTimeout)); err != nil {
|
|
||||||
log.Debugf("TCP DNS: set deadline for %s: %v", remoteAddr, err)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
msg, err := readTCPDNSMessage(conn)
|
|
||||||
if err != nil {
|
|
||||||
if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
|
||||||
log.Debugf("TCP DNS: read from %s: %v", remoteAddr, err)
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
writer := &tcpResponseWriter{
|
|
||||||
conn: conn,
|
|
||||||
localAddr: localAddr,
|
|
||||||
remoteAddr: remoteAddr,
|
|
||||||
}
|
|
||||||
t.mux.ServeDNS(writer, msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// dnsEndpoint implements stack.LinkEndpoint for writing packets back via the tun device.
|
|
||||||
type dnsEndpoint struct {
|
|
||||||
dispatcher stack.NetworkDispatcher
|
|
||||||
tunDev tun.Device
|
|
||||||
mtu atomic.Uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *dnsEndpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher }
|
|
||||||
func (e *dnsEndpoint) IsAttached() bool { return e.dispatcher != nil }
|
|
||||||
func (e *dnsEndpoint) MTU() uint32 { return e.mtu.Load() }
|
|
||||||
func (e *dnsEndpoint) Capabilities() stack.LinkEndpointCapabilities { return stack.CapabilityNone }
|
|
||||||
func (e *dnsEndpoint) MaxHeaderLength() uint16 { return 0 }
|
|
||||||
func (e *dnsEndpoint) LinkAddress() tcpip.LinkAddress { return "" }
|
|
||||||
func (e *dnsEndpoint) Wait() { /* no async work */ }
|
|
||||||
func (e *dnsEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone }
|
|
||||||
func (e *dnsEndpoint) AddHeader(*stack.PacketBuffer) { /* IP-level endpoint, no link header */ }
|
|
||||||
func (e *dnsEndpoint) ParseHeader(*stack.PacketBuffer) bool { return true }
|
|
||||||
func (e *dnsEndpoint) Close() { /* lifecycle managed by tcpDNSServer */ }
|
|
||||||
func (e *dnsEndpoint) SetLinkAddress(tcpip.LinkAddress) { /* no link address for tun */ }
|
|
||||||
func (e *dnsEndpoint) SetMTU(mtu uint32) { e.mtu.Store(mtu) }
|
|
||||||
func (e *dnsEndpoint) SetOnCloseAction(func()) { /* not needed */ }
|
|
||||||
|
|
||||||
const tunPacketOffset = 40
|
|
||||||
|
|
||||||
func (e *dnsEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
|
|
||||||
var written int
|
|
||||||
for _, pkt := range pkts.AsSlice() {
|
|
||||||
data := stack.PayloadSince(pkt.NetworkHeader())
|
|
||||||
if data == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
raw := data.AsSlice()
|
|
||||||
buf := make([]byte, tunPacketOffset, tunPacketOffset+len(raw))
|
|
||||||
buf = append(buf, raw...)
|
|
||||||
data.Release()
|
|
||||||
|
|
||||||
if _, err := e.tunDev.Write([][]byte{buf}, tunPacketOffset); err != nil {
|
|
||||||
log.Tracef("TCP DNS endpoint: failed to write packet: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
written++
|
|
||||||
}
|
|
||||||
return written, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// tcpResponseWriter implements dns.ResponseWriter for TCP DNS connections.
|
|
||||||
type tcpResponseWriter struct {
|
|
||||||
conn *gonet.TCPConn
|
|
||||||
localAddr net.Addr
|
|
||||||
remoteAddr net.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *tcpResponseWriter) LocalAddr() net.Addr {
|
|
||||||
return w.localAddr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *tcpResponseWriter) RemoteAddr() net.Addr {
|
|
||||||
return w.remoteAddr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *tcpResponseWriter) WriteMsg(msg *dns.Msg) error {
|
|
||||||
data, err := msg.Pack()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("pack: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DNS TCP: 2-byte length prefix + message
|
|
||||||
buf := make([]byte, 2+len(data))
|
|
||||||
buf[0] = byte(len(data) >> 8)
|
|
||||||
buf[1] = byte(len(data))
|
|
||||||
copy(buf[2:], data)
|
|
||||||
|
|
||||||
if _, err = w.conn.Write(buf); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *tcpResponseWriter) Write(data []byte) (int, error) {
|
|
||||||
buf := make([]byte, 2+len(data))
|
|
||||||
buf[0] = byte(len(data) >> 8)
|
|
||||||
buf[1] = byte(len(data))
|
|
||||||
copy(buf[2:], data)
|
|
||||||
if _, err := w.conn.Write(buf); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return len(data), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *tcpResponseWriter) Close() error {
|
|
||||||
return w.conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *tcpResponseWriter) TsigStatus() error { return nil }
|
|
||||||
func (w *tcpResponseWriter) TsigTimersOnly(bool) { /* TSIG not supported */ }
|
|
||||||
func (w *tcpResponseWriter) Hijack() { /* not supported */ }
|
|
||||||
|
|
||||||
// readTCPDNSMessage reads a single DNS message from a TCP connection (length-prefixed).
|
|
||||||
func readTCPDNSMessage(conn *gonet.TCPConn) (*dns.Msg, error) {
|
|
||||||
// DNS over TCP uses a 2-byte length prefix
|
|
||||||
lenBuf := make([]byte, 2)
|
|
||||||
if _, err := io.ReadFull(conn, lenBuf); err != nil {
|
|
||||||
return nil, fmt.Errorf("read length: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
msgLen := int(lenBuf[0])<<8 | int(lenBuf[1])
|
|
||||||
if msgLen == 0 || msgLen > 65535 {
|
|
||||||
return nil, fmt.Errorf("invalid message length: %d", msgLen)
|
|
||||||
}
|
|
||||||
|
|
||||||
msgBuf := make([]byte, msgLen)
|
|
||||||
if _, err := io.ReadFull(conn, msgBuf); err != nil {
|
|
||||||
return nil, fmt.Errorf("read message: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := new(dns.Msg)
|
|
||||||
if err := msg.Unpack(msgBuf); err != nil {
|
|
||||||
return nil, fmt.Errorf("unpack: %w", err)
|
|
||||||
}
|
|
||||||
return msg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// srcAddrFromPacket extracts the source IP:port from a raw IP+TCP packet for logging.
|
|
||||||
// Supports both IPv4 and IPv6.
|
|
||||||
func srcAddrFromPacket(pkt []byte) netip.AddrPort {
|
|
||||||
if len(pkt) == 0 {
|
|
||||||
return netip.AddrPort{}
|
|
||||||
}
|
|
||||||
|
|
||||||
srcIP, transportOffset := srcIPFromPacket(pkt)
|
|
||||||
if !srcIP.IsValid() || len(pkt) < transportOffset+2 {
|
|
||||||
return netip.AddrPort{}
|
|
||||||
}
|
|
||||||
|
|
||||||
srcPort := uint16(pkt[transportOffset])<<8 | uint16(pkt[transportOffset+1])
|
|
||||||
return netip.AddrPortFrom(srcIP.Unmap(), srcPort)
|
|
||||||
}
|
|
||||||
|
|
||||||
func srcIPFromPacket(pkt []byte) (netip.Addr, int) {
|
|
||||||
switch header.IPVersion(pkt) {
|
|
||||||
case 4:
|
|
||||||
return srcIPv4(pkt)
|
|
||||||
case 6:
|
|
||||||
return srcIPv6(pkt)
|
|
||||||
default:
|
|
||||||
return netip.Addr{}, 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func srcIPv4(pkt []byte) (netip.Addr, int) {
|
|
||||||
if len(pkt) < header.IPv4MinimumSize {
|
|
||||||
return netip.Addr{}, 0
|
|
||||||
}
|
|
||||||
hdr := header.IPv4(pkt)
|
|
||||||
src := hdr.SourceAddress()
|
|
||||||
ip, ok := netip.AddrFromSlice(src.AsSlice())
|
|
||||||
if !ok {
|
|
||||||
return netip.Addr{}, 0
|
|
||||||
}
|
|
||||||
return ip, int(hdr.HeaderLength())
|
|
||||||
}
|
|
||||||
|
|
||||||
func srcIPv6(pkt []byte) (netip.Addr, int) {
|
|
||||||
if len(pkt) < header.IPv6MinimumSize {
|
|
||||||
return netip.Addr{}, 0
|
|
||||||
}
|
|
||||||
hdr := header.IPv6(pkt)
|
|
||||||
src := hdr.SourceAddress()
|
|
||||||
ip, ok := netip.AddrFromSlice(src.AsSlice())
|
|
||||||
if !ok {
|
|
||||||
return netip.Addr{}, 0
|
|
||||||
}
|
|
||||||
return ip, header.IPv6MinimumSize
|
|
||||||
}
|
|
||||||
@@ -41,61 +41,10 @@ const (
|
|||||||
|
|
||||||
reactivatePeriod = 30 * time.Second
|
reactivatePeriod = 30 * time.Second
|
||||||
probeTimeout = 2 * time.Second
|
probeTimeout = 2 * time.Second
|
||||||
|
|
||||||
// ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP
|
|
||||||
// payload from the tunnel MTU.
|
|
||||||
ipUDPHeaderSize = 60 + 8
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const testRecord = "com."
|
const testRecord = "com."
|
||||||
|
|
||||||
const (
|
|
||||||
protoUDP = "udp"
|
|
||||||
protoTCP = "tcp"
|
|
||||||
)
|
|
||||||
|
|
||||||
type dnsProtocolKey struct{}
|
|
||||||
|
|
||||||
// contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context.
|
|
||||||
func contextWithDNSProtocol(ctx context.Context, network string) context.Context {
|
|
||||||
return context.WithValue(ctx, dnsProtocolKey{}, network)
|
|
||||||
}
|
|
||||||
|
|
||||||
// dnsProtocolFromContext retrieves the inbound DNS protocol from context.
|
|
||||||
func dnsProtocolFromContext(ctx context.Context) string {
|
|
||||||
if ctx == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
if v, ok := ctx.Value(dnsProtocolKey{}).(string); ok {
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
type upstreamProtocolKey struct{}
|
|
||||||
|
|
||||||
// upstreamProtocolResult holds the protocol used for the upstream exchange.
|
|
||||||
// Stored as a pointer in context so the exchange function can set it.
|
|
||||||
type upstreamProtocolResult struct {
|
|
||||||
protocol string
|
|
||||||
}
|
|
||||||
|
|
||||||
// contextWithupstreamProtocolResult stores a mutable result holder in the context.
|
|
||||||
func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
|
|
||||||
r := &upstreamProtocolResult{}
|
|
||||||
return context.WithValue(ctx, upstreamProtocolKey{}, r), r
|
|
||||||
}
|
|
||||||
|
|
||||||
// setUpstreamProtocol sets the upstream protocol on the result holder in context, if present.
|
|
||||||
func setUpstreamProtocol(ctx context.Context, protocol string) {
|
|
||||||
if ctx == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if r, ok := ctx.Value(upstreamProtocolKey{}).(*upstreamProtocolResult); ok && r != nil {
|
|
||||||
r.protocol = protocol
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type upstreamClient interface {
|
type upstreamClient interface {
|
||||||
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||||
}
|
}
|
||||||
@@ -189,16 +138,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Propagate inbound protocol so upstream exchange can use TCP directly
|
ok, failures := u.tryUpstreamServers(w, r, logger)
|
||||||
// when the request came in over TCP.
|
|
||||||
ctx := u.ctx
|
|
||||||
if addr := w.RemoteAddr(); addr != nil {
|
|
||||||
network := addr.Network()
|
|
||||||
ctx = contextWithDNSProtocol(ctx, network)
|
|
||||||
resutil.SetMeta(w, "protocol", network)
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, failures := u.tryUpstreamServers(ctx, w, r, logger)
|
|
||||||
if len(failures) > 0 {
|
if len(failures) > 0 {
|
||||||
u.logUpstreamFailures(r.Question[0].Name, failures, ok, logger)
|
u.logUpstreamFailures(r.Question[0].Name, failures, ok, logger)
|
||||||
}
|
}
|
||||||
@@ -213,7 +153,7 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
|
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
|
||||||
timeout := u.upstreamTimeout
|
timeout := u.upstreamTimeout
|
||||||
if len(u.upstreamServers) > 1 {
|
if len(u.upstreamServers) > 1 {
|
||||||
maxTotal := 5 * time.Second
|
maxTotal := 5 * time.Second
|
||||||
@@ -228,7 +168,7 @@ func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.Res
|
|||||||
|
|
||||||
var failures []upstreamFailure
|
var failures []upstreamFailure
|
||||||
for _, upstream := range u.upstreamServers {
|
for _, upstream := range u.upstreamServers {
|
||||||
if failure := u.queryUpstream(ctx, w, r, upstream, timeout, logger); failure != nil {
|
if failure := u.queryUpstream(w, r, upstream, timeout, logger); failure != nil {
|
||||||
failures = append(failures, *failure)
|
failures = append(failures, *failure)
|
||||||
} else {
|
} else {
|
||||||
return true, failures
|
return true, failures
|
||||||
@@ -238,17 +178,15 @@ func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.Res
|
|||||||
}
|
}
|
||||||
|
|
||||||
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
|
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
|
||||||
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
|
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
|
||||||
var rm *dns.Msg
|
var rm *dns.Msg
|
||||||
var t time.Duration
|
var t time.Duration
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
var startTime time.Time
|
var startTime time.Time
|
||||||
var upstreamProto *upstreamProtocolResult
|
|
||||||
func() {
|
func() {
|
||||||
ctx, cancel := context.WithTimeout(parentCtx, timeout)
|
ctx, cancel := context.WithTimeout(u.ctx, timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
ctx, upstreamProto = contextWithupstreamProtocolResult(ctx)
|
|
||||||
startTime = time.Now()
|
startTime = time.Now()
|
||||||
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
|
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
|
||||||
}()
|
}()
|
||||||
@@ -265,7 +203,7 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re
|
|||||||
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
|
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
|
||||||
}
|
}
|
||||||
|
|
||||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
|
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -282,13 +220,10 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
|
|||||||
return &upstreamFailure{upstream: upstream, reason: reason}
|
return &upstreamFailure{upstream: upstream, reason: reason}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, upstreamProto *upstreamProtocolResult, logger *log.Entry) bool {
|
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
|
||||||
u.successCount.Add(1)
|
u.successCount.Add(1)
|
||||||
|
|
||||||
resutil.SetMeta(w, "upstream", upstream.String())
|
resutil.SetMeta(w, "upstream", upstream.String())
|
||||||
if upstreamProto != nil && upstreamProto.protocol != "" {
|
|
||||||
resutil.SetMeta(w, "upstream_protocol", upstreamProto.protocol)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear Zero bit from external responses to prevent upstream servers from
|
// Clear Zero bit from external responses to prevent upstream servers from
|
||||||
// manipulating our internal fallthrough signaling mechanism
|
// manipulating our internal fallthrough signaling mechanism
|
||||||
@@ -493,42 +428,13 @@ func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalC
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// clientUDPMaxSize returns the maximum UDP response size the client accepts.
|
|
||||||
func clientUDPMaxSize(r *dns.Msg) int {
|
|
||||||
if opt := r.IsEdns0(); opt != nil {
|
|
||||||
return int(opt.UDPSize())
|
|
||||||
}
|
|
||||||
return dns.MinMsgSize
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExchangeWithFallback exchanges a DNS message with the upstream server.
|
// ExchangeWithFallback exchanges a DNS message with the upstream server.
|
||||||
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
|
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
|
||||||
// If the inbound request came over TCP (via context), it skips the UDP attempt.
|
|
||||||
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
|
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
|
||||||
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
|
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
|
||||||
// If the request came in over TCP, go straight to TCP upstream.
|
// MTU - ip + udp headers
|
||||||
if dnsProtocolFromContext(ctx) == protoTCP {
|
// Note: this could be sent out on an interface that is not ours, but higher MTU settings could break truncation handling.
|
||||||
tcpClient := *client
|
client.UDPSize = uint16(currentMTU - (60 + 8))
|
||||||
tcpClient.Net = protoTCP
|
|
||||||
rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream)
|
|
||||||
if err != nil {
|
|
||||||
return nil, t, fmt.Errorf("with tcp: %w", err)
|
|
||||||
}
|
|
||||||
setUpstreamProtocol(ctx, protoTCP)
|
|
||||||
return rm, t, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
clientMaxSize := clientUDPMaxSize(r)
|
|
||||||
|
|
||||||
// Cap EDNS0 to our tunnel MTU so the upstream doesn't send a
|
|
||||||
// response larger than our read buffer.
|
|
||||||
// Note: the query could be sent out on an interface that is not ours,
|
|
||||||
// but higher MTU settings could break truncation handling.
|
|
||||||
maxUDPPayload := uint16(currentMTU - ipUDPHeaderSize)
|
|
||||||
client.UDPSize = maxUDPPayload
|
|
||||||
if opt := r.IsEdns0(); opt != nil && opt.UDPSize() > maxUDPPayload {
|
|
||||||
opt.SetUDPSize(maxUDPPayload)
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
rm *dns.Msg
|
rm *dns.Msg
|
||||||
@@ -547,32 +453,25 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
|||||||
}
|
}
|
||||||
|
|
||||||
if rm == nil || !rm.MsgHdr.Truncated {
|
if rm == nil || !rm.MsgHdr.Truncated {
|
||||||
setUpstreamProtocol(ctx, protoUDP)
|
|
||||||
return rm, t, nil
|
return rm, t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: if the upstream's truncated UDP response already contains more
|
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP.",
|
||||||
// data than the client's buffer, we could truncate locally and skip
|
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||||
// the TCP retry.
|
|
||||||
|
|
||||||
tcpClient := *client
|
client.Net = "tcp"
|
||||||
tcpClient.Net = protoTCP
|
|
||||||
|
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
rm, t, err = tcpClient.Exchange(r, upstream)
|
rm, t, err = client.Exchange(r, upstream)
|
||||||
} else {
|
} else {
|
||||||
rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream)
|
rm, t, err = client.ExchangeContext(ctx, r, upstream)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, t, fmt.Errorf("with tcp: %w", err)
|
return nil, t, fmt.Errorf("with tcp: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
setUpstreamProtocol(ctx, protoTCP)
|
// TODO: once TCP is implemented, rm.Truncate() if the request came in over UDP
|
||||||
|
|
||||||
if rm.Len() > clientMaxSize {
|
|
||||||
rm.Truncate(clientMaxSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
return rm, t, nil
|
return rm, t, nil
|
||||||
}
|
}
|
||||||
@@ -580,46 +479,18 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
|||||||
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
|
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
|
||||||
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
|
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
|
||||||
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
|
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
|
||||||
// If request came in over TCP, go straight to TCP upstream
|
reply, err := netstackExchange(ctx, nsNet, r, upstream, "udp")
|
||||||
if dnsProtocolFromContext(ctx) == protoTCP {
|
|
||||||
rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
setUpstreamProtocol(ctx, protoTCP)
|
|
||||||
return rm, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
clientMaxSize := clientUDPMaxSize(r)
|
|
||||||
|
|
||||||
// Cap EDNS0 to our tunnel MTU so the upstream doesn't send a
|
|
||||||
// response larger than what we can read over UDP.
|
|
||||||
maxUDPPayload := uint16(currentMTU - ipUDPHeaderSize)
|
|
||||||
if opt := r.IsEdns0(); opt != nil && opt.UDPSize() > maxUDPPayload {
|
|
||||||
opt.SetUDPSize(maxUDPPayload)
|
|
||||||
}
|
|
||||||
|
|
||||||
reply, err := netstackExchange(ctx, nsNet, r, upstream, protoUDP)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If response is truncated, retry with TCP
|
||||||
if reply != nil && reply.MsgHdr.Truncated {
|
if reply != nil && reply.MsgHdr.Truncated {
|
||||||
rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP)
|
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP",
|
||||||
if err != nil {
|
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||||
return nil, err
|
return netstackExchange(ctx, nsNet, r, upstream, "tcp")
|
||||||
}
|
|
||||||
|
|
||||||
setUpstreamProtocol(ctx, protoTCP)
|
|
||||||
if rm.Len() > clientMaxSize {
|
|
||||||
rm.Truncate(clientMaxSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
return rm, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
setUpstreamProtocol(ctx, protoUDP)
|
|
||||||
|
|
||||||
return reply, nil
|
return reply, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -640,7 +511,7 @@ func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upst
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsConn := &dns.Conn{Conn: conn, UDPSize: uint16(currentMTU - ipUDPHeaderSize)}
|
dnsConn := &dns.Conn{Conn: conn}
|
||||||
|
|
||||||
if err := dnsConn.WriteMsg(r); err != nil {
|
if err := dnsConn.WriteMsg(r); err != nil {
|
||||||
return nil, fmt.Errorf("write %s message: %w", network, err)
|
return nil, fmt.Errorf("write %s message: %w", network, err)
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream strin
|
|||||||
upstreamExchangeClient := &dns.Client{
|
upstreamExchangeClient := &dns.Client{
|
||||||
Timeout: ClientTimeout,
|
Timeout: ClientTimeout,
|
||||||
}
|
}
|
||||||
return ExchangeWithFallback(ctx, upstreamExchangeClient, r, upstream)
|
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
|
||||||
}
|
}
|
||||||
|
|
||||||
// exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN
|
// exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN
|
||||||
@@ -76,7 +76,7 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri
|
|||||||
Timeout: timeout,
|
Timeout: timeout,
|
||||||
}
|
}
|
||||||
|
|
||||||
return ExchangeWithFallback(ctx, upstreamExchangeClient, r, upstream)
|
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolver) isLocalResolver(upstream string) bool {
|
func (u *upstreamResolver) isLocalResolver(upstream string) bool {
|
||||||
|
|||||||
@@ -475,298 +475,3 @@ func TestFormatFailures(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDNSProtocolContext(t *testing.T) {
|
|
||||||
t.Run("roundtrip udp", func(t *testing.T) {
|
|
||||||
ctx := contextWithDNSProtocol(context.Background(), protoUDP)
|
|
||||||
assert.Equal(t, protoUDP, dnsProtocolFromContext(ctx))
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("roundtrip tcp", func(t *testing.T) {
|
|
||||||
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
|
|
||||||
assert.Equal(t, protoTCP, dnsProtocolFromContext(ctx))
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("missing returns empty", func(t *testing.T) {
|
|
||||||
assert.Equal(t, "", dnsProtocolFromContext(context.Background()))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExchangeWithFallback_TCPContext(t *testing.T) {
|
|
||||||
// Start a local DNS server that responds on TCP only
|
|
||||||
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
|
||||||
m := new(dns.Msg)
|
|
||||||
m.SetReply(r)
|
|
||||||
m.Answer = append(m.Answer, &dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP("10.0.0.1"),
|
|
||||||
})
|
|
||||||
if err := w.WriteMsg(m); err != nil {
|
|
||||||
t.Logf("write msg: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
tcpServer := &dns.Server{
|
|
||||||
Addr: "127.0.0.1:0",
|
|
||||||
Net: "tcp",
|
|
||||||
Handler: tcpHandler,
|
|
||||||
}
|
|
||||||
|
|
||||||
tcpLn, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
require.NoError(t, err)
|
|
||||||
tcpServer.Listener = tcpLn
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
if err := tcpServer.ActivateAndServe(); err != nil {
|
|
||||||
t.Logf("tcp server: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
defer func() {
|
|
||||||
_ = tcpServer.Shutdown()
|
|
||||||
}()
|
|
||||||
|
|
||||||
upstream := tcpLn.Addr().String()
|
|
||||||
|
|
||||||
// With TCP context, should connect directly via TCP without trying UDP
|
|
||||||
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
|
|
||||||
client := &dns.Client{Timeout: 2 * time.Second}
|
|
||||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
|
||||||
|
|
||||||
rm, _, err := ExchangeWithFallback(ctx, client, r, upstream)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, rm)
|
|
||||||
require.NotEmpty(t, rm.Answer)
|
|
||||||
assert.Contains(t, rm.Answer[0].String(), "10.0.0.1")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExchangeWithFallback_UDPFallbackToTCP(t *testing.T) {
|
|
||||||
// UDP handler returns a truncated response to trigger TCP retry.
|
|
||||||
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
|
||||||
m := new(dns.Msg)
|
|
||||||
m.SetReply(r)
|
|
||||||
m.Truncated = true
|
|
||||||
if err := w.WriteMsg(m); err != nil {
|
|
||||||
t.Logf("write msg: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// TCP handler returns the full answer.
|
|
||||||
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
|
||||||
m := new(dns.Msg)
|
|
||||||
m.SetReply(r)
|
|
||||||
m.Answer = append(m.Answer, &dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP("10.0.0.3"),
|
|
||||||
})
|
|
||||||
if err := w.WriteMsg(m); err != nil {
|
|
||||||
t.Logf("write msg: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
|
|
||||||
require.NoError(t, err)
|
|
||||||
addr := udpPC.LocalAddr().String()
|
|
||||||
|
|
||||||
udpServer := &dns.Server{
|
|
||||||
PacketConn: udpPC,
|
|
||||||
Net: "udp",
|
|
||||||
Handler: udpHandler,
|
|
||||||
}
|
|
||||||
|
|
||||||
tcpLn, err := net.Listen("tcp", addr)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
tcpServer := &dns.Server{
|
|
||||||
Listener: tcpLn,
|
|
||||||
Net: "tcp",
|
|
||||||
Handler: tcpHandler,
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
if err := udpServer.ActivateAndServe(); err != nil {
|
|
||||||
t.Logf("udp server: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
if err := tcpServer.ActivateAndServe(); err != nil {
|
|
||||||
t.Logf("tcp server: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
defer func() {
|
|
||||||
_ = udpServer.Shutdown()
|
|
||||||
_ = tcpServer.Shutdown()
|
|
||||||
}()
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
client := &dns.Client{Timeout: 2 * time.Second}
|
|
||||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
|
||||||
|
|
||||||
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
|
|
||||||
require.NoError(t, err, "should fall back to TCP after truncated UDP response")
|
|
||||||
require.NotNil(t, rm)
|
|
||||||
require.NotEmpty(t, rm.Answer, "TCP response should contain the full answer")
|
|
||||||
assert.Contains(t, rm.Answer[0].String(), "10.0.0.3")
|
|
||||||
assert.False(t, rm.Truncated, "TCP response should not be truncated")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExchangeWithFallback_TCPContextSkipsUDP(t *testing.T) {
|
|
||||||
// Start only a TCP server (no UDP). With TCP context it should succeed.
|
|
||||||
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
|
||||||
m := new(dns.Msg)
|
|
||||||
m.SetReply(r)
|
|
||||||
m.Answer = append(m.Answer, &dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP("10.0.0.2"),
|
|
||||||
})
|
|
||||||
if err := w.WriteMsg(m); err != nil {
|
|
||||||
t.Logf("write msg: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
tcpLn, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
tcpServer := &dns.Server{
|
|
||||||
Listener: tcpLn,
|
|
||||||
Net: "tcp",
|
|
||||||
Handler: tcpHandler,
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
if err := tcpServer.ActivateAndServe(); err != nil {
|
|
||||||
t.Logf("tcp server: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
defer func() {
|
|
||||||
_ = tcpServer.Shutdown()
|
|
||||||
}()
|
|
||||||
|
|
||||||
upstream := tcpLn.Addr().String()
|
|
||||||
|
|
||||||
// TCP context: should skip UDP entirely and go directly to TCP
|
|
||||||
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
|
|
||||||
client := &dns.Client{Timeout: 2 * time.Second}
|
|
||||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
|
||||||
|
|
||||||
rm, _, err := ExchangeWithFallback(ctx, client, r, upstream)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, rm)
|
|
||||||
require.NotEmpty(t, rm.Answer)
|
|
||||||
assert.Contains(t, rm.Answer[0].String(), "10.0.0.2")
|
|
||||||
|
|
||||||
// Without TCP context, trying to reach a TCP-only server via UDP should fail
|
|
||||||
ctx2 := context.Background()
|
|
||||||
client2 := &dns.Client{Timeout: 500 * time.Millisecond}
|
|
||||||
_, _, err = ExchangeWithFallback(ctx2, client2, r, upstream)
|
|
||||||
assert.Error(t, err, "should fail when no UDP server and no TCP context")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
|
|
||||||
// Verify that a client EDNS0 larger than our MTU-derived limit gets
|
|
||||||
// capped in the outgoing request so the upstream doesn't send a
|
|
||||||
// response larger than our read buffer.
|
|
||||||
var receivedUDPSize uint16
|
|
||||||
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
|
||||||
if opt := r.IsEdns0(); opt != nil {
|
|
||||||
receivedUDPSize = opt.UDPSize()
|
|
||||||
}
|
|
||||||
m := new(dns.Msg)
|
|
||||||
m.SetReply(r)
|
|
||||||
m.Answer = append(m.Answer, &dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP("10.0.0.1"),
|
|
||||||
})
|
|
||||||
if err := w.WriteMsg(m); err != nil {
|
|
||||||
t.Logf("write msg: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
|
|
||||||
require.NoError(t, err)
|
|
||||||
addr := udpPC.LocalAddr().String()
|
|
||||||
|
|
||||||
udpServer := &dns.Server{PacketConn: udpPC, Net: "udp", Handler: udpHandler}
|
|
||||||
go func() { _ = udpServer.ActivateAndServe() }()
|
|
||||||
t.Cleanup(func() { _ = udpServer.Shutdown() })
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
client := &dns.Client{Timeout: 2 * time.Second}
|
|
||||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
|
||||||
r.SetEdns0(4096, false)
|
|
||||||
|
|
||||||
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, rm)
|
|
||||||
|
|
||||||
expectedMax := uint16(currentMTU - ipUDPHeaderSize)
|
|
||||||
assert.Equal(t, expectedMax, receivedUDPSize,
|
|
||||||
"upstream should see capped EDNS0, not the client's 4096")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExchangeWithFallback_TCPTruncatesToClientSize(t *testing.T) {
|
|
||||||
// When the client advertises a large EDNS0 (4096) and the upstream
|
|
||||||
// truncates, the TCP response should NOT be truncated since the full
|
|
||||||
// answer fits within the client's original buffer.
|
|
||||||
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
|
||||||
m := new(dns.Msg)
|
|
||||||
m.SetReply(r)
|
|
||||||
m.Truncated = true
|
|
||||||
if err := w.WriteMsg(m); err != nil {
|
|
||||||
t.Logf("write msg: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
|
||||||
m := new(dns.Msg)
|
|
||||||
m.SetReply(r)
|
|
||||||
// Add enough records to exceed MTU but fit within 4096
|
|
||||||
for i := range 20 {
|
|
||||||
m.Answer = append(m.Answer, &dns.TXT{
|
|
||||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
Txt: []string{fmt.Sprintf("record-%d-padding-data-to-make-it-longer", i)},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
if err := w.WriteMsg(m); err != nil {
|
|
||||||
t.Logf("write msg: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
|
|
||||||
require.NoError(t, err)
|
|
||||||
addr := udpPC.LocalAddr().String()
|
|
||||||
|
|
||||||
udpServer := &dns.Server{PacketConn: udpPC, Net: "udp", Handler: udpHandler}
|
|
||||||
tcpLn, err := net.Listen("tcp", addr)
|
|
||||||
require.NoError(t, err)
|
|
||||||
tcpServer := &dns.Server{Listener: tcpLn, Net: "tcp", Handler: tcpHandler}
|
|
||||||
|
|
||||||
go func() { _ = udpServer.ActivateAndServe() }()
|
|
||||||
go func() { _ = tcpServer.ActivateAndServe() }()
|
|
||||||
t.Cleanup(func() {
|
|
||||||
_ = udpServer.Shutdown()
|
|
||||||
_ = tcpServer.Shutdown()
|
|
||||||
})
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
client := &dns.Client{Timeout: 2 * time.Second}
|
|
||||||
|
|
||||||
// Client with large buffer: should get all records without truncation
|
|
||||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT)
|
|
||||||
r.SetEdns0(4096, false)
|
|
||||||
|
|
||||||
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, rm)
|
|
||||||
assert.Len(t, rm.Answer, 20, "large EDNS0 client should get all records")
|
|
||||||
assert.False(t, rm.Truncated, "response should not be truncated for large buffer client")
|
|
||||||
|
|
||||||
// Client with small buffer: should get truncated response
|
|
||||||
r2 := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT)
|
|
||||||
r2.SetEdns0(512, false)
|
|
||||||
|
|
||||||
rm2, _, err := ExchangeWithFallback(ctx, &dns.Client{Timeout: 2 * time.Second}, r2, addr)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, rm2)
|
|
||||||
assert.Less(t, len(rm2.Answer), 20, "small EDNS0 client should get fewer records")
|
|
||||||
assert.True(t, rm2.Truncated, "response should be truncated for small buffer client")
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -237,8 +237,8 @@ func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, re
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB took=%s",
|
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
||||||
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), resp.Len(), time.Since(startTime))
|
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
||||||
}
|
}
|
||||||
|
|
||||||
// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation.
|
// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation.
|
||||||
@@ -263,28 +263,20 @@ func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error {
|
|||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
fields := log.Fields{
|
logger := log.WithFields(log.Fields{
|
||||||
"request_id": resutil.GenerateRequestID(),
|
"request_id": resutil.GenerateRequestID(),
|
||||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||||
}
|
})
|
||||||
if addr := w.RemoteAddr(); addr != nil {
|
|
||||||
fields["client"] = addr.String()
|
|
||||||
}
|
|
||||||
logger := log.WithFields(fields)
|
|
||||||
|
|
||||||
f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime)
|
f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
fields := log.Fields{
|
logger := log.WithFields(log.Fields{
|
||||||
"request_id": resutil.GenerateRequestID(),
|
"request_id": resutil.GenerateRequestID(),
|
||||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||||
}
|
})
|
||||||
if addr := w.RemoteAddr(); addr != nil {
|
|
||||||
fields["client"] = addr.String()
|
|
||||||
}
|
|
||||||
logger := log.WithFields(fields)
|
|
||||||
|
|
||||||
f.handleDNSQuery(logger, w, query, startTime)
|
f.handleDNSQuery(logger, w, query, startTime)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -46,7 +46,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
"github.com/netbirdio/netbird/client/internal/portforward"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/relay"
|
"github.com/netbirdio/netbird/client/internal/relay"
|
||||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||||
@@ -140,7 +139,6 @@ type EngineConfig struct {
|
|||||||
ProfileConfig *profilemanager.Config
|
ProfileConfig *profilemanager.Config
|
||||||
|
|
||||||
LogPath string
|
LogPath string
|
||||||
TempDir string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// EngineServices holds the external service dependencies required by the Engine.
|
// EngineServices holds the external service dependencies required by the Engine.
|
||||||
@@ -212,10 +210,9 @@ type Engine struct {
|
|||||||
// checks are the client-applied posture checks that need to be evaluated on the client
|
// checks are the client-applied posture checks that need to be evaluated on the client
|
||||||
checks []*mgmProto.Checks
|
checks []*mgmProto.Checks
|
||||||
|
|
||||||
relayManager *relayClient.Manager
|
relayManager *relayClient.Manager
|
||||||
stateManager *statemanager.Manager
|
stateManager *statemanager.Manager
|
||||||
portForwardManager *portforward.Manager
|
srWatcher *guard.SRWatcher
|
||||||
srWatcher *guard.SRWatcher
|
|
||||||
|
|
||||||
// Sync response persistence (protected by syncRespMux)
|
// Sync response persistence (protected by syncRespMux)
|
||||||
syncRespMux sync.RWMutex
|
syncRespMux sync.RWMutex
|
||||||
@@ -262,27 +259,26 @@ func NewEngine(
|
|||||||
mobileDep MobileDependency,
|
mobileDep MobileDependency,
|
||||||
) *Engine {
|
) *Engine {
|
||||||
engine := &Engine{
|
engine := &Engine{
|
||||||
clientCtx: clientCtx,
|
clientCtx: clientCtx,
|
||||||
clientCancel: clientCancel,
|
clientCancel: clientCancel,
|
||||||
signal: services.SignalClient,
|
signal: services.SignalClient,
|
||||||
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
|
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
|
||||||
mgmClient: services.MgmClient,
|
mgmClient: services.MgmClient,
|
||||||
relayManager: services.RelayManager,
|
relayManager: services.RelayManager,
|
||||||
peerStore: peerstore.NewConnStore(),
|
peerStore: peerstore.NewConnStore(),
|
||||||
syncMsgMux: &sync.Mutex{},
|
syncMsgMux: &sync.Mutex{},
|
||||||
config: config,
|
config: config,
|
||||||
mobileDep: mobileDep,
|
mobileDep: mobileDep,
|
||||||
STUNs: []*stun.URI{},
|
STUNs: []*stun.URI{},
|
||||||
TURNs: []*stun.URI{},
|
TURNs: []*stun.URI{},
|
||||||
networkSerial: 0,
|
networkSerial: 0,
|
||||||
statusRecorder: services.StatusRecorder,
|
statusRecorder: services.StatusRecorder,
|
||||||
stateManager: services.StateManager,
|
stateManager: services.StateManager,
|
||||||
portForwardManager: portforward.NewManager(),
|
checks: services.Checks,
|
||||||
checks: services.Checks,
|
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
||||||
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
jobExecutor: jobexec.NewExecutor(),
|
||||||
jobExecutor: jobexec.NewExecutor(),
|
clientMetrics: services.ClientMetrics,
|
||||||
clientMetrics: services.ClientMetrics,
|
updateManager: services.UpdateManager,
|
||||||
updateManager: services.UpdateManager,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
||||||
@@ -504,7 +500,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
||||||
|
|
||||||
e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool {
|
e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool {
|
||||||
for _, routes := range e.routeManager.GetSelectedClientRoutes() {
|
for _, routes := range e.routeManager.GetClientRoutes() {
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if r.Network.Contains(ip) {
|
if r.Network.Contains(ip) {
|
||||||
return true
|
return true
|
||||||
@@ -525,11 +521,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Inject firewall into DNS server now that it's available.
|
|
||||||
// The DNS server is created before the firewall because the route manager
|
|
||||||
// depends on the DNS server, and the firewall depends on the wg interface.
|
|
||||||
e.dnsServer.SetFirewall(e.firewall)
|
|
||||||
|
|
||||||
e.udpMux, err = e.wgInterface.Up()
|
e.udpMux, err = e.wgInterface.Up()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
|
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
|
||||||
@@ -541,13 +532,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
// conntrack entries from being created before the rules are in place
|
// conntrack entries from being created before the rules are in place
|
||||||
e.setupWGProxyNoTrack()
|
e.setupWGProxyNoTrack()
|
||||||
|
|
||||||
// Start after interface is up since port may have been resolved from 0 or changed if occupied
|
|
||||||
e.shutdownWg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer e.shutdownWg.Done()
|
|
||||||
e.portForwardManager.Start(e.ctx, uint16(e.config.WgPort))
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Set the WireGuard interface for rosenpass after interface is up
|
// Set the WireGuard interface for rosenpass after interface is up
|
||||||
if e.rpManager != nil {
|
if e.rpManager != nil {
|
||||||
e.rpManager.SetInterface(e.wgInterface)
|
e.rpManager.SetInterface(e.wgInterface)
|
||||||
@@ -1096,7 +1080,6 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR
|
|||||||
StatusRecorder: e.statusRecorder,
|
StatusRecorder: e.statusRecorder,
|
||||||
SyncResponse: syncResponse,
|
SyncResponse: syncResponse,
|
||||||
LogPath: e.config.LogPath,
|
LogPath: e.config.LogPath,
|
||||||
TempDir: e.config.TempDir,
|
|
||||||
ClientMetrics: e.clientMetrics,
|
ClientMetrics: e.clientMetrics,
|
||||||
RefreshStatus: func() {
|
RefreshStatus: func() {
|
||||||
e.RunHealthProbes(true)
|
e.RunHealthProbes(true)
|
||||||
@@ -1552,13 +1535,12 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
|||||||
}
|
}
|
||||||
|
|
||||||
serviceDependencies := peer.ServiceDependencies{
|
serviceDependencies := peer.ServiceDependencies{
|
||||||
StatusRecorder: e.statusRecorder,
|
StatusRecorder: e.statusRecorder,
|
||||||
Signaler: e.signaler,
|
Signaler: e.signaler,
|
||||||
IFaceDiscover: e.mobileDep.IFaceDiscover,
|
IFaceDiscover: e.mobileDep.IFaceDiscover,
|
||||||
RelayManager: e.relayManager,
|
RelayManager: e.relayManager,
|
||||||
SrWatcher: e.srWatcher,
|
SrWatcher: e.srWatcher,
|
||||||
PortForwardManager: e.portForwardManager,
|
MetricsRecorder: e.clientMetrics,
|
||||||
MetricsRecorder: e.clientMetrics,
|
|
||||||
}
|
}
|
||||||
peerConn, err := peer.NewConn(config, serviceDependencies)
|
peerConn, err := peer.NewConn(config, serviceDependencies)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1715,12 +1697,6 @@ func (e *Engine) close() {
|
|||||||
if e.rpManager != nil {
|
if e.rpManager != nil {
|
||||||
_ = e.rpManager.Close()
|
_ = e.rpManager.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := e.portForwardManager.GracefullyStop(ctx); err != nil {
|
|
||||||
log.Warnf("failed to gracefully stop port forwarding manager: %s", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) {
|
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) {
|
||||||
@@ -1824,7 +1800,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
|
|||||||
return dnsServer, nil
|
return dnsServer, nil
|
||||||
|
|
||||||
case "ios":
|
case "ios":
|
||||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.mobileDep.HostDNSAddresses, e.statusRecorder, e.config.DisableDNS)
|
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
|
||||||
return dnsServer, nil
|
return dnsServer, nil
|
||||||
|
|
||||||
default:
|
default:
|
||||||
@@ -1861,11 +1837,6 @@ func (e *Engine) GetExposeManager() *expose.Manager {
|
|||||||
return e.exposeManager
|
return e.exposeManager
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsBlockInbound returns whether inbound connections are blocked.
|
|
||||||
func (e *Engine) IsBlockInbound() bool {
|
|
||||||
return e.config.BlockInbound
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetClientMetrics returns the client metrics
|
// GetClientMetrics returns the client metrics
|
||||||
func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
|
func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
|
||||||
return e.clientMetrics
|
return e.clientMetrics
|
||||||
|
|||||||
@@ -55,7 +55,6 @@ import (
|
|||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
@@ -829,7 +828,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
|||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
MTU: iface.DefaultMTU,
|
MTU: iface.DefaultMTU,
|
||||||
}, EngineServices{
|
}, EngineServices{
|
||||||
SignalClient: &signal.MockClient{},
|
SignalClient: &signal.MockClient{},
|
||||||
MgmClient: &mgmt.MockClient{},
|
MgmClient: &mgmt.MockClient{},
|
||||||
RelayManager: relayMgr,
|
RelayManager: relayMgr,
|
||||||
@@ -1036,7 +1035,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
MTU: iface.DefaultMTU,
|
MTU: iface.DefaultMTU,
|
||||||
}, EngineServices{
|
}, EngineServices{
|
||||||
SignalClient: &signal.MockClient{},
|
SignalClient: &signal.MockClient{},
|
||||||
MgmClient: &mgmt.MockClient{},
|
MgmClient: &mgmt.MockClient{},
|
||||||
RelayManager: relayMgr,
|
RelayManager: relayMgr,
|
||||||
@@ -1539,8 +1538,13 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
publicKey, err := mgmtClient.GetServerPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
info := system.GetInfo(ctx)
|
info := system.GetInfo(ctx)
|
||||||
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
|
resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1562,7 +1566,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||||
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
||||||
SignalClient: signalClient,
|
SignalClient: signalClient,
|
||||||
MgmClient: mgmtClient,
|
MgmClient: mgmtClient,
|
||||||
RelayManager: relayMgr,
|
RelayManager: relayMgr,
|
||||||
@@ -1635,12 +1639,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
peersManager := peers.NewManager(store, permissionsManager)
|
peersManager := peers.NewManager(store, permissionsManager)
|
||||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||||
|
|
||||||
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore)
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
|
||||||
|
|
||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1662,7 +1661,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||||
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,8 +22,4 @@ type MobileDependency struct {
|
|||||||
DnsManager dns.IosDnsManager
|
DnsManager dns.IosDnsManager
|
||||||
FileDescriptor int32
|
FileDescriptor int32
|
||||||
StateFilePath string
|
StateFilePath string
|
||||||
|
|
||||||
// TempDir is a writable directory for temporary files (e.g., debug bundle zip).
|
|
||||||
// On Android, this should be set to the app's cache directory.
|
|
||||||
TempDir string
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,9 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
nfct "github.com/ti-mo/conntrack"
|
nfct "github.com/ti-mo/conntrack"
|
||||||
@@ -19,64 +17,31 @@ import (
|
|||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const defaultChannelSize = 100
|
||||||
defaultChannelSize = 100
|
|
||||||
reconnectInitInterval = 5 * time.Second
|
|
||||||
reconnectMaxInterval = 5 * time.Minute
|
|
||||||
reconnectRandomization = 0.5
|
|
||||||
)
|
|
||||||
|
|
||||||
// listener abstracts a netlink conntrack connection for testability.
|
|
||||||
type listener interface {
|
|
||||||
Listen(evChan chan<- nfct.Event, numWorkers uint8, groups []netfilter.NetlinkGroup) (chan error, error)
|
|
||||||
Close() error
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConnTrack manages kernel-based conntrack events
|
// ConnTrack manages kernel-based conntrack events
|
||||||
type ConnTrack struct {
|
type ConnTrack struct {
|
||||||
flowLogger nftypes.FlowLogger
|
flowLogger nftypes.FlowLogger
|
||||||
iface nftypes.IFaceMapper
|
iface nftypes.IFaceMapper
|
||||||
|
|
||||||
conn listener
|
conn *nfct.Conn
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
|
|
||||||
dial func() (listener, error)
|
|
||||||
instanceID uuid.UUID
|
instanceID uuid.UUID
|
||||||
started bool
|
started bool
|
||||||
done chan struct{}
|
done chan struct{}
|
||||||
sysctlModified bool
|
sysctlModified bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialFunc is a constructor for netlink conntrack connections.
|
|
||||||
type DialFunc func() (listener, error)
|
|
||||||
|
|
||||||
// Option configures a ConnTrack instance.
|
|
||||||
type Option func(*ConnTrack)
|
|
||||||
|
|
||||||
// WithDialer overrides the default netlink dialer, primarily for testing.
|
|
||||||
func WithDialer(dial DialFunc) Option {
|
|
||||||
return func(c *ConnTrack) {
|
|
||||||
c.dial = dial
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func defaultDial() (listener, error) {
|
|
||||||
return nfct.Dial(nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a new connection tracker that interfaces with the kernel's conntrack system
|
// New creates a new connection tracker that interfaces with the kernel's conntrack system
|
||||||
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper, opts ...Option) *ConnTrack {
|
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) *ConnTrack {
|
||||||
ct := &ConnTrack{
|
return &ConnTrack{
|
||||||
flowLogger: flowLogger,
|
flowLogger: flowLogger,
|
||||||
iface: iface,
|
iface: iface,
|
||||||
instanceID: uuid.New(),
|
instanceID: uuid.New(),
|
||||||
dial: defaultDial,
|
started: false,
|
||||||
done: make(chan struct{}, 1),
|
done: make(chan struct{}, 1),
|
||||||
}
|
}
|
||||||
for _, opt := range opts {
|
|
||||||
opt(ct)
|
|
||||||
}
|
|
||||||
return ct
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start begins tracking connections by listening for conntrack events. This method is idempotent.
|
// Start begins tracking connections by listening for conntrack events. This method is idempotent.
|
||||||
@@ -94,9 +59,8 @@ func (c *ConnTrack) Start(enableCounters bool) error {
|
|||||||
c.EnableAccounting()
|
c.EnableAccounting()
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := c.dial()
|
conn, err := nfct.Dial(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.RestoreAccounting()
|
|
||||||
return fmt.Errorf("dial conntrack: %w", err)
|
return fmt.Errorf("dial conntrack: %w", err)
|
||||||
}
|
}
|
||||||
c.conn = conn
|
c.conn = conn
|
||||||
@@ -112,16 +76,9 @@ func (c *ConnTrack) Start(enableCounters bool) error {
|
|||||||
log.Errorf("Error closing conntrack connection: %v", err)
|
log.Errorf("Error closing conntrack connection: %v", err)
|
||||||
}
|
}
|
||||||
c.conn = nil
|
c.conn = nil
|
||||||
c.RestoreAccounting()
|
|
||||||
return fmt.Errorf("start conntrack listener: %w", err)
|
return fmt.Errorf("start conntrack listener: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Drain any stale stop signal from a previous cycle.
|
|
||||||
select {
|
|
||||||
case <-c.done:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
c.started = true
|
c.started = true
|
||||||
|
|
||||||
go c.receiverRoutine(events, errChan)
|
go c.receiverRoutine(events, errChan)
|
||||||
@@ -135,98 +92,17 @@ func (c *ConnTrack) receiverRoutine(events chan nfct.Event, errChan chan error)
|
|||||||
case event := <-events:
|
case event := <-events:
|
||||||
c.handleEvent(event)
|
c.handleEvent(event)
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
if events, errChan = c.handleListenerError(err); events == nil {
|
log.Errorf("Error from conntrack event listener: %v", err)
|
||||||
return
|
if err := c.conn.Close(); err != nil {
|
||||||
|
log.Errorf("Error closing conntrack connection: %v", err)
|
||||||
}
|
}
|
||||||
|
return
|
||||||
case <-c.done:
|
case <-c.done:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleListenerError closes the failed connection and attempts to reconnect.
|
|
||||||
// Returns new channels on success, or nil if shutdown was requested.
|
|
||||||
func (c *ConnTrack) handleListenerError(err error) (chan nfct.Event, chan error) {
|
|
||||||
log.Warnf("conntrack event listener failed: %v", err)
|
|
||||||
c.closeConn()
|
|
||||||
return c.reconnect()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ConnTrack) closeConn() {
|
|
||||||
c.mux.Lock()
|
|
||||||
defer c.mux.Unlock()
|
|
||||||
|
|
||||||
if c.conn != nil {
|
|
||||||
if err := c.conn.Close(); err != nil {
|
|
||||||
log.Debugf("close conntrack connection: %v", err)
|
|
||||||
}
|
|
||||||
c.conn = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// reconnect attempts to re-establish the conntrack netlink listener with exponential backoff.
|
|
||||||
// Returns new channels on success, or nil if shutdown was requested.
|
|
||||||
func (c *ConnTrack) reconnect() (chan nfct.Event, chan error) {
|
|
||||||
bo := &backoff.ExponentialBackOff{
|
|
||||||
InitialInterval: reconnectInitInterval,
|
|
||||||
RandomizationFactor: reconnectRandomization,
|
|
||||||
Multiplier: backoff.DefaultMultiplier,
|
|
||||||
MaxInterval: reconnectMaxInterval,
|
|
||||||
MaxElapsedTime: 0, // retry indefinitely
|
|
||||||
Clock: backoff.SystemClock,
|
|
||||||
}
|
|
||||||
bo.Reset()
|
|
||||||
|
|
||||||
for {
|
|
||||||
delay := bo.NextBackOff()
|
|
||||||
log.Infof("reconnecting conntrack listener in %s", delay)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-c.done:
|
|
||||||
c.mux.Lock()
|
|
||||||
c.started = false
|
|
||||||
c.mux.Unlock()
|
|
||||||
return nil, nil
|
|
||||||
case <-time.After(delay):
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := c.dial()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("reconnect conntrack dial: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
events := make(chan nfct.Event, defaultChannelSize)
|
|
||||||
errChan, err := conn.Listen(events, 1, []netfilter.NetlinkGroup{
|
|
||||||
netfilter.GroupCTNew,
|
|
||||||
netfilter.GroupCTDestroy,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("reconnect conntrack listen: %v", err)
|
|
||||||
if closeErr := conn.Close(); closeErr != nil {
|
|
||||||
log.Debugf("close conntrack connection: %v", closeErr)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
c.mux.Lock()
|
|
||||||
if !c.started {
|
|
||||||
// Stop() ran while we were reconnecting.
|
|
||||||
c.mux.Unlock()
|
|
||||||
if closeErr := conn.Close(); closeErr != nil {
|
|
||||||
log.Debugf("close conntrack connection: %v", closeErr)
|
|
||||||
}
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
c.conn = conn
|
|
||||||
c.mux.Unlock()
|
|
||||||
|
|
||||||
log.Infof("conntrack listener reconnected successfully")
|
|
||||||
|
|
||||||
return events, errChan
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop stops the connection tracking. This method is idempotent.
|
// Stop stops the connection tracking. This method is idempotent.
|
||||||
func (c *ConnTrack) Stop() {
|
func (c *ConnTrack) Stop() {
|
||||||
c.mux.Lock()
|
c.mux.Lock()
|
||||||
@@ -260,27 +136,23 @@ func (c *ConnTrack) Close() error {
|
|||||||
c.mux.Lock()
|
c.mux.Lock()
|
||||||
defer c.mux.Unlock()
|
defer c.mux.Unlock()
|
||||||
|
|
||||||
if !c.started {
|
if c.started {
|
||||||
return nil
|
select {
|
||||||
|
case c.done <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
|
||||||
case c.done <- struct{}{}:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
c.started = false
|
|
||||||
|
|
||||||
var closeErr error
|
|
||||||
if c.conn != nil {
|
if c.conn != nil {
|
||||||
closeErr = c.conn.Close()
|
err := c.conn.Close()
|
||||||
c.conn = nil
|
c.conn = nil
|
||||||
}
|
c.started = false
|
||||||
|
|
||||||
c.RestoreAccounting()
|
c.RestoreAccounting()
|
||||||
|
|
||||||
if closeErr != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("close conntrack: %w", closeErr)
|
return fmt.Errorf("close conntrack: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -1,224 +0,0 @@
|
|||||||
//go:build linux && !android
|
|
||||||
|
|
||||||
package conntrack
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
nfct "github.com/ti-mo/conntrack"
|
|
||||||
"github.com/ti-mo/netfilter"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mockListener struct {
|
|
||||||
errChan chan error
|
|
||||||
closed atomic.Bool
|
|
||||||
closedCh chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newMockListener() *mockListener {
|
|
||||||
return &mockListener{
|
|
||||||
errChan: make(chan error, 1),
|
|
||||||
closedCh: make(chan struct{}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockListener) Listen(evChan chan<- nfct.Event, _ uint8, _ []netfilter.NetlinkGroup) (chan error, error) {
|
|
||||||
return m.errChan, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockListener) Close() error {
|
|
||||||
if m.closed.CompareAndSwap(false, true) {
|
|
||||||
close(m.closedCh)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReconnectAfterError(t *testing.T) {
|
|
||||||
first := newMockListener()
|
|
||||||
second := newMockListener()
|
|
||||||
third := newMockListener()
|
|
||||||
listeners := []*mockListener{first, second, third}
|
|
||||||
callCount := atomic.Int32{}
|
|
||||||
|
|
||||||
ct := New(nil, nil, WithDialer(func() (listener, error) {
|
|
||||||
n := int(callCount.Add(1)) - 1
|
|
||||||
return listeners[n], nil
|
|
||||||
}))
|
|
||||||
|
|
||||||
err := ct.Start(false)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Inject an error on the first listener.
|
|
||||||
first.errChan <- assert.AnError
|
|
||||||
|
|
||||||
// Wait for reconnect to complete.
|
|
||||||
require.Eventually(t, func() bool {
|
|
||||||
return callCount.Load() >= 2
|
|
||||||
}, 15*time.Second, 100*time.Millisecond, "reconnect should dial a new connection")
|
|
||||||
|
|
||||||
// The first connection must have been closed.
|
|
||||||
select {
|
|
||||||
case <-first.closedCh:
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
t.Fatal("first connection was not closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify the receiver is still running by injecting and handling a second error.
|
|
||||||
second.errChan <- assert.AnError
|
|
||||||
|
|
||||||
require.Eventually(t, func() bool {
|
|
||||||
return callCount.Load() >= 3
|
|
||||||
}, 15*time.Second, 100*time.Millisecond, "second reconnect should succeed")
|
|
||||||
|
|
||||||
ct.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStopDuringReconnectBackoff(t *testing.T) {
|
|
||||||
mock := newMockListener()
|
|
||||||
|
|
||||||
ct := New(nil, nil, WithDialer(func() (listener, error) {
|
|
||||||
return mock, nil
|
|
||||||
}))
|
|
||||||
|
|
||||||
err := ct.Start(false)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Trigger an error so the receiver enters reconnect.
|
|
||||||
mock.errChan <- assert.AnError
|
|
||||||
|
|
||||||
// Wait for the error handler to close the old listener before calling Stop.
|
|
||||||
select {
|
|
||||||
case <-mock.closedCh:
|
|
||||||
case <-time.After(5 * time.Second):
|
|
||||||
t.Fatal("timed out waiting for reconnect to start")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop while reconnecting.
|
|
||||||
ct.Stop()
|
|
||||||
|
|
||||||
ct.mux.Lock()
|
|
||||||
assert.False(t, ct.started, "started should be false after Stop")
|
|
||||||
assert.Nil(t, ct.conn, "conn should be nil after Stop")
|
|
||||||
ct.mux.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStopRaceWithReconnectDial(t *testing.T) {
|
|
||||||
first := newMockListener()
|
|
||||||
dialStarted := make(chan struct{})
|
|
||||||
dialProceed := make(chan struct{})
|
|
||||||
second := newMockListener()
|
|
||||||
callCount := atomic.Int32{}
|
|
||||||
|
|
||||||
ct := New(nil, nil, WithDialer(func() (listener, error) {
|
|
||||||
n := callCount.Add(1)
|
|
||||||
if n == 1 {
|
|
||||||
return first, nil
|
|
||||||
}
|
|
||||||
// Second dial: signal that we're in progress, wait for test to call Stop.
|
|
||||||
close(dialStarted)
|
|
||||||
<-dialProceed
|
|
||||||
return second, nil
|
|
||||||
}))
|
|
||||||
|
|
||||||
err := ct.Start(false)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Trigger error to enter reconnect.
|
|
||||||
first.errChan <- assert.AnError
|
|
||||||
|
|
||||||
// Wait for reconnect's second dial to begin.
|
|
||||||
select {
|
|
||||||
case <-dialStarted:
|
|
||||||
case <-time.After(15 * time.Second):
|
|
||||||
t.Fatal("timed out waiting for reconnect dial")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop while dial is in progress (conn is nil at this point).
|
|
||||||
ct.Stop()
|
|
||||||
|
|
||||||
// Let the dial complete. reconnect should detect started==false and close the new conn.
|
|
||||||
close(dialProceed)
|
|
||||||
|
|
||||||
// The second connection should be closed (not leaked).
|
|
||||||
select {
|
|
||||||
case <-second.closedCh:
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
t.Fatal("second connection was leaked after Stop")
|
|
||||||
}
|
|
||||||
|
|
||||||
ct.mux.Lock()
|
|
||||||
assert.False(t, ct.started)
|
|
||||||
assert.Nil(t, ct.conn)
|
|
||||||
ct.mux.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCloseRaceWithReconnectDial(t *testing.T) {
|
|
||||||
first := newMockListener()
|
|
||||||
dialStarted := make(chan struct{})
|
|
||||||
dialProceed := make(chan struct{})
|
|
||||||
second := newMockListener()
|
|
||||||
callCount := atomic.Int32{}
|
|
||||||
|
|
||||||
ct := New(nil, nil, WithDialer(func() (listener, error) {
|
|
||||||
n := callCount.Add(1)
|
|
||||||
if n == 1 {
|
|
||||||
return first, nil
|
|
||||||
}
|
|
||||||
close(dialStarted)
|
|
||||||
<-dialProceed
|
|
||||||
return second, nil
|
|
||||||
}))
|
|
||||||
|
|
||||||
err := ct.Start(false)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
first.errChan <- assert.AnError
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-dialStarted:
|
|
||||||
case <-time.After(15 * time.Second):
|
|
||||||
t.Fatal("timed out waiting for reconnect dial")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close while dial is in progress (conn is nil).
|
|
||||||
require.NoError(t, ct.Close())
|
|
||||||
|
|
||||||
close(dialProceed)
|
|
||||||
|
|
||||||
// The second connection should be closed (not leaked).
|
|
||||||
select {
|
|
||||||
case <-second.closedCh:
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
t.Fatal("second connection was leaked after Close")
|
|
||||||
}
|
|
||||||
|
|
||||||
ct.mux.Lock()
|
|
||||||
assert.False(t, ct.started)
|
|
||||||
assert.Nil(t, ct.conn)
|
|
||||||
ct.mux.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStartIsIdempotent(t *testing.T) {
|
|
||||||
mock := newMockListener()
|
|
||||||
callCount := atomic.Int32{}
|
|
||||||
|
|
||||||
ct := New(nil, nil, WithDialer(func() (listener, error) {
|
|
||||||
callCount.Add(1)
|
|
||||||
return mock, nil
|
|
||||||
}))
|
|
||||||
|
|
||||||
err := ct.Start(false)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Second Start should be a no-op.
|
|
||||||
err = ct.Start(false)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, int32(1), callCount.Load(), "dial should only be called once")
|
|
||||||
|
|
||||||
ct.Stop()
|
|
||||||
}
|
|
||||||
@@ -22,7 +22,6 @@ import (
|
|||||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer/id"
|
"github.com/netbirdio/netbird/client/internal/peer/id"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer/worker"
|
"github.com/netbirdio/netbird/client/internal/peer/worker"
|
||||||
"github.com/netbirdio/netbird/client/internal/portforward"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
@@ -46,7 +45,6 @@ type ServiceDependencies struct {
|
|||||||
RelayManager *relayClient.Manager
|
RelayManager *relayClient.Manager
|
||||||
SrWatcher *guard.SRWatcher
|
SrWatcher *guard.SRWatcher
|
||||||
PeerConnDispatcher *dispatcher.ConnectionDispatcher
|
PeerConnDispatcher *dispatcher.ConnectionDispatcher
|
||||||
PortForwardManager *portforward.Manager
|
|
||||||
MetricsRecorder MetricsRecorder
|
MetricsRecorder MetricsRecorder
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -89,17 +87,16 @@ type ConnConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
Log *log.Entry
|
Log *log.Entry
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
ctxCancel context.CancelFunc
|
ctxCancel context.CancelFunc
|
||||||
config ConnConfig
|
config ConnConfig
|
||||||
statusRecorder *Status
|
statusRecorder *Status
|
||||||
signaler *Signaler
|
signaler *Signaler
|
||||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||||
relayManager *relayClient.Manager
|
relayManager *relayClient.Manager
|
||||||
srWatcher *guard.SRWatcher
|
srWatcher *guard.SRWatcher
|
||||||
portForwardManager *portforward.Manager
|
|
||||||
|
|
||||||
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
|
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
|
||||||
onDisconnected func(remotePeer string)
|
onDisconnected func(remotePeer string)
|
||||||
@@ -148,20 +145,19 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
|||||||
|
|
||||||
dumpState := newStateDump(config.Key, connLog, services.StatusRecorder)
|
dumpState := newStateDump(config.Key, connLog, services.StatusRecorder)
|
||||||
var conn = &Conn{
|
var conn = &Conn{
|
||||||
Log: connLog,
|
Log: connLog,
|
||||||
config: config,
|
config: config,
|
||||||
statusRecorder: services.StatusRecorder,
|
statusRecorder: services.StatusRecorder,
|
||||||
signaler: services.Signaler,
|
signaler: services.Signaler,
|
||||||
iFaceDiscover: services.IFaceDiscover,
|
iFaceDiscover: services.IFaceDiscover,
|
||||||
relayManager: services.RelayManager,
|
relayManager: services.RelayManager,
|
||||||
srWatcher: services.SrWatcher,
|
srWatcher: services.SrWatcher,
|
||||||
portForwardManager: services.PortForwardManager,
|
statusRelay: worker.NewAtomicStatus(),
|
||||||
statusRelay: worker.NewAtomicStatus(),
|
statusICE: worker.NewAtomicStatus(),
|
||||||
statusICE: worker.NewAtomicStatus(),
|
dumpState: dumpState,
|
||||||
dumpState: dumpState,
|
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
|
||||||
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
|
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
|
||||||
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
|
metricsRecorder: services.MetricsRecorder,
|
||||||
metricsRecorder: services.MetricsRecorder,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer/conntype"
|
"github.com/netbirdio/netbird/client/internal/peer/conntype"
|
||||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
"github.com/netbirdio/netbird/client/internal/portforward"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
@@ -62,9 +61,6 @@ type WorkerICE struct {
|
|||||||
|
|
||||||
// we record the last known state of the ICE agent to avoid duplicate on disconnected events
|
// we record the last known state of the ICE agent to avoid duplicate on disconnected events
|
||||||
lastKnownState ice.ConnectionState
|
lastKnownState ice.ConnectionState
|
||||||
|
|
||||||
// portForwardAttempted tracks if we've already tried port forwarding this session
|
|
||||||
portForwardAttempted bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) {
|
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) {
|
||||||
@@ -218,8 +214,6 @@ func (w *WorkerICE) Close() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) {
|
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) {
|
||||||
w.portForwardAttempted = false
|
|
||||||
|
|
||||||
agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
|
agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create agent: %w", err)
|
return nil, fmt.Errorf("create agent: %w", err)
|
||||||
@@ -376,93 +370,6 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
|
|||||||
w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err)
|
w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if candidate.Type() == ice.CandidateTypeServerReflexive {
|
|
||||||
w.injectPortForwardedCandidate(candidate)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// injectPortForwardedCandidate signals an additional candidate using the pre-created port mapping.
|
|
||||||
func (w *WorkerICE) injectPortForwardedCandidate(srflxCandidate ice.Candidate) {
|
|
||||||
pfManager := w.conn.portForwardManager
|
|
||||||
if pfManager == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
mapping := pfManager.GetMapping()
|
|
||||||
if mapping == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.muxAgent.Lock()
|
|
||||||
if w.portForwardAttempted {
|
|
||||||
w.muxAgent.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.portForwardAttempted = true
|
|
||||||
w.muxAgent.Unlock()
|
|
||||||
|
|
||||||
forwardedCandidate, err := w.createForwardedCandidate(srflxCandidate, mapping)
|
|
||||||
if err != nil {
|
|
||||||
w.log.Warnf("create forwarded candidate: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.log.Debugf("injecting port-forwarded candidate: %s (mapping: %d -> %d via %s, priority: %d)",
|
|
||||||
forwardedCandidate.String(), mapping.InternalPort, mapping.ExternalPort, mapping.NATType, forwardedCandidate.Priority())
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
if err := w.signaler.SignalICECandidate(forwardedCandidate, w.config.Key); err != nil {
|
|
||||||
w.log.Errorf("signal port-forwarded candidate: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// createForwardedCandidate creates a new server reflexive candidate with the forwarded port.
|
|
||||||
// It uses the NAT gateway's external IP with the forwarded port.
|
|
||||||
func (w *WorkerICE) createForwardedCandidate(srflxCandidate ice.Candidate, mapping *portforward.Mapping) (ice.Candidate, error) {
|
|
||||||
var externalIP string
|
|
||||||
if mapping.ExternalIP != nil && !mapping.ExternalIP.IsUnspecified() {
|
|
||||||
externalIP = mapping.ExternalIP.String()
|
|
||||||
} else {
|
|
||||||
// Fallback to STUN-discovered address if NAT didn't provide external IP
|
|
||||||
externalIP = srflxCandidate.Address()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Per RFC 8445, the related address for srflx is the base (host candidate address).
|
|
||||||
// If the original srflx has unspecified related address, use its own address as base.
|
|
||||||
relAddr := srflxCandidate.RelatedAddress().Address
|
|
||||||
if relAddr == "" || relAddr == "0.0.0.0" || relAddr == "::" {
|
|
||||||
relAddr = srflxCandidate.Address()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Arbitrary +1000 boost on top of RFC 8445 priority to favor port-forwarded candidates
|
|
||||||
// over regular srflx during ICE connectivity checks.
|
|
||||||
priority := srflxCandidate.Priority() + 1000
|
|
||||||
|
|
||||||
candidate, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
|
|
||||||
Network: srflxCandidate.NetworkType().String(),
|
|
||||||
Address: externalIP,
|
|
||||||
Port: int(mapping.ExternalPort),
|
|
||||||
Component: srflxCandidate.Component(),
|
|
||||||
Priority: priority,
|
|
||||||
RelAddr: relAddr,
|
|
||||||
RelPort: int(mapping.InternalPort),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("create candidate: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, e := range srflxCandidate.Extensions() {
|
|
||||||
if e.Key == ice.ExtensionKeyCandidateID {
|
|
||||||
e.Value = srflxCandidate.ID()
|
|
||||||
}
|
|
||||||
if err := candidate.AddExtension(e); err != nil {
|
|
||||||
return nil, fmt.Errorf("add extension: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return candidate, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) {
|
func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) {
|
||||||
@@ -504,10 +411,10 @@ func (w *WorkerICE) logSuccessfulPaths(agent *icemaker.ThreadSafeAgent) {
|
|||||||
if !lok || !rok {
|
if !lok || !rok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
w.log.Debugf("successful ICE path %s: [%s %s %s:%d] <-> [%s %s %s:%d] rtt=%.3fms",
|
w.log.Debugf("successful ICE path %s: [%s %s %s] <-> [%s %s %s] rtt=%.3fms",
|
||||||
sessionID,
|
sessionID,
|
||||||
local.NetworkType(), local.Type(), local.Address(), local.Port(),
|
local.NetworkType(), local.Type(), local.Address(),
|
||||||
remote.NetworkType(), remote.Type(), remote.Address(), remote.Port(),
|
remote.NetworkType(), remote.Type(), remote.Address(),
|
||||||
stat.CurrentRoundTripTime*1000)
|
stat.CurrentRoundTripTime*1000)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,35 +0,0 @@
|
|||||||
package portforward
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
envDisableNATMapper = "NB_DISABLE_NAT_MAPPER"
|
|
||||||
envDisablePCPHealthCheck = "NB_DISABLE_PCP_HEALTH_CHECK"
|
|
||||||
)
|
|
||||||
|
|
||||||
func isDisabledByEnv() bool {
|
|
||||||
return parseBoolEnv(envDisableNATMapper)
|
|
||||||
}
|
|
||||||
|
|
||||||
func isHealthCheckDisabled() bool {
|
|
||||||
return parseBoolEnv(envDisablePCPHealthCheck)
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseBoolEnv(key string) bool {
|
|
||||||
val := os.Getenv(key)
|
|
||||||
if val == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
disabled, err := strconv.ParseBool(val)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to parse %s: %v", key, err)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return disabled
|
|
||||||
}
|
|
||||||
@@ -1,342 +0,0 @@
|
|||||||
//go:build !js
|
|
||||||
|
|
||||||
package portforward
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"regexp"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/libp2p/go-nat"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/portforward/pcp"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
defaultMappingTTL = 2 * time.Hour
|
|
||||||
healthCheckInterval = 1 * time.Minute
|
|
||||||
discoveryTimeout = 10 * time.Second
|
|
||||||
mappingDescription = "NetBird"
|
|
||||||
)
|
|
||||||
|
|
||||||
// upnpErrPermanentLeaseOnly matches UPnP error 725 in SOAP fault XML,
|
|
||||||
// allowing for whitespace/newlines between tags from different router firmware.
|
|
||||||
var upnpErrPermanentLeaseOnly = regexp.MustCompile(`<errorCode>\s*725\s*</errorCode>`)
|
|
||||||
|
|
||||||
// Mapping represents an active NAT port mapping.
|
|
||||||
type Mapping struct {
|
|
||||||
Protocol string
|
|
||||||
InternalPort uint16
|
|
||||||
ExternalPort uint16
|
|
||||||
ExternalIP net.IP
|
|
||||||
NATType string
|
|
||||||
// TTL is the lease duration. Zero means a permanent lease that never expires.
|
|
||||||
TTL time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: persist mapping state for crash recovery cleanup of permanent leases.
|
|
||||||
// Currently not done because State.Cleanup requires NAT gateway re-discovery,
|
|
||||||
// which blocks startup for ~10s when no gateway is present (affects all clients).
|
|
||||||
|
|
||||||
type Manager struct {
|
|
||||||
cancel context.CancelFunc
|
|
||||||
|
|
||||||
mapping *Mapping
|
|
||||||
mappingLock sync.Mutex
|
|
||||||
|
|
||||||
wgPort uint16
|
|
||||||
|
|
||||||
done chan struct{}
|
|
||||||
stopCtx chan context.Context
|
|
||||||
|
|
||||||
// protect exported functions
|
|
||||||
mu sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewManager creates a new port forwarding manager.
|
|
||||||
func NewManager() *Manager {
|
|
||||||
return &Manager{
|
|
||||||
stopCtx: make(chan context.Context, 1),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) Start(ctx context.Context, wgPort uint16) {
|
|
||||||
m.mu.Lock()
|
|
||||||
if m.cancel != nil {
|
|
||||||
m.mu.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if isDisabledByEnv() {
|
|
||||||
log.Infof("NAT port mapper disabled via %s", envDisableNATMapper)
|
|
||||||
m.mu.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if wgPort == 0 {
|
|
||||||
log.Warnf("invalid WireGuard port 0; NAT mapping disabled")
|
|
||||||
m.mu.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
m.wgPort = wgPort
|
|
||||||
|
|
||||||
m.done = make(chan struct{})
|
|
||||||
defer close(m.done)
|
|
||||||
|
|
||||||
ctx, m.cancel = context.WithCancel(ctx)
|
|
||||||
m.mu.Unlock()
|
|
||||||
|
|
||||||
gateway, mapping, err := m.setup(ctx)
|
|
||||||
if err != nil {
|
|
||||||
log.Infof("port forwarding setup: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
m.mappingLock.Lock()
|
|
||||||
m.mapping = mapping
|
|
||||||
m.mappingLock.Unlock()
|
|
||||||
|
|
||||||
m.renewLoop(ctx, gateway, mapping.TTL)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case cleanupCtx := <-m.stopCtx:
|
|
||||||
// block the Start while cleaned up gracefully
|
|
||||||
m.cleanup(cleanupCtx, gateway)
|
|
||||||
default:
|
|
||||||
// return Start immediately and cleanup in background
|
|
||||||
cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
||||||
go func() {
|
|
||||||
defer cleanupCancel()
|
|
||||||
m.cleanup(cleanupCtx, gateway)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetMapping returns the current mapping if ready, nil otherwise
|
|
||||||
func (m *Manager) GetMapping() *Mapping {
|
|
||||||
m.mappingLock.Lock()
|
|
||||||
defer m.mappingLock.Unlock()
|
|
||||||
|
|
||||||
if m.mapping == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
mapping := *m.mapping
|
|
||||||
return &mapping
|
|
||||||
}
|
|
||||||
|
|
||||||
// GracefullyStop cancels the manager and attempts to delete the port mapping.
|
|
||||||
// After GracefullyStop returns, the manager cannot be restarted.
|
|
||||||
func (m *Manager) GracefullyStop(ctx context.Context) error {
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
|
|
||||||
if m.cancel == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send cleanup context before cancelling, so Start picks it up after renewLoop exits.
|
|
||||||
m.startTearDown(ctx)
|
|
||||||
|
|
||||||
m.cancel()
|
|
||||||
m.cancel = nil
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
case <-m.done:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) setup(ctx context.Context) (nat.NAT, *Mapping, error) {
|
|
||||||
discoverCtx, discoverCancel := context.WithTimeout(ctx, discoveryTimeout)
|
|
||||||
defer discoverCancel()
|
|
||||||
|
|
||||||
gateway, err := discoverGateway(discoverCtx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("discover gateway: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("discovered NAT gateway: %s", gateway.Type())
|
|
||||||
|
|
||||||
mapping, err := m.createMapping(ctx, gateway)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("create port mapping: %w", err)
|
|
||||||
}
|
|
||||||
return gateway, mapping, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping, error) {
|
|
||||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
ttl := defaultMappingTTL
|
|
||||||
externalPort, err := gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, ttl)
|
|
||||||
if err != nil {
|
|
||||||
if !isPermanentLeaseRequired(err) {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
log.Infof("gateway only supports permanent leases, retrying with indefinite duration")
|
|
||||||
ttl = 0
|
|
||||||
externalPort, err = gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, ttl)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
externalIP, err := gateway.GetExternalAddress()
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to get external address: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
mapping := &Mapping{
|
|
||||||
Protocol: "udp",
|
|
||||||
InternalPort: m.wgPort,
|
|
||||||
ExternalPort: uint16(externalPort),
|
|
||||||
ExternalIP: externalIP,
|
|
||||||
NATType: gateway.Type(),
|
|
||||||
TTL: ttl,
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("created port mapping: %d -> %d via %s (external IP: %s)",
|
|
||||||
m.wgPort, externalPort, gateway.Type(), externalIP)
|
|
||||||
return mapping, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) renewLoop(ctx context.Context, gateway nat.NAT, ttl time.Duration) {
|
|
||||||
if ttl == 0 {
|
|
||||||
// Permanent mappings don't expire, just wait for cancellation
|
|
||||||
// but still run health checks for PCP gateways.
|
|
||||||
m.permanentLeaseLoop(ctx, gateway)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
renewTicker := time.NewTicker(ttl / 2)
|
|
||||||
healthTicker := time.NewTicker(healthCheckInterval)
|
|
||||||
defer renewTicker.Stop()
|
|
||||||
defer healthTicker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case <-renewTicker.C:
|
|
||||||
if err := m.renewMapping(ctx, gateway); err != nil {
|
|
||||||
log.Warnf("failed to renew port mapping: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
case <-healthTicker.C:
|
|
||||||
if m.checkHealthAndRecreate(ctx, gateway) {
|
|
||||||
renewTicker.Reset(ttl / 2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) permanentLeaseLoop(ctx context.Context, gateway nat.NAT) {
|
|
||||||
healthTicker := time.NewTicker(healthCheckInterval)
|
|
||||||
defer healthTicker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case <-healthTicker.C:
|
|
||||||
m.checkHealthAndRecreate(ctx, gateway)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) checkHealthAndRecreate(ctx context.Context, gateway nat.NAT) bool {
|
|
||||||
if isHealthCheckDisabled() {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
m.mappingLock.Lock()
|
|
||||||
hasMapping := m.mapping != nil
|
|
||||||
m.mappingLock.Unlock()
|
|
||||||
|
|
||||||
if !hasMapping {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
pcpNAT, ok := gateway.(*pcp.NAT)
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
epoch, serverRestarted, err := pcpNAT.CheckServerHealth(ctx)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("PCP health check failed: %v", err)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if serverRestarted {
|
|
||||||
log.Warnf("PCP server restart detected (epoch=%d), recreating port mapping", epoch)
|
|
||||||
if err := m.renewMapping(ctx, gateway); err != nil {
|
|
||||||
log.Errorf("failed to recreate port mapping after server restart: %v", err)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) renewMapping(ctx context.Context, gateway nat.NAT) error {
|
|
||||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
externalPort, err := gateway.AddPortMapping(ctx, m.mapping.Protocol, int(m.mapping.InternalPort), mappingDescription, m.mapping.TTL)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("add port mapping: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if uint16(externalPort) != m.mapping.ExternalPort {
|
|
||||||
log.Warnf("external port changed on renewal: %d -> %d (candidate may be stale)", m.mapping.ExternalPort, externalPort)
|
|
||||||
m.mappingLock.Lock()
|
|
||||||
m.mapping.ExternalPort = uint16(externalPort)
|
|
||||||
m.mappingLock.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("renewed port mapping: %d -> %d", m.mapping.InternalPort, m.mapping.ExternalPort)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) cleanup(ctx context.Context, gateway nat.NAT) {
|
|
||||||
m.mappingLock.Lock()
|
|
||||||
mapping := m.mapping
|
|
||||||
m.mapping = nil
|
|
||||||
m.mappingLock.Unlock()
|
|
||||||
|
|
||||||
if mapping == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := gateway.DeletePortMapping(ctx, mapping.Protocol, int(mapping.InternalPort)); err != nil {
|
|
||||||
log.Warnf("delete port mapping on stop: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("deleted port mapping for port %d", mapping.InternalPort)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) startTearDown(ctx context.Context) {
|
|
||||||
select {
|
|
||||||
case m.stopCtx <- ctx:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// isPermanentLeaseRequired checks if a UPnP error indicates the gateway only supports permanent leases (error 725).
|
|
||||||
func isPermanentLeaseRequired(err error) bool {
|
|
||||||
return err != nil && upnpErrPermanentLeaseOnly.MatchString(err.Error())
|
|
||||||
}
|
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
package portforward
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Mapping represents an active NAT port mapping.
|
|
||||||
type Mapping struct {
|
|
||||||
Protocol string
|
|
||||||
InternalPort uint16
|
|
||||||
ExternalPort uint16
|
|
||||||
ExternalIP net.IP
|
|
||||||
NATType string
|
|
||||||
// TTL is the lease duration. Zero means a permanent lease that never expires.
|
|
||||||
TTL time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
// Manager is a stub for js/wasm builds where NAT-PMP/UPnP is not supported.
|
|
||||||
type Manager struct{}
|
|
||||||
|
|
||||||
// NewManager returns a stub manager for js/wasm builds.
|
|
||||||
func NewManager() *Manager {
|
|
||||||
return &Manager{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start is a no-op on js/wasm: NAT-PMP/UPnP is not available in browser environments.
|
|
||||||
func (m *Manager) Start(context.Context, uint16) {
|
|
||||||
// no NAT traversal in wasm
|
|
||||||
}
|
|
||||||
|
|
||||||
// GracefullyStop is a no-op on js/wasm.
|
|
||||||
func (m *Manager) GracefullyStop(context.Context) error { return nil }
|
|
||||||
|
|
||||||
// GetMapping always returns nil on js/wasm.
|
|
||||||
func (m *Manager) GetMapping() *Mapping {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,201 +0,0 @@
|
|||||||
//go:build !js
|
|
||||||
|
|
||||||
package portforward
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mockNAT struct {
|
|
||||||
natType string
|
|
||||||
deviceAddr net.IP
|
|
||||||
externalAddr net.IP
|
|
||||||
internalAddr net.IP
|
|
||||||
mappings map[int]int
|
|
||||||
addMappingErr error
|
|
||||||
deleteMappingErr error
|
|
||||||
onlyPermanentLeases bool
|
|
||||||
lastTimeout time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
func newMockNAT() *mockNAT {
|
|
||||||
return &mockNAT{
|
|
||||||
natType: "Mock-NAT",
|
|
||||||
deviceAddr: net.ParseIP("192.168.1.1"),
|
|
||||||
externalAddr: net.ParseIP("203.0.113.50"),
|
|
||||||
internalAddr: net.ParseIP("192.168.1.100"),
|
|
||||||
mappings: make(map[int]int),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockNAT) Type() string {
|
|
||||||
return m.natType
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockNAT) GetDeviceAddress() (net.IP, error) {
|
|
||||||
return m.deviceAddr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockNAT) GetExternalAddress() (net.IP, error) {
|
|
||||||
return m.externalAddr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockNAT) GetInternalAddress() (net.IP, error) {
|
|
||||||
return m.internalAddr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockNAT) AddPortMapping(ctx context.Context, protocol string, internalPort int, description string, timeout time.Duration) (int, error) {
|
|
||||||
if m.addMappingErr != nil {
|
|
||||||
return 0, m.addMappingErr
|
|
||||||
}
|
|
||||||
if m.onlyPermanentLeases && timeout != 0 {
|
|
||||||
return 0, fmt.Errorf("SOAP fault. Code: | Explanation: | Detail: <UPnPError xmlns=\"urn:schemas-upnp-org:control-1-0\"><errorCode>725</errorCode><errorDescription>OnlyPermanentLeasesSupported</errorDescription></UPnPError>")
|
|
||||||
}
|
|
||||||
externalPort := internalPort
|
|
||||||
m.mappings[internalPort] = externalPort
|
|
||||||
m.lastTimeout = timeout
|
|
||||||
return externalPort, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockNAT) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error {
|
|
||||||
if m.deleteMappingErr != nil {
|
|
||||||
return m.deleteMappingErr
|
|
||||||
}
|
|
||||||
delete(m.mappings, internalPort)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestManager_CreateMapping(t *testing.T) {
|
|
||||||
m := NewManager()
|
|
||||||
m.wgPort = 51820
|
|
||||||
|
|
||||||
gateway := newMockNAT()
|
|
||||||
mapping, err := m.createMapping(context.Background(), gateway)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, mapping)
|
|
||||||
|
|
||||||
assert.Equal(t, "udp", mapping.Protocol)
|
|
||||||
assert.Equal(t, uint16(51820), mapping.InternalPort)
|
|
||||||
assert.Equal(t, uint16(51820), mapping.ExternalPort)
|
|
||||||
assert.Equal(t, "Mock-NAT", mapping.NATType)
|
|
||||||
assert.Equal(t, net.ParseIP("203.0.113.50").To4(), mapping.ExternalIP.To4())
|
|
||||||
assert.Equal(t, defaultMappingTTL, mapping.TTL)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestManager_GetMapping_ReturnsNilWhenNotReady(t *testing.T) {
|
|
||||||
m := NewManager()
|
|
||||||
assert.Nil(t, m.GetMapping())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestManager_GetMapping_ReturnsCopy(t *testing.T) {
|
|
||||||
m := NewManager()
|
|
||||||
m.mapping = &Mapping{
|
|
||||||
Protocol: "udp",
|
|
||||||
InternalPort: 51820,
|
|
||||||
ExternalPort: 51820,
|
|
||||||
}
|
|
||||||
|
|
||||||
mapping := m.GetMapping()
|
|
||||||
require.NotNil(t, mapping)
|
|
||||||
assert.Equal(t, uint16(51820), mapping.InternalPort)
|
|
||||||
|
|
||||||
// Mutating the returned copy should not affect the manager's mapping.
|
|
||||||
mapping.ExternalPort = 9999
|
|
||||||
assert.Equal(t, uint16(51820), m.GetMapping().ExternalPort)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestManager_Cleanup_DeletesMapping(t *testing.T) {
|
|
||||||
m := NewManager()
|
|
||||||
m.mapping = &Mapping{
|
|
||||||
Protocol: "udp",
|
|
||||||
InternalPort: 51820,
|
|
||||||
ExternalPort: 51820,
|
|
||||||
}
|
|
||||||
|
|
||||||
gateway := newMockNAT()
|
|
||||||
// Seed the mock so we can verify deletion.
|
|
||||||
gateway.mappings[51820] = 51820
|
|
||||||
|
|
||||||
m.cleanup(context.Background(), gateway)
|
|
||||||
|
|
||||||
_, exists := gateway.mappings[51820]
|
|
||||||
assert.False(t, exists, "mapping should be deleted from gateway")
|
|
||||||
assert.Nil(t, m.GetMapping(), "in-memory mapping should be cleared")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestManager_Cleanup_NilMapping(t *testing.T) {
|
|
||||||
m := NewManager()
|
|
||||||
gateway := newMockNAT()
|
|
||||||
|
|
||||||
// Should not panic or call gateway.
|
|
||||||
m.cleanup(context.Background(), gateway)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
func TestManager_CreateMapping_PermanentLeaseFallback(t *testing.T) {
|
|
||||||
m := NewManager()
|
|
||||||
m.wgPort = 51820
|
|
||||||
|
|
||||||
gateway := newMockNAT()
|
|
||||||
gateway.onlyPermanentLeases = true
|
|
||||||
|
|
||||||
mapping, err := m.createMapping(context.Background(), gateway)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, mapping)
|
|
||||||
|
|
||||||
assert.Equal(t, uint16(51820), mapping.InternalPort)
|
|
||||||
assert.Equal(t, time.Duration(0), mapping.TTL, "should return zero TTL for permanent lease")
|
|
||||||
assert.Equal(t, time.Duration(0), gateway.lastTimeout, "should have retried with zero duration")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsPermanentLeaseRequired(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
err error
|
|
||||||
expected bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "nil error",
|
|
||||||
err: nil,
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "UPnP error 725",
|
|
||||||
err: fmt.Errorf("SOAP fault. Code: | Detail: <UPnPError><errorCode>725</errorCode><errorDescription>OnlyPermanentLeasesSupported</errorDescription></UPnPError>"),
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "wrapped error with 725",
|
|
||||||
err: fmt.Errorf("add port mapping: %w", fmt.Errorf("Detail: <errorCode>725</errorCode>")),
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "error 725 with newlines in XML",
|
|
||||||
err: fmt.Errorf("<errorCode>\n 725\n</errorCode>"),
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "bare 725 without XML tag",
|
|
||||||
err: fmt.Errorf("error code 725"),
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "unrelated error",
|
|
||||||
err: fmt.Errorf("connection refused"),
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
assert.Equal(t, tt.expected, isPermanentLeaseRequired(tt.err))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,408 +0,0 @@
|
|||||||
package pcp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/rand"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
defaultTimeout = 3 * time.Second
|
|
||||||
responseBufferSize = 128
|
|
||||||
|
|
||||||
// RFC 6887 Section 8.1.1 retry timing
|
|
||||||
initialRetryDelay = 3 * time.Second
|
|
||||||
maxRetryDelay = 1024 * time.Second
|
|
||||||
maxRetries = 4 // 3s + 6s + 12s + 24s = 45s total worst case
|
|
||||||
)
|
|
||||||
|
|
||||||
// Client is a PCP protocol client.
|
|
||||||
// All methods are safe for concurrent use.
|
|
||||||
type Client struct {
|
|
||||||
gateway netip.Addr
|
|
||||||
timeout time.Duration
|
|
||||||
|
|
||||||
mu sync.Mutex
|
|
||||||
// localIP caches the resolved local IP address.
|
|
||||||
localIP netip.Addr
|
|
||||||
// lastEpoch is the last observed server epoch value.
|
|
||||||
lastEpoch uint32
|
|
||||||
// epochTime tracks when lastEpoch was received for state loss detection.
|
|
||||||
epochTime time.Time
|
|
||||||
// externalIP caches the external IP from the last successful MAP response.
|
|
||||||
externalIP netip.Addr
|
|
||||||
// epochStateLost is set when epoch indicates server restart.
|
|
||||||
epochStateLost bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewClient creates a new PCP client for the gateway at the given IP.
|
|
||||||
func NewClient(gateway net.IP) *Client {
|
|
||||||
addr, ok := netip.AddrFromSlice(gateway)
|
|
||||||
if !ok {
|
|
||||||
log.Debugf("invalid gateway IP: %v", gateway)
|
|
||||||
}
|
|
||||||
return &Client{
|
|
||||||
gateway: addr.Unmap(),
|
|
||||||
timeout: defaultTimeout,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewClientWithTimeout creates a new PCP client with a custom timeout.
|
|
||||||
func NewClientWithTimeout(gateway net.IP, timeout time.Duration) *Client {
|
|
||||||
addr, ok := netip.AddrFromSlice(gateway)
|
|
||||||
if !ok {
|
|
||||||
log.Debugf("invalid gateway IP: %v", gateway)
|
|
||||||
}
|
|
||||||
return &Client{
|
|
||||||
gateway: addr.Unmap(),
|
|
||||||
timeout: timeout,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetLocalIP sets the local IP address to use in PCP requests.
|
|
||||||
func (c *Client) SetLocalIP(ip net.IP) {
|
|
||||||
addr, ok := netip.AddrFromSlice(ip)
|
|
||||||
if !ok {
|
|
||||||
log.Debugf("invalid local IP: %v", ip)
|
|
||||||
}
|
|
||||||
c.mu.Lock()
|
|
||||||
c.localIP = addr.Unmap()
|
|
||||||
c.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Gateway returns the gateway IP address.
|
|
||||||
func (c *Client) Gateway() net.IP {
|
|
||||||
return c.gateway.AsSlice()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Announce sends a PCP ANNOUNCE request to discover PCP support.
|
|
||||||
// Returns the server's epoch time on success.
|
|
||||||
func (c *Client) Announce(ctx context.Context) (epoch uint32, err error) {
|
|
||||||
localIP, err := c.getLocalIP()
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("get local IP: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
req := buildAnnounceRequest(localIP)
|
|
||||||
resp, err := c.sendRequest(ctx, req)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("send announce: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
parsed, err := parseResponse(resp)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("parse announce response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if parsed.ResultCode != ResultSuccess {
|
|
||||||
return 0, fmt.Errorf("PCP ANNOUNCE failed: %s", ResultCodeString(parsed.ResultCode))
|
|
||||||
}
|
|
||||||
|
|
||||||
c.mu.Lock()
|
|
||||||
if c.updateEpochLocked(parsed.Epoch) {
|
|
||||||
log.Warnf("PCP server epoch indicates state loss - mappings may need refresh")
|
|
||||||
}
|
|
||||||
c.mu.Unlock()
|
|
||||||
return parsed.Epoch, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddPortMapping requests a port mapping from the PCP server.
|
|
||||||
func (c *Client) AddPortMapping(ctx context.Context, protocol string, internalPort int, lifetime time.Duration) (*MapResponse, error) {
|
|
||||||
return c.addPortMappingWithHint(ctx, protocol, internalPort, internalPort, netip.Addr{}, lifetime)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddPortMappingWithHint requests a port mapping with suggested external port and IP.
|
|
||||||
// Use lifetime <= 0 to delete a mapping.
|
|
||||||
func (c *Client) AddPortMappingWithHint(ctx context.Context, protocol string, internalPort, suggestedExtPort int, suggestedExtIP net.IP, lifetime time.Duration) (*MapResponse, error) {
|
|
||||||
var extIP netip.Addr
|
|
||||||
if suggestedExtIP != nil {
|
|
||||||
var ok bool
|
|
||||||
extIP, ok = netip.AddrFromSlice(suggestedExtIP)
|
|
||||||
if !ok {
|
|
||||||
log.Debugf("invalid suggested external IP: %v", suggestedExtIP)
|
|
||||||
}
|
|
||||||
extIP = extIP.Unmap()
|
|
||||||
}
|
|
||||||
return c.addPortMappingWithHint(ctx, protocol, internalPort, suggestedExtPort, extIP, lifetime)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) addPortMappingWithHint(ctx context.Context, protocol string, internalPort, suggestedExtPort int, suggestedExtIP netip.Addr, lifetime time.Duration) (*MapResponse, error) {
|
|
||||||
localIP, err := c.getLocalIP()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("get local IP: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
proto, err := protocolNumber(protocol)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("parse protocol: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var nonce [12]byte
|
|
||||||
if _, err := rand.Read(nonce[:]); err != nil {
|
|
||||||
return nil, fmt.Errorf("generate nonce: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert lifetime to seconds. Lifetime 0 means delete, so only apply
|
|
||||||
// default for positive durations that round to 0 seconds.
|
|
||||||
var lifetimeSec uint32
|
|
||||||
if lifetime > 0 {
|
|
||||||
lifetimeSec = uint32(lifetime.Seconds())
|
|
||||||
if lifetimeSec == 0 {
|
|
||||||
lifetimeSec = DefaultLifetime
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
req := buildMapRequest(localIP, nonce, proto, uint16(internalPort), uint16(suggestedExtPort), suggestedExtIP, lifetimeSec)
|
|
||||||
|
|
||||||
resp, err := c.sendRequest(ctx, req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("send map request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
mapResp, err := parseMapResponse(resp)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("parse map response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if mapResp.Nonce != nonce {
|
|
||||||
return nil, fmt.Errorf("nonce mismatch in response")
|
|
||||||
}
|
|
||||||
|
|
||||||
if mapResp.Protocol != proto {
|
|
||||||
return nil, fmt.Errorf("protocol mismatch: requested %d, got %d", proto, mapResp.Protocol)
|
|
||||||
}
|
|
||||||
if mapResp.InternalPort != uint16(internalPort) {
|
|
||||||
return nil, fmt.Errorf("internal port mismatch: requested %d, got %d", internalPort, mapResp.InternalPort)
|
|
||||||
}
|
|
||||||
|
|
||||||
if mapResp.ResultCode != ResultSuccess {
|
|
||||||
return nil, &Error{
|
|
||||||
Code: mapResp.ResultCode,
|
|
||||||
Message: ResultCodeString(mapResp.ResultCode),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.mu.Lock()
|
|
||||||
if c.updateEpochLocked(mapResp.Epoch) {
|
|
||||||
log.Warnf("PCP server epoch indicates state loss - mappings may need refresh")
|
|
||||||
}
|
|
||||||
c.cacheExternalIPLocked(mapResp.ExternalIP)
|
|
||||||
c.mu.Unlock()
|
|
||||||
return mapResp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeletePortMapping removes a port mapping by requesting zero lifetime.
|
|
||||||
func (c *Client) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error {
|
|
||||||
if _, err := c.addPortMappingWithHint(ctx, protocol, internalPort, 0, netip.Addr{}, 0); err != nil {
|
|
||||||
var pcpErr *Error
|
|
||||||
if errors.As(err, &pcpErr) && pcpErr.Code == ResultNotAuthorized {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return fmt.Errorf("delete mapping: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetExternalAddress returns the external IP address.
|
|
||||||
// First checks for a cached value from previous MAP responses.
|
|
||||||
// If not cached, creates a short-lived mapping to discover the external IP.
|
|
||||||
func (c *Client) GetExternalAddress(ctx context.Context) (net.IP, error) {
|
|
||||||
c.mu.Lock()
|
|
||||||
if c.externalIP.IsValid() {
|
|
||||||
ip := c.externalIP.AsSlice()
|
|
||||||
c.mu.Unlock()
|
|
||||||
return ip, nil
|
|
||||||
}
|
|
||||||
c.mu.Unlock()
|
|
||||||
|
|
||||||
// Use an ephemeral port in the dynamic range (49152-65535).
|
|
||||||
// Port 0 is not valid with UDP/TCP protocols per RFC 6887.
|
|
||||||
ephemeralPort := 49152 + int(uint16(time.Now().UnixNano()))%(65535-49152)
|
|
||||||
|
|
||||||
// Use minimal lifetime (1 second) for discovery.
|
|
||||||
resp, err := c.AddPortMapping(ctx, "udp", ephemeralPort, time.Second)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("create temporary mapping: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.DeletePortMapping(ctx, "udp", ephemeralPort); err != nil {
|
|
||||||
log.Debugf("cleanup temporary PCP mapping: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp.ExternalIP.AsSlice(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// LastEpoch returns the last observed server epoch value.
|
|
||||||
// A decrease in epoch indicates the server may have restarted and mappings may be lost.
|
|
||||||
func (c *Client) LastEpoch() uint32 {
|
|
||||||
c.mu.Lock()
|
|
||||||
defer c.mu.Unlock()
|
|
||||||
return c.lastEpoch
|
|
||||||
}
|
|
||||||
|
|
||||||
// EpochStateLost returns true if epoch state loss was detected and clears the flag.
|
|
||||||
func (c *Client) EpochStateLost() bool {
|
|
||||||
c.mu.Lock()
|
|
||||||
defer c.mu.Unlock()
|
|
||||||
lost := c.epochStateLost
|
|
||||||
c.epochStateLost = false
|
|
||||||
return lost
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateEpoch updates the epoch tracking and detects potential state loss.
|
|
||||||
// Returns true if state loss was detected (server likely restarted).
|
|
||||||
// Caller must hold c.mu.
|
|
||||||
func (c *Client) updateEpochLocked(newEpoch uint32) bool {
|
|
||||||
now := time.Now()
|
|
||||||
stateLost := false
|
|
||||||
|
|
||||||
// RFC 6887 Section 8.5: Detect invalid epoch indicating server state loss.
|
|
||||||
// client_delta = time since last response
|
|
||||||
// server_delta = epoch change since last response
|
|
||||||
// Invalid if: client_delta+2 < server_delta - server_delta/16
|
|
||||||
// OR: server_delta+2 < client_delta - client_delta/16
|
|
||||||
// The +2 handles quantization, /16 (6.25%) handles clock drift.
|
|
||||||
if !c.epochTime.IsZero() && c.lastEpoch > 0 {
|
|
||||||
clientDelta := uint32(now.Sub(c.epochTime).Seconds())
|
|
||||||
serverDelta := newEpoch - c.lastEpoch
|
|
||||||
|
|
||||||
// Check for epoch going backwards or jumping unexpectedly.
|
|
||||||
// Subtraction is safe: serverDelta/16 is always <= serverDelta.
|
|
||||||
if clientDelta+2 < serverDelta-(serverDelta/16) ||
|
|
||||||
serverDelta+2 < clientDelta-(clientDelta/16) {
|
|
||||||
stateLost = true
|
|
||||||
c.epochStateLost = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.lastEpoch = newEpoch
|
|
||||||
c.epochTime = now
|
|
||||||
return stateLost
|
|
||||||
}
|
|
||||||
|
|
||||||
// cacheExternalIP stores the external IP from a successful MAP response.
|
|
||||||
// Caller must hold c.mu.
|
|
||||||
func (c *Client) cacheExternalIPLocked(ip netip.Addr) {
|
|
||||||
if ip.IsValid() && !ip.IsUnspecified() {
|
|
||||||
c.externalIP = ip
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// sendRequest sends a PCP request with retries per RFC 6887 Section 8.1.1.
|
|
||||||
func (c *Client) sendRequest(ctx context.Context, req []byte) ([]byte, error) {
|
|
||||||
addr := &net.UDPAddr{IP: c.gateway.AsSlice(), Port: Port}
|
|
||||||
|
|
||||||
var lastErr error
|
|
||||||
delay := initialRetryDelay
|
|
||||||
|
|
||||||
for range maxRetries {
|
|
||||||
resp, err := c.sendOnce(ctx, addr, req)
|
|
||||||
if err == nil {
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
lastErr = err
|
|
||||||
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
return nil, ctx.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
// RFC 6887 Section 8.1.1: RT = (1 + RAND) * MIN(2 * RTprev, MRT)
|
|
||||||
// RAND is random between -0.1 and +0.1
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, ctx.Err()
|
|
||||||
case <-time.After(retryDelayWithJitter(delay)):
|
|
||||||
}
|
|
||||||
delay = min(delay*2, maxRetryDelay)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("PCP request failed after %d retries: %w", maxRetries, lastErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// retryDelayWithJitter applies RFC 6887 jitter: multiply by (1 + RAND) where RAND is [-0.1, +0.1].
|
|
||||||
func retryDelayWithJitter(d time.Duration) time.Duration {
|
|
||||||
var b [1]byte
|
|
||||||
_, _ = rand.Read(b[:])
|
|
||||||
// Convert byte to range [-0.1, +0.1]: (b/255 * 0.2) - 0.1
|
|
||||||
jitter := (float64(b[0])/255.0)*0.2 - 0.1
|
|
||||||
return time.Duration(float64(d) * (1 + jitter))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) sendOnce(ctx context.Context, addr *net.UDPAddr, req []byte) ([]byte, error) {
|
|
||||||
// Use ListenUDP instead of DialUDP to validate response source address per RFC 6887 §8.3.
|
|
||||||
conn, err := net.ListenUDP("udp", nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("listen: %w", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err := conn.Close(); err != nil {
|
|
||||||
log.Debugf("close UDP connection: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
timeout := c.timeout
|
|
||||||
if deadline, ok := ctx.Deadline(); ok {
|
|
||||||
if remaining := time.Until(deadline); remaining < timeout {
|
|
||||||
timeout = remaining
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil {
|
|
||||||
return nil, fmt.Errorf("set deadline: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := conn.WriteToUDP(req, addr); err != nil {
|
|
||||||
return nil, fmt.Errorf("write: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := make([]byte, responseBufferSize)
|
|
||||||
n, from, err := conn.ReadFromUDP(resp)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("read: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RFC 6887 §8.3: Validate response came from expected PCP server.
|
|
||||||
if !from.IP.Equal(addr.IP) {
|
|
||||||
return nil, fmt.Errorf("response from unexpected source %s (expected %s)", from.IP, addr.IP)
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp[:n], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) getLocalIP() (netip.Addr, error) {
|
|
||||||
c.mu.Lock()
|
|
||||||
defer c.mu.Unlock()
|
|
||||||
|
|
||||||
if !c.localIP.IsValid() {
|
|
||||||
return netip.Addr{}, fmt.Errorf("local IP not set for gateway %s", c.gateway)
|
|
||||||
}
|
|
||||||
return c.localIP, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func protocolNumber(protocol string) (uint8, error) {
|
|
||||||
switch protocol {
|
|
||||||
case "udp", "UDP":
|
|
||||||
return ProtoUDP, nil
|
|
||||||
case "tcp", "TCP":
|
|
||||||
return ProtoTCP, nil
|
|
||||||
default:
|
|
||||||
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error represents a PCP error response.
|
|
||||||
type Error struct {
|
|
||||||
Code uint8
|
|
||||||
Message string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *Error) Error() string {
|
|
||||||
return fmt.Sprintf("PCP error: %s (%d)", e.Message, e.Code)
|
|
||||||
}
|
|
||||||
@@ -1,187 +0,0 @@
|
|||||||
package pcp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestAddrConversion(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
addr netip.Addr
|
|
||||||
}{
|
|
||||||
{"IPv4", netip.MustParseAddr("192.168.1.100")},
|
|
||||||
{"IPv4 loopback", netip.MustParseAddr("127.0.0.1")},
|
|
||||||
{"IPv6", netip.MustParseAddr("2001:db8::1")},
|
|
||||||
{"IPv6 loopback", netip.MustParseAddr("::1")},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
b16 := addrTo16(tt.addr)
|
|
||||||
|
|
||||||
recovered := addrFrom16(b16)
|
|
||||||
assert.Equal(t, tt.addr, recovered, "address should round-trip")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildAnnounceRequest(t *testing.T) {
|
|
||||||
clientIP := netip.MustParseAddr("192.168.1.100")
|
|
||||||
req := buildAnnounceRequest(clientIP)
|
|
||||||
|
|
||||||
require.Len(t, req, headerSize)
|
|
||||||
assert.Equal(t, byte(Version), req[0], "version")
|
|
||||||
assert.Equal(t, byte(OpAnnounce), req[1], "opcode")
|
|
||||||
|
|
||||||
// Check client IP is properly encoded as IPv4-mapped IPv6
|
|
||||||
assert.Equal(t, byte(0xff), req[18], "IPv4-mapped prefix byte 10")
|
|
||||||
assert.Equal(t, byte(0xff), req[19], "IPv4-mapped prefix byte 11")
|
|
||||||
assert.Equal(t, byte(192), req[20], "IP octet 1")
|
|
||||||
assert.Equal(t, byte(168), req[21], "IP octet 2")
|
|
||||||
assert.Equal(t, byte(1), req[22], "IP octet 3")
|
|
||||||
assert.Equal(t, byte(100), req[23], "IP octet 4")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildMapRequest(t *testing.T) {
|
|
||||||
clientIP := netip.MustParseAddr("192.168.1.100")
|
|
||||||
nonce := [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}
|
|
||||||
req := buildMapRequest(clientIP, nonce, ProtoUDP, 51820, 51820, netip.Addr{}, 3600)
|
|
||||||
|
|
||||||
require.Len(t, req, mapRequestSize)
|
|
||||||
assert.Equal(t, byte(Version), req[0], "version")
|
|
||||||
assert.Equal(t, byte(OpMap), req[1], "opcode")
|
|
||||||
|
|
||||||
// Lifetime at bytes 4-7
|
|
||||||
assert.Equal(t, uint32(3600), (uint32(req[4])<<24)|(uint32(req[5])<<16)|(uint32(req[6])<<8)|uint32(req[7]), "lifetime")
|
|
||||||
|
|
||||||
// Nonce at bytes 24-35
|
|
||||||
assert.Equal(t, nonce[:], req[24:36], "nonce")
|
|
||||||
|
|
||||||
// Protocol at byte 36
|
|
||||||
assert.Equal(t, byte(ProtoUDP), req[36], "protocol")
|
|
||||||
|
|
||||||
// Internal port at bytes 40-41
|
|
||||||
assert.Equal(t, uint16(51820), (uint16(req[40])<<8)|uint16(req[41]), "internal port")
|
|
||||||
|
|
||||||
// External port at bytes 42-43
|
|
||||||
assert.Equal(t, uint16(51820), (uint16(req[42])<<8)|uint16(req[43]), "external port")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseResponse(t *testing.T) {
|
|
||||||
// Construct a valid ANNOUNCE response
|
|
||||||
resp := make([]byte, headerSize)
|
|
||||||
resp[0] = Version
|
|
||||||
resp[1] = OpAnnounce | OpReply
|
|
||||||
// Result code = 0 (success)
|
|
||||||
// Lifetime = 0
|
|
||||||
// Epoch = 12345
|
|
||||||
resp[8] = 0
|
|
||||||
resp[9] = 0
|
|
||||||
resp[10] = 0x30
|
|
||||||
resp[11] = 0x39
|
|
||||||
|
|
||||||
parsed, err := parseResponse(resp)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, uint8(Version), parsed.Version)
|
|
||||||
assert.Equal(t, uint8(OpAnnounce|OpReply), parsed.Opcode)
|
|
||||||
assert.Equal(t, uint8(ResultSuccess), parsed.ResultCode)
|
|
||||||
assert.Equal(t, uint32(12345), parsed.Epoch)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseResponseErrors(t *testing.T) {
|
|
||||||
t.Run("too short", func(t *testing.T) {
|
|
||||||
_, err := parseResponse([]byte{1, 2, 3})
|
|
||||||
assert.Error(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("wrong version", func(t *testing.T) {
|
|
||||||
resp := make([]byte, headerSize)
|
|
||||||
resp[0] = 1 // Wrong version
|
|
||||||
resp[1] = OpReply
|
|
||||||
_, err := parseResponse(resp)
|
|
||||||
assert.Error(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("missing reply bit", func(t *testing.T) {
|
|
||||||
resp := make([]byte, headerSize)
|
|
||||||
resp[0] = Version
|
|
||||||
resp[1] = OpAnnounce // Missing OpReply bit
|
|
||||||
_, err := parseResponse(resp)
|
|
||||||
assert.Error(t, err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResultCodeString(t *testing.T) {
|
|
||||||
assert.Equal(t, "SUCCESS", ResultCodeString(ResultSuccess))
|
|
||||||
assert.Equal(t, "NOT_AUTHORIZED", ResultCodeString(ResultNotAuthorized))
|
|
||||||
assert.Equal(t, "ADDRESS_MISMATCH", ResultCodeString(ResultAddressMismatch))
|
|
||||||
assert.Contains(t, ResultCodeString(255), "UNKNOWN")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProtocolNumber(t *testing.T) {
|
|
||||||
proto, err := protocolNumber("udp")
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, uint8(ProtoUDP), proto)
|
|
||||||
|
|
||||||
proto, err = protocolNumber("tcp")
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, uint8(ProtoTCP), proto)
|
|
||||||
|
|
||||||
proto, err = protocolNumber("UDP")
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, uint8(ProtoUDP), proto)
|
|
||||||
|
|
||||||
_, err = protocolNumber("icmp")
|
|
||||||
assert.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClientCreation(t *testing.T) {
|
|
||||||
gateway := netip.MustParseAddr("192.168.1.1").AsSlice()
|
|
||||||
|
|
||||||
client := NewClient(gateway)
|
|
||||||
assert.Equal(t, net.IP(gateway), client.Gateway())
|
|
||||||
assert.Equal(t, defaultTimeout, client.timeout)
|
|
||||||
|
|
||||||
clientWithTimeout := NewClientWithTimeout(gateway, 5*time.Second)
|
|
||||||
assert.Equal(t, 5*time.Second, clientWithTimeout.timeout)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNATType(t *testing.T) {
|
|
||||||
n := NewNAT(netip.MustParseAddr("192.168.1.1").AsSlice(), netip.MustParseAddr("192.168.1.100").AsSlice())
|
|
||||||
assert.Equal(t, "PCP", n.Type())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Integration test - skipped unless PCP_TEST_GATEWAY env is set
|
|
||||||
func TestClientIntegration(t *testing.T) {
|
|
||||||
t.Skip("Integration test - run manually with PCP_TEST_GATEWAY=<gateway-ip>")
|
|
||||||
|
|
||||||
gateway := netip.MustParseAddr("10.0.1.1").AsSlice() // Change to your test gateway
|
|
||||||
localIP := netip.MustParseAddr("10.0.1.100").AsSlice() // Change to your local IP
|
|
||||||
|
|
||||||
client := NewClient(gateway)
|
|
||||||
client.SetLocalIP(localIP)
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// Test ANNOUNCE
|
|
||||||
epoch, err := client.Announce(ctx)
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Logf("Server epoch: %d", epoch)
|
|
||||||
|
|
||||||
// Test MAP
|
|
||||||
resp, err := client.AddPortMapping(ctx, "udp", 51820, 1*time.Hour)
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Logf("Mapping: internal=%d external=%d externalIP=%s",
|
|
||||||
resp.InternalPort, resp.ExternalPort, resp.ExternalIP)
|
|
||||||
|
|
||||||
// Cleanup
|
|
||||||
err = client.DeletePortMapping(ctx, "udp", 51820)
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
@@ -1,209 +0,0 @@
|
|||||||
package pcp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/libp2p/go-nat"
|
|
||||||
"github.com/libp2p/go-netroute"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ nat.NAT = (*NAT)(nil)
|
|
||||||
|
|
||||||
// NAT implements the go-nat NAT interface using PCP.
|
|
||||||
// Supports dual-stack (IPv4 and IPv6) when available.
|
|
||||||
// All methods are safe for concurrent use.
|
|
||||||
//
|
|
||||||
// TODO: IPv6 pinholes use the local IPv6 address. If the address changes
|
|
||||||
// (e.g., due to SLAAC rotation or network change), the pinhole becomes stale
|
|
||||||
// and needs to be recreated with the new address.
|
|
||||||
type NAT struct {
|
|
||||||
client *Client
|
|
||||||
|
|
||||||
mu sync.RWMutex
|
|
||||||
// client6 is the IPv6 PCP client, nil if IPv6 is unavailable.
|
|
||||||
client6 *Client
|
|
||||||
// localIP6 caches the local IPv6 address used for PCP requests.
|
|
||||||
localIP6 netip.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewNAT creates a new NAT instance backed by PCP.
|
|
||||||
func NewNAT(gateway, localIP net.IP) *NAT {
|
|
||||||
client := NewClient(gateway)
|
|
||||||
client.SetLocalIP(localIP)
|
|
||||||
return &NAT{
|
|
||||||
client: client,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Type returns "PCP" as the NAT type.
|
|
||||||
func (n *NAT) Type() string {
|
|
||||||
return "PCP"
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDeviceAddress returns the gateway IP address.
|
|
||||||
func (n *NAT) GetDeviceAddress() (net.IP, error) {
|
|
||||||
return n.client.Gateway(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetExternalAddress returns the external IP address.
|
|
||||||
func (n *NAT) GetExternalAddress() (net.IP, error) {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
return n.client.GetExternalAddress(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetInternalAddress returns the local IP address used to communicate with the gateway.
|
|
||||||
func (n *NAT) GetInternalAddress() (net.IP, error) {
|
|
||||||
addr, err := n.client.getLocalIP()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return addr.AsSlice(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddPortMapping creates a port mapping on both IPv4 and IPv6 (if available).
|
|
||||||
func (n *NAT) AddPortMapping(ctx context.Context, protocol string, internalPort int, _ string, timeout time.Duration) (int, error) {
|
|
||||||
resp, err := n.client.AddPortMapping(ctx, protocol, internalPort, timeout)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("add mapping: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
n.mu.RLock()
|
|
||||||
client6 := n.client6
|
|
||||||
localIP6 := n.localIP6
|
|
||||||
n.mu.RUnlock()
|
|
||||||
|
|
||||||
if client6 == nil {
|
|
||||||
return int(resp.ExternalPort), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := client6.AddPortMapping(ctx, protocol, internalPort, timeout); err != nil {
|
|
||||||
log.Warnf("IPv6 PCP mapping failed (continuing with IPv4): %v", err)
|
|
||||||
return int(resp.ExternalPort), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("created IPv6 PCP pinhole: %s:%d", localIP6, internalPort)
|
|
||||||
return int(resp.ExternalPort), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeletePortMapping removes a port mapping from both IPv4 and IPv6.
|
|
||||||
func (n *NAT) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error {
|
|
||||||
err := n.client.DeletePortMapping(ctx, protocol, internalPort)
|
|
||||||
|
|
||||||
n.mu.RLock()
|
|
||||||
client6 := n.client6
|
|
||||||
n.mu.RUnlock()
|
|
||||||
|
|
||||||
if client6 != nil {
|
|
||||||
if err6 := client6.DeletePortMapping(ctx, protocol, internalPort); err6 != nil {
|
|
||||||
log.Warnf("IPv6 PCP delete mapping failed: %v", err6)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("delete mapping: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CheckServerHealth sends an ANNOUNCE to verify the server is still responsive.
|
|
||||||
// Returns the current epoch and whether the server may have restarted (epoch state loss detected).
|
|
||||||
func (n *NAT) CheckServerHealth(ctx context.Context) (epoch uint32, serverRestarted bool, err error) {
|
|
||||||
epoch, err = n.client.Announce(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return 0, false, fmt.Errorf("announce: %w", err)
|
|
||||||
}
|
|
||||||
return epoch, n.client.EpochStateLost(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DiscoverPCP attempts to discover a PCP-capable gateway.
|
|
||||||
// Returns a NAT interface if PCP is supported, or an error otherwise.
|
|
||||||
// Discovers both IPv4 and IPv6 gateways when available.
|
|
||||||
func DiscoverPCP(ctx context.Context) (nat.NAT, error) {
|
|
||||||
gateway, localIP, err := getDefaultGateway()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("get default gateway: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
client := NewClient(gateway)
|
|
||||||
client.SetLocalIP(localIP)
|
|
||||||
if _, err := client.Announce(ctx); err != nil {
|
|
||||||
return nil, fmt.Errorf("PCP announce: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
result := &NAT{client: client}
|
|
||||||
discoverIPv6(ctx, result)
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func discoverIPv6(ctx context.Context, result *NAT) {
|
|
||||||
gateway6, localIP6, err := getDefaultGateway6()
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("IPv6 gateway discovery failed: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
client6 := NewClient(gateway6)
|
|
||||||
client6.SetLocalIP(localIP6)
|
|
||||||
if _, err := client6.Announce(ctx); err != nil {
|
|
||||||
log.Debugf("PCP IPv6 announce failed: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
addr, ok := netip.AddrFromSlice(localIP6)
|
|
||||||
if !ok {
|
|
||||||
log.Debugf("invalid IPv6 local IP: %v", localIP6)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
result.mu.Lock()
|
|
||||||
result.client6 = client6
|
|
||||||
result.localIP6 = addr
|
|
||||||
result.mu.Unlock()
|
|
||||||
log.Debugf("PCP IPv6 gateway discovered: %s (local: %s)", gateway6, localIP6)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getDefaultGateway returns the default IPv4 gateway and local IP using the system routing table.
|
|
||||||
func getDefaultGateway() (gateway net.IP, localIP net.IP, err error) {
|
|
||||||
router, err := netroute.New()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, gateway, localIP, err = router.Route(net.IPv4zero)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if gateway == nil {
|
|
||||||
return nil, nil, nat.ErrNoNATFound
|
|
||||||
}
|
|
||||||
|
|
||||||
return gateway, localIP, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getDefaultGateway6 returns the default IPv6 gateway IP address using the system routing table.
|
|
||||||
func getDefaultGateway6() (gateway net.IP, localIP net.IP, err error) {
|
|
||||||
router, err := netroute.New()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, gateway, localIP, err = router.Route(net.IPv6zero)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if gateway == nil {
|
|
||||||
return nil, nil, nat.ErrNoNATFound
|
|
||||||
}
|
|
||||||
|
|
||||||
return gateway, localIP, nil
|
|
||||||
}
|
|
||||||
@@ -1,225 +0,0 @@
|
|||||||
// Package pcp implements the Port Control Protocol (RFC 6887).
|
|
||||||
//
|
|
||||||
// # Implemented Features
|
|
||||||
//
|
|
||||||
// - ANNOUNCE opcode: Discovers PCP server support
|
|
||||||
// - MAP opcode: Creates/deletes port mappings (IPv4 NAT) and firewall pinholes (IPv6)
|
|
||||||
// - Dual-stack: Simultaneous IPv4 and IPv6 support via separate clients
|
|
||||||
// - Nonce validation: Prevents response spoofing
|
|
||||||
// - Epoch tracking: Detects server restarts per Section 8.5
|
|
||||||
// - RFC-compliant retry timing: 3s initial, exponential backoff to 1024s max (Section 8.1.1)
|
|
||||||
//
|
|
||||||
// # Not Implemented
|
|
||||||
//
|
|
||||||
// - PEER opcode: For outbound peer connections (not needed for inbound NAT traversal)
|
|
||||||
// - THIRD_PARTY option: For managing mappings on behalf of other devices
|
|
||||||
// - PREFER_FAILURE option: Requires exact external port or fail (IPv4 NAT only, not needed for IPv6 pinholing)
|
|
||||||
// - FILTER option: To restrict remote peer addresses
|
|
||||||
//
|
|
||||||
// These optional features are omitted because the primary use case is simple
|
|
||||||
// port forwarding for WireGuard, which only requires MAP with default behavior.
|
|
||||||
package pcp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// Version is the PCP protocol version (RFC 6887).
|
|
||||||
Version = 2
|
|
||||||
|
|
||||||
// Port is the standard PCP server port.
|
|
||||||
Port = 5351
|
|
||||||
|
|
||||||
// DefaultLifetime is the default requested mapping lifetime in seconds.
|
|
||||||
DefaultLifetime = 7200 // 2 hours
|
|
||||||
|
|
||||||
// Header sizes
|
|
||||||
headerSize = 24
|
|
||||||
mapPayloadSize = 36
|
|
||||||
mapRequestSize = headerSize + mapPayloadSize // 60 bytes
|
|
||||||
)
|
|
||||||
|
|
||||||
// Opcodes
|
|
||||||
const (
|
|
||||||
OpAnnounce = 0
|
|
||||||
OpMap = 1
|
|
||||||
OpPeer = 2
|
|
||||||
OpReply = 0x80 // OR'd with opcode in responses
|
|
||||||
)
|
|
||||||
|
|
||||||
// Protocol numbers for MAP requests
|
|
||||||
const (
|
|
||||||
ProtoUDP = 17
|
|
||||||
ProtoTCP = 6
|
|
||||||
)
|
|
||||||
|
|
||||||
// Result codes (RFC 6887 Section 7.4)
|
|
||||||
const (
|
|
||||||
ResultSuccess = 0
|
|
||||||
ResultUnsuppVersion = 1
|
|
||||||
ResultNotAuthorized = 2
|
|
||||||
ResultMalformedRequest = 3
|
|
||||||
ResultUnsuppOpcode = 4
|
|
||||||
ResultUnsuppOption = 5
|
|
||||||
ResultMalformedOption = 6
|
|
||||||
ResultNetworkFailure = 7
|
|
||||||
ResultNoResources = 8
|
|
||||||
ResultUnsuppProtocol = 9
|
|
||||||
ResultUserExQuota = 10
|
|
||||||
ResultCannotProvideExt = 11
|
|
||||||
ResultAddressMismatch = 12
|
|
||||||
ResultExcessiveRemotePeers = 13
|
|
||||||
)
|
|
||||||
|
|
||||||
// ResultCodeString returns a human-readable string for a result code.
|
|
||||||
func ResultCodeString(code uint8) string {
|
|
||||||
switch code {
|
|
||||||
case ResultSuccess:
|
|
||||||
return "SUCCESS"
|
|
||||||
case ResultUnsuppVersion:
|
|
||||||
return "UNSUPP_VERSION"
|
|
||||||
case ResultNotAuthorized:
|
|
||||||
return "NOT_AUTHORIZED"
|
|
||||||
case ResultMalformedRequest:
|
|
||||||
return "MALFORMED_REQUEST"
|
|
||||||
case ResultUnsuppOpcode:
|
|
||||||
return "UNSUPP_OPCODE"
|
|
||||||
case ResultUnsuppOption:
|
|
||||||
return "UNSUPP_OPTION"
|
|
||||||
case ResultMalformedOption:
|
|
||||||
return "MALFORMED_OPTION"
|
|
||||||
case ResultNetworkFailure:
|
|
||||||
return "NETWORK_FAILURE"
|
|
||||||
case ResultNoResources:
|
|
||||||
return "NO_RESOURCES"
|
|
||||||
case ResultUnsuppProtocol:
|
|
||||||
return "UNSUPP_PROTOCOL"
|
|
||||||
case ResultUserExQuota:
|
|
||||||
return "USER_EX_QUOTA"
|
|
||||||
case ResultCannotProvideExt:
|
|
||||||
return "CANNOT_PROVIDE_EXTERNAL"
|
|
||||||
case ResultAddressMismatch:
|
|
||||||
return "ADDRESS_MISMATCH"
|
|
||||||
case ResultExcessiveRemotePeers:
|
|
||||||
return "EXCESSIVE_REMOTE_PEERS"
|
|
||||||
default:
|
|
||||||
return fmt.Sprintf("UNKNOWN(%d)", code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Response represents a parsed PCP response header.
|
|
||||||
type Response struct {
|
|
||||||
Version uint8
|
|
||||||
Opcode uint8
|
|
||||||
ResultCode uint8
|
|
||||||
Lifetime uint32
|
|
||||||
Epoch uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// MapResponse contains the full response to a MAP request.
|
|
||||||
type MapResponse struct {
|
|
||||||
Response
|
|
||||||
Nonce [12]byte
|
|
||||||
Protocol uint8
|
|
||||||
InternalPort uint16
|
|
||||||
ExternalPort uint16
|
|
||||||
ExternalIP netip.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
// addrTo16 converts an address to its 16-byte IPv4-mapped IPv6 representation.
|
|
||||||
func addrTo16(addr netip.Addr) [16]byte {
|
|
||||||
if addr.Is4() {
|
|
||||||
return netip.AddrFrom4(addr.As4()).As16()
|
|
||||||
}
|
|
||||||
return addr.As16()
|
|
||||||
}
|
|
||||||
|
|
||||||
// addrFrom16 extracts an address from a 16-byte representation, unmapping IPv4.
|
|
||||||
func addrFrom16(b [16]byte) netip.Addr {
|
|
||||||
return netip.AddrFrom16(b).Unmap()
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildAnnounceRequest creates a PCP ANNOUNCE request packet.
|
|
||||||
func buildAnnounceRequest(clientIP netip.Addr) []byte {
|
|
||||||
req := make([]byte, headerSize)
|
|
||||||
req[0] = Version
|
|
||||||
req[1] = OpAnnounce
|
|
||||||
mapped := addrTo16(clientIP)
|
|
||||||
copy(req[8:24], mapped[:])
|
|
||||||
return req
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildMapRequest creates a PCP MAP request packet.
|
|
||||||
func buildMapRequest(clientIP netip.Addr, nonce [12]byte, protocol uint8, internalPort, suggestedExtPort uint16, suggestedExtIP netip.Addr, lifetime uint32) []byte {
|
|
||||||
req := make([]byte, mapRequestSize)
|
|
||||||
|
|
||||||
// Header
|
|
||||||
req[0] = Version
|
|
||||||
req[1] = OpMap
|
|
||||||
binary.BigEndian.PutUint32(req[4:8], lifetime)
|
|
||||||
mapped := addrTo16(clientIP)
|
|
||||||
copy(req[8:24], mapped[:])
|
|
||||||
|
|
||||||
// MAP payload
|
|
||||||
copy(req[24:36], nonce[:])
|
|
||||||
req[36] = protocol
|
|
||||||
binary.BigEndian.PutUint16(req[40:42], internalPort)
|
|
||||||
binary.BigEndian.PutUint16(req[42:44], suggestedExtPort)
|
|
||||||
if suggestedExtIP.IsValid() {
|
|
||||||
extMapped := addrTo16(suggestedExtIP)
|
|
||||||
copy(req[44:60], extMapped[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
return req
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseResponse parses the common PCP response header.
|
|
||||||
func parseResponse(data []byte) (*Response, error) {
|
|
||||||
if len(data) < headerSize {
|
|
||||||
return nil, fmt.Errorf("response too short: %d bytes", len(data))
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := &Response{
|
|
||||||
Version: data[0],
|
|
||||||
Opcode: data[1],
|
|
||||||
ResultCode: data[3], // Byte 2 is reserved, byte 3 is result code (RFC 6887 §7.2)
|
|
||||||
Lifetime: binary.BigEndian.Uint32(data[4:8]),
|
|
||||||
Epoch: binary.BigEndian.Uint32(data[8:12]),
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.Version != Version {
|
|
||||||
return nil, fmt.Errorf("unsupported PCP version: %d", resp.Version)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.Opcode&OpReply == 0 {
|
|
||||||
return nil, fmt.Errorf("response missing reply bit: opcode=0x%02x", resp.Opcode)
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseMapResponse parses a complete MAP response.
|
|
||||||
func parseMapResponse(data []byte) (*MapResponse, error) {
|
|
||||||
if len(data) < mapRequestSize {
|
|
||||||
return nil, fmt.Errorf("MAP response too short: %d bytes", len(data))
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := parseResponse(data)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("parse header: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
mapResp := &MapResponse{
|
|
||||||
Response: *resp,
|
|
||||||
Protocol: data[36],
|
|
||||||
InternalPort: binary.BigEndian.Uint16(data[40:42]),
|
|
||||||
ExternalPort: binary.BigEndian.Uint16(data[42:44]),
|
|
||||||
ExternalIP: addrFrom16([16]byte(data[44:60])),
|
|
||||||
}
|
|
||||||
copy(mapResp.Nonce[:], data[24:36])
|
|
||||||
|
|
||||||
return mapResp, nil
|
|
||||||
}
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
//go:build !js
|
|
||||||
|
|
||||||
package portforward
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/libp2p/go-nat"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/portforward/pcp"
|
|
||||||
)
|
|
||||||
|
|
||||||
// discoverGateway is the function used for NAT gateway discovery.
|
|
||||||
// It can be replaced in tests to avoid real network operations.
|
|
||||||
// Tries PCP first, then falls back to NAT-PMP/UPnP.
|
|
||||||
var discoverGateway = defaultDiscoverGateway
|
|
||||||
|
|
||||||
func defaultDiscoverGateway(ctx context.Context) (nat.NAT, error) {
|
|
||||||
pcpGateway, err := pcp.DiscoverPCP(ctx)
|
|
||||||
if err == nil {
|
|
||||||
return pcpGateway, nil
|
|
||||||
}
|
|
||||||
log.Debugf("PCP discovery failed: %v, trying NAT-PMP/UPnP", err)
|
|
||||||
|
|
||||||
return nat.DiscoverGateway(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// State is persisted only for crash recovery cleanup
|
|
||||||
type State struct {
|
|
||||||
InternalPort uint16 `json:"internal_port,omitempty"`
|
|
||||||
Protocol string `json:"protocol,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *State) Name() string {
|
|
||||||
return "port_forward_state"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cleanup implements statemanager.CleanableState for crash recovery
|
|
||||||
func (s *State) Cleanup() error {
|
|
||||||
if s.InternalPort == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("cleaning up stale port mapping for port %d", s.InternalPort)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), discoveryTimeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
gateway, err := discoverGateway(ctx)
|
|
||||||
if err != nil {
|
|
||||||
// Discovery failure is not an error - gateway may not exist
|
|
||||||
log.Debugf("cleanup: no gateway found: %v", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := gateway.DeletePortMapping(ctx, s.Protocol, int(s.InternalPort)); err != nil {
|
|
||||||
return fmt.Errorf("delete port mapping: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -39,18 +39,6 @@ const (
|
|||||||
DefaultAdminURL = "https://app.netbird.io:443"
|
DefaultAdminURL = "https://app.netbird.io:443"
|
||||||
)
|
)
|
||||||
|
|
||||||
// mgmProber is the subset of management client needed for URL migration probes.
|
|
||||||
type mgmProber interface {
|
|
||||||
HealthCheck() error
|
|
||||||
Close() error
|
|
||||||
}
|
|
||||||
|
|
||||||
// newMgmProber creates a management client for probing URL reachability.
|
|
||||||
// Overridden in tests to avoid real network calls.
|
|
||||||
var newMgmProber = func(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (mgmProber, error) {
|
|
||||||
return mgm.NewClient(ctx, addr, key, tlsEnabled)
|
|
||||||
}
|
|
||||||
|
|
||||||
var DefaultInterfaceBlacklist = []string{
|
var DefaultInterfaceBlacklist = []string{
|
||||||
iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
|
iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
|
||||||
"Tailscale", "tailscale", "docker", "veth", "br-", "lo",
|
"Tailscale", "tailscale", "docker", "veth", "br-", "lo",
|
||||||
@@ -765,19 +753,21 @@ func UpdateOldManagementURL(ctx context.Context, config *Config, configPath stri
|
|||||||
return config, err
|
return config, err
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := newMgmProber(ctx, newURL.Host, key, mgmTlsEnabled)
|
client, err := mgm.NewClient(ctx, newURL.Host, key, mgmTlsEnabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||||
return config, err
|
return config, err
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := client.Close(); err != nil {
|
err = client.Close()
|
||||||
|
if err != nil {
|
||||||
log.Warnf("failed to close the Management service client %v", err)
|
log.Warnf("failed to close the Management service client %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// gRPC check
|
// gRPC check
|
||||||
if err = client.HealthCheck(); err != nil {
|
_, err = client.GetServerPublicKey()
|
||||||
|
if err != nil {
|
||||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,21 +10,12 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockMgmProber struct{}
|
|
||||||
|
|
||||||
func (m *mockMgmProber) HealthCheck() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockMgmProber) Close() error { return nil }
|
|
||||||
|
|
||||||
func TestGetConfig(t *testing.T) {
|
func TestGetConfig(t *testing.T) {
|
||||||
// case 1: new default config has to be generated
|
// case 1: new default config has to be generated
|
||||||
config, err := UpdateOrCreateConfig(ConfigInput{
|
config, err := UpdateOrCreateConfig(ConfigInput{
|
||||||
@@ -243,12 +234,6 @@ func TestWireguardPortDefaultVsExplicit(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateOldManagementURL(t *testing.T) {
|
func TestUpdateOldManagementURL(t *testing.T) {
|
||||||
origProber := newMgmProber
|
|
||||||
newMgmProber = func(_ context.Context, _ string, _ wgtypes.Key, _ bool) (mgmProber, error) {
|
|
||||||
return &mockMgmProber{}, nil
|
|
||||||
}
|
|
||||||
t.Cleanup(func() { newMgmProber = origProber })
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
previousManagementURL string
|
previousManagementURL string
|
||||||
@@ -288,17 +273,18 @@ func TestUpdateOldManagementURL(t *testing.T) {
|
|||||||
ConfigPath: configPath,
|
ConfigPath: configPath,
|
||||||
})
|
})
|
||||||
require.NoError(t, err, "failed to create testing config")
|
require.NoError(t, err, "failed to create testing config")
|
||||||
previousContent, err := os.ReadFile(configPath)
|
previousStats, err := os.Stat(configPath)
|
||||||
require.NoError(t, err, "failed to read initial config")
|
require.NoError(t, err, "failed to create testing config stats")
|
||||||
resultConfig, err := UpdateOldManagementURL(context.TODO(), config, configPath)
|
resultConfig, err := UpdateOldManagementURL(context.TODO(), config, configPath)
|
||||||
require.NoError(t, err, "got error when updating old management url")
|
require.NoError(t, err, "got error when updating old management url")
|
||||||
require.Equal(t, tt.expectedManagementURL, resultConfig.ManagementURL.String())
|
require.Equal(t, tt.expectedManagementURL, resultConfig.ManagementURL.String())
|
||||||
newContent, err := os.ReadFile(configPath)
|
newStats, err := os.Stat(configPath)
|
||||||
require.NoError(t, err, "failed to read updated config")
|
require.NoError(t, err, "failed to create testing config stats")
|
||||||
if tt.fileShouldNotChange {
|
switch tt.fileShouldNotChange {
|
||||||
require.Equal(t, string(previousContent), string(newContent), "file should not change")
|
case true:
|
||||||
} else {
|
require.Equal(t, previousStats.ModTime(), newStats.ModTime(), "file should not change")
|
||||||
require.NotEqual(t, string(previousContent), string(newContent), "file should have changed")
|
case false:
|
||||||
|
require.NotEqual(t, previousStats.ModTime(), newStats.ModTime(), "file should have changed")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,7 +52,6 @@ type Manager interface {
|
|||||||
TriggerSelection(route.HAMap)
|
TriggerSelection(route.HAMap)
|
||||||
GetRouteSelector() *routeselector.RouteSelector
|
GetRouteSelector() *routeselector.RouteSelector
|
||||||
GetClientRoutes() route.HAMap
|
GetClientRoutes() route.HAMap
|
||||||
GetSelectedClientRoutes() route.HAMap
|
|
||||||
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
|
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
|
||||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||||
InitialRouteRange() []string
|
InitialRouteRange() []string
|
||||||
@@ -168,7 +167,6 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
|
|||||||
NetworkType: route.IPv4Network,
|
NetworkType: route.IPv4Network,
|
||||||
}
|
}
|
||||||
cr = append(cr, fakeIPRoute)
|
cr = append(cr, fakeIPRoute)
|
||||||
m.notifier.SetFakeIPRoute(fakeIPRoute)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
m.notifier.SetInitialClientRoutes(cr, routesForComparison)
|
m.notifier.SetInitialClientRoutes(cr, routesForComparison)
|
||||||
@@ -467,16 +465,6 @@ func (m *DefaultManager) GetClientRoutes() route.HAMap {
|
|||||||
return maps.Clone(m.clientRoutes)
|
return maps.Clone(m.clientRoutes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSelectedClientRoutes returns only the currently selected/active client routes,
|
|
||||||
// filtering out deselected exit nodes. Use this instead of GetClientRoutes when checking
|
|
||||||
// if traffic should be routed through the tunnel.
|
|
||||||
func (m *DefaultManager) GetSelectedClientRoutes() route.HAMap {
|
|
||||||
m.mux.Lock()
|
|
||||||
defer m.mux.Unlock()
|
|
||||||
|
|
||||||
return m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes))
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
|
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
|
||||||
func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||||
m.mux.Lock()
|
m.mux.Lock()
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ type MockManager struct {
|
|||||||
TriggerSelectionFunc func(haMap route.HAMap)
|
TriggerSelectionFunc func(haMap route.HAMap)
|
||||||
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
||||||
GetClientRoutesFunc func() route.HAMap
|
GetClientRoutesFunc func() route.HAMap
|
||||||
GetSelectedClientRoutesFunc func() route.HAMap
|
|
||||||
GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route
|
GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route
|
||||||
StopFunc func(manager *statemanager.Manager)
|
StopFunc func(manager *statemanager.Manager)
|
||||||
}
|
}
|
||||||
@@ -62,7 +61,7 @@ func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetClientRoutes mock implementation of GetClientRoutes from the Manager interface
|
// GetClientRoutes mock implementation of GetClientRoutes from Manager interface
|
||||||
func (m *MockManager) GetClientRoutes() route.HAMap {
|
func (m *MockManager) GetClientRoutes() route.HAMap {
|
||||||
if m.GetClientRoutesFunc != nil {
|
if m.GetClientRoutesFunc != nil {
|
||||||
return m.GetClientRoutesFunc()
|
return m.GetClientRoutesFunc()
|
||||||
@@ -70,14 +69,6 @@ func (m *MockManager) GetClientRoutes() route.HAMap {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSelectedClientRoutes mock implementation of GetSelectedClientRoutes from the Manager interface
|
|
||||||
func (m *MockManager) GetSelectedClientRoutes() route.HAMap {
|
|
||||||
if m.GetSelectedClientRoutesFunc != nil {
|
|
||||||
return m.GetSelectedClientRoutesFunc()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface
|
// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface
|
||||||
func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||||
if m.GetClientRoutesWithNetIDFunc != nil {
|
if m.GetClientRoutesWithNetIDFunc != nil {
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ import (
|
|||||||
type Notifier struct {
|
type Notifier struct {
|
||||||
initialRoutes []*route.Route
|
initialRoutes []*route.Route
|
||||||
currentRoutes []*route.Route
|
currentRoutes []*route.Route
|
||||||
fakeIPRoute *route.Route
|
|
||||||
|
|
||||||
listener listener.NetworkChangeListener
|
listener listener.NetworkChangeListener
|
||||||
listenerMux sync.Mutex
|
listenerMux sync.Mutex
|
||||||
@@ -32,15 +31,26 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
|||||||
n.listener = listener
|
n.listener = listener
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetInitialClientRoutes stores the initial route sets for TUN configuration.
|
|
||||||
func (n *Notifier) SetInitialClientRoutes(initialRoutes []*route.Route, routesForComparison []*route.Route) {
|
func (n *Notifier) SetInitialClientRoutes(initialRoutes []*route.Route, routesForComparison []*route.Route) {
|
||||||
n.initialRoutes = filterStatic(initialRoutes)
|
// initialRoutes contains fake IP block for interface configuration
|
||||||
n.currentRoutes = filterStatic(routesForComparison)
|
filteredInitial := make([]*route.Route, 0)
|
||||||
}
|
for _, r := range initialRoutes {
|
||||||
|
if r.IsDynamic() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filteredInitial = append(filteredInitial, r)
|
||||||
|
}
|
||||||
|
n.initialRoutes = filteredInitial
|
||||||
|
|
||||||
// SetFakeIPRoute stores the fake IP route to be included in every TUN rebuild.
|
// routesForComparison excludes fake IP block for comparison with new routes
|
||||||
func (n *Notifier) SetFakeIPRoute(r *route.Route) {
|
filteredComparison := make([]*route.Route, 0)
|
||||||
n.fakeIPRoute = r
|
for _, r := range routesForComparison {
|
||||||
|
if r.IsDynamic() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filteredComparison = append(filteredComparison, r)
|
||||||
|
}
|
||||||
|
n.currentRoutes = filteredComparison
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
||||||
@@ -73,28 +83,13 @@ func (n *Notifier) notify() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
allRoutes := slices.Clone(n.currentRoutes)
|
routeStrings := n.routesToStrings(n.currentRoutes)
|
||||||
if n.fakeIPRoute != nil {
|
|
||||||
allRoutes = append(allRoutes, n.fakeIPRoute)
|
|
||||||
}
|
|
||||||
|
|
||||||
routeStrings := n.routesToStrings(allRoutes)
|
|
||||||
sort.Strings(routeStrings)
|
sort.Strings(routeStrings)
|
||||||
go func(l listener.NetworkChangeListener) {
|
go func(l listener.NetworkChangeListener) {
|
||||||
l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(routeStrings, allRoutes), ","))
|
l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(routeStrings, n.currentRoutes), ","))
|
||||||
}(n.listener)
|
}(n.listener)
|
||||||
}
|
}
|
||||||
|
|
||||||
func filterStatic(routes []*route.Route) []*route.Route {
|
|
||||||
out := make([]*route.Route, 0, len(routes))
|
|
||||||
for _, r := range routes {
|
|
||||||
if !r.IsDynamic() {
|
|
||||||
out = append(out, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *Notifier) routesToStrings(routes []*route.Route) []string {
|
func (n *Notifier) routesToStrings(routes []*route.Route) []string {
|
||||||
nets := make([]string, 0, len(routes))
|
nets := make([]string, 0, len(routes))
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
|
|||||||
@@ -34,10 +34,6 @@ func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
|
|||||||
// iOS doesn't care about initial routes
|
// iOS doesn't care about initial routes
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *Notifier) SetFakeIPRoute(*route.Route) {
|
|
||||||
// Not used on iOS
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *Notifier) OnNewRoutes(route.HAMap) {
|
func (n *Notifier) OnNewRoutes(route.HAMap) {
|
||||||
// Not used on iOS
|
// Not used on iOS
|
||||||
}
|
}
|
||||||
@@ -57,6 +53,7 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
|||||||
n.currentPrefixes = newNets
|
n.currentPrefixes = newNets
|
||||||
n.notify()
|
n.notify()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *Notifier) notify() {
|
func (n *Notifier) notify() {
|
||||||
n.listenerMux.Lock()
|
n.listenerMux.Lock()
|
||||||
defer n.listenerMux.Unlock()
|
defer n.listenerMux.Unlock()
|
||||||
|
|||||||
@@ -23,10 +23,6 @@ func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
|
|||||||
// Not used on non-mobile platforms
|
// Not used on non-mobile platforms
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *Notifier) SetFakeIPRoute(*route.Route) {
|
|
||||||
// Not used on non-mobile platforms
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
||||||
// Not used on non-mobile platforms
|
// Not used on non-mobile platforms
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +0,0 @@
|
|||||||
//go:build (dragonfly || freebsd || netbsd || openbsd) && !darwin
|
|
||||||
|
|
||||||
package systemops
|
|
||||||
|
|
||||||
// Non-darwin BSDs don't support the IP_BOUND_IF + scoped default model. They
|
|
||||||
// always fall through to the ref-counter exclusion-route path; these stubs
|
|
||||||
// exist only so systemops_unix.go compiles.
|
|
||||||
func (r *SysOps) setupAdvancedRouting() error { return nil }
|
|
||||||
func (r *SysOps) cleanupAdvancedRouting() error { return nil }
|
|
||||||
func (r *SysOps) flushPlatformExtras() error { return nil }
|
|
||||||
@@ -1,241 +0,0 @@
|
|||||||
//go:build darwin && !ios
|
|
||||||
|
|
||||||
package systemops
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/net/route"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
|
||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
|
||||||
)
|
|
||||||
|
|
||||||
// scopedRouteBudget bounds retries for the scoped default route. Installing or
|
|
||||||
// deleting it matters enough that we're willing to spend longer waiting for the
|
|
||||||
// kernel reply than for per-prefix exclusion routes.
|
|
||||||
const scopedRouteBudget = 5 * time.Second
|
|
||||||
|
|
||||||
// setupAdvancedRouting installs an RTF_IFSCOPE default route per address family
|
|
||||||
// pinned to the current physical egress, so IP_BOUND_IF scoped lookups can
|
|
||||||
// resolve gateway'd destinations while the VPN's split default owns the
|
|
||||||
// unscoped table.
|
|
||||||
//
|
|
||||||
// Timing note: this runs during routeManager.Init, which happens before the
|
|
||||||
// VPN interface is created and before any peer routes propagate. The initial
|
|
||||||
// mgmt / signal / relay TCP dials always fire before this runs, so those
|
|
||||||
// sockets miss the IP_BOUND_IF binding and rely on the kernel's normal route
|
|
||||||
// lookup, which at that point correctly picks the physical default. Those
|
|
||||||
// already-established TCP flows keep their originally-selected interface for
|
|
||||||
// their lifetime on Darwin because the kernel caches the egress route
|
|
||||||
// per-socket at connect time; adding the VPN's 0/1 + 128/1 split default
|
|
||||||
// afterwards does not migrate them since the original en0 default stays in
|
|
||||||
// the table. Any subsequent reconnect via nbnet.NewDialer picks up the
|
|
||||||
// populated bound-iface cache and gets IP_BOUND_IF set cleanly.
|
|
||||||
func (r *SysOps) setupAdvancedRouting() error {
|
|
||||||
// Drop any previously-cached egress interface before reinstalling. On a
|
|
||||||
// refresh, a family that no longer resolves would otherwise keep the stale
|
|
||||||
// binding, causing new sockets to scope to an interface without a matching
|
|
||||||
// scoped default.
|
|
||||||
nbnet.ClearBoundInterfaces()
|
|
||||||
|
|
||||||
if err := r.flushScopedDefaults(); err != nil {
|
|
||||||
log.Warnf("flush residual scoped defaults: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
|
||||||
installed := 0
|
|
||||||
|
|
||||||
for _, unspec := range []netip.Addr{netip.IPv4Unspecified(), netip.IPv6Unspecified()} {
|
|
||||||
ok, err := r.installScopedDefaultFor(unspec)
|
|
||||||
if err != nil {
|
|
||||||
merr = multierror.Append(merr, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
installed++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if installed == 0 && merr != nil {
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
if merr != nil {
|
|
||||||
log.Warnf("advanced routing setup partially succeeded: %v", nberrors.FormatErrorOrNil(merr))
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// installScopedDefaultFor resolves the physical default nexthop for the given
|
|
||||||
// address family, installs a scoped default via it, and caches the iface for
|
|
||||||
// subsequent IP_BOUND_IF / IPV6_BOUND_IF socket binds.
|
|
||||||
func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) {
|
|
||||||
nexthop, err := GetNextHop(unspec)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, vars.ErrRouteNotFound) {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
return false, fmt.Errorf("get default nexthop for %s: %w", unspec, err)
|
|
||||||
}
|
|
||||||
if nexthop.Intf == nil {
|
|
||||||
return false, fmt.Errorf("unusable default nexthop for %s (no interface)", unspec)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.addScopedDefault(unspec, nexthop); err != nil {
|
|
||||||
return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
af := unix.AF_INET
|
|
||||||
if unspec.Is6() {
|
|
||||||
af = unix.AF_INET6
|
|
||||||
}
|
|
||||||
nbnet.SetBoundInterface(af, nexthop.Intf)
|
|
||||||
via := "point-to-point"
|
|
||||||
if nexthop.IP.IsValid() {
|
|
||||||
via = nexthop.IP.String()
|
|
||||||
}
|
|
||||||
log.Infof("installed scoped default route via %s on %s for %s", via, nexthop.Intf.Name, afOf(unspec))
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *SysOps) cleanupAdvancedRouting() error {
|
|
||||||
nbnet.ClearBoundInterfaces()
|
|
||||||
return r.flushScopedDefaults()
|
|
||||||
}
|
|
||||||
|
|
||||||
// flushPlatformExtras runs darwin-specific residual cleanup hooked into the
|
|
||||||
// generic FlushMarkedRoutes path, so a crashed daemon's scoped defaults get
|
|
||||||
// removed on the next boot regardless of whether a profile is brought up.
|
|
||||||
func (r *SysOps) flushPlatformExtras() error {
|
|
||||||
return r.flushScopedDefaults()
|
|
||||||
}
|
|
||||||
|
|
||||||
// flushScopedDefaults removes any scoped default routes tagged with routeProtoFlag.
|
|
||||||
// Safe to call at startup to clear residual entries from a prior session.
|
|
||||||
func (r *SysOps) flushScopedDefaults() error {
|
|
||||||
rib, err := retryFetchRIB()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("fetch routing table: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("parse routing table: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
|
||||||
removed := 0
|
|
||||||
|
|
||||||
for _, msg := range msgs {
|
|
||||||
rtMsg, ok := msg.(*route.RouteMessage)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if rtMsg.Flags&routeProtoFlag == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if rtMsg.Flags&unix.RTF_IFSCOPE == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
info, err := MsgToRoute(rtMsg)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("skip scoped flush: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !info.Dst.IsValid() || info.Dst.Bits() != 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.deleteScopedRoute(rtMsg); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete scoped default %s on index %d: %w",
|
|
||||||
info.Dst, rtMsg.Index, err))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
removed++
|
|
||||||
log.Debugf("flushed residual scoped default %s on index %d", info.Dst, rtMsg.Index)
|
|
||||||
}
|
|
||||||
|
|
||||||
if removed > 0 {
|
|
||||||
log.Infof("flushed %d residual scoped default route(s)", removed)
|
|
||||||
}
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *SysOps) addScopedDefault(unspec netip.Addr, nexthop Nexthop) error {
|
|
||||||
return r.scopedRouteSocket(unix.RTM_ADD, unspec, nexthop)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *SysOps) deleteScopedRoute(rtMsg *route.RouteMessage) error {
|
|
||||||
// Preserve identifying flags from the stored route (including RTF_GATEWAY
|
|
||||||
// only if present); kernel-set bits like RTF_DONE don't belong on RTM_DELETE.
|
|
||||||
keep := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_GATEWAY | unix.RTF_IFSCOPE | routeProtoFlag
|
|
||||||
del := &route.RouteMessage{
|
|
||||||
Type: unix.RTM_DELETE,
|
|
||||||
Flags: rtMsg.Flags & keep,
|
|
||||||
Version: unix.RTM_VERSION,
|
|
||||||
Seq: r.getSeq(),
|
|
||||||
Index: rtMsg.Index,
|
|
||||||
Addrs: rtMsg.Addrs,
|
|
||||||
}
|
|
||||||
return r.writeRouteMessage(del, scopedRouteBudget)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *SysOps) scopedRouteSocket(action int, unspec netip.Addr, nexthop Nexthop) error {
|
|
||||||
flags := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_IFSCOPE | routeProtoFlag
|
|
||||||
|
|
||||||
msg := &route.RouteMessage{
|
|
||||||
Type: action,
|
|
||||||
Flags: flags,
|
|
||||||
Version: unix.RTM_VERSION,
|
|
||||||
ID: uintptr(os.Getpid()),
|
|
||||||
Seq: r.getSeq(),
|
|
||||||
Index: nexthop.Intf.Index,
|
|
||||||
}
|
|
||||||
|
|
||||||
const numAddrs = unix.RTAX_NETMASK + 1
|
|
||||||
addrs := make([]route.Addr, numAddrs)
|
|
||||||
|
|
||||||
dst, err := addrToRouteAddr(unspec)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("build destination: %w", err)
|
|
||||||
}
|
|
||||||
mask, err := prefixToRouteNetmask(netip.PrefixFrom(unspec, 0))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("build netmask: %w", err)
|
|
||||||
}
|
|
||||||
addrs[unix.RTAX_DST] = dst
|
|
||||||
addrs[unix.RTAX_NETMASK] = mask
|
|
||||||
|
|
||||||
if nexthop.IP.IsValid() {
|
|
||||||
msg.Flags |= unix.RTF_GATEWAY
|
|
||||||
gw, err := addrToRouteAddr(nexthop.IP.Unmap())
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("build gateway: %w", err)
|
|
||||||
}
|
|
||||||
addrs[unix.RTAX_GATEWAY] = gw
|
|
||||||
} else {
|
|
||||||
addrs[unix.RTAX_GATEWAY] = &route.LinkAddr{
|
|
||||||
Index: nexthop.Intf.Index,
|
|
||||||
Name: nexthop.Intf.Name,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
msg.Addrs = addrs
|
|
||||||
|
|
||||||
return r.writeRouteMessage(msg, scopedRouteBudget)
|
|
||||||
}
|
|
||||||
|
|
||||||
func afOf(a netip.Addr) string {
|
|
||||||
if a.Is4() {
|
|
||||||
return "IPv4"
|
|
||||||
}
|
|
||||||
return "IPv6"
|
|
||||||
}
|
|
||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
|
||||||
"github.com/netbirdio/netbird/client/net/hooks"
|
"github.com/netbirdio/netbird/client/net/hooks"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -32,6 +31,8 @@ var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
|
|||||||
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
|
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
|
||||||
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
|
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
|
||||||
|
|
||||||
|
var ErrRoutingIsSeparate = errors.New("routing is separate")
|
||||||
|
|
||||||
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||||
stateManager.RegisterState(&ShutdownState{})
|
stateManager.RegisterState(&ShutdownState{})
|
||||||
|
|
||||||
@@ -396,16 +397,12 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
|
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
|
||||||
// When advanced routing is active the WG socket is bound to the physical interface (fwmark on linux,
|
|
||||||
// IP_UNICAST_IF on windows, IP_BOUND_IF on darwin) and bypasses the main routing table, so the check is skipped.
|
|
||||||
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
|
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
|
||||||
if nbnet.AdvancedRouting() {
|
localRoutes, err := hasSeparateRouting()
|
||||||
return false, netip.Prefix{}
|
|
||||||
}
|
|
||||||
|
|
||||||
localRoutes, err := GetRoutesFromTable()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to get routes: %v", err)
|
if !errors.Is(err, ErrRoutingIsSeparate) {
|
||||||
|
log.Errorf("Failed to get routes: %v", err)
|
||||||
|
}
|
||||||
return false, netip.Prefix{}
|
return false, netip.Prefix{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,10 @@ func GetRoutesFromTable() ([]netip.Prefix, error) {
|
|||||||
return []netip.Prefix{}, nil
|
return []netip.Prefix{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hasSeparateRouting() ([]netip.Prefix, error) {
|
||||||
|
return []netip.Prefix{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetDetailedRoutesFromTable returns empty routes for WASM.
|
// GetDetailedRoutesFromTable returns empty routes for WASM.
|
||||||
func GetDetailedRoutesFromTable() ([]DetailedRoute, error) {
|
func GetDetailedRoutesFromTable() ([]DetailedRoute, error) {
|
||||||
return []DetailedRoute{}, nil
|
return []DetailedRoute{}, nil
|
||||||
|
|||||||
@@ -894,6 +894,13 @@ func getAddressFamily(prefix netip.Prefix) int {
|
|||||||
return netlink.FAMILY_V6
|
return netlink.FAMILY_V6
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hasSeparateRouting() ([]netip.Prefix, error) {
|
||||||
|
if !nbnet.AdvancedRouting() {
|
||||||
|
return GetRoutesFromTable()
|
||||||
|
}
|
||||||
|
return nil, ErrRoutingIsSeparate
|
||||||
|
}
|
||||||
|
|
||||||
func isOpErr(err error) bool {
|
func isOpErr(err error) bool {
|
||||||
// EAFTNOSUPPORT when ipv6 is disabled via sysctl, EOPNOTSUPP when disabled in boot options or otherwise not supported
|
// EAFTNOSUPPORT when ipv6 is disabled via sysctl, EOPNOTSUPP when disabled in boot options or otherwise not supported
|
||||||
if errors.Is(err, syscall.EAFNOSUPPORT) || errors.Is(err, syscall.EOPNOTSUPP) {
|
if errors.Is(err, syscall.EAFNOSUPPORT) || errors.Is(err, syscall.EOPNOTSUPP) {
|
||||||
|
|||||||
@@ -48,6 +48,10 @@ func EnableIPForwarding() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hasSeparateRouting() ([]netip.Prefix, error) {
|
||||||
|
return GetRoutesFromTable()
|
||||||
|
}
|
||||||
|
|
||||||
// GetIPRules returns IP rules for debugging (not supported on non-Linux platforms)
|
// GetIPRules returns IP rules for debugging (not supported on non-Linux platforms)
|
||||||
func GetIPRules() ([]IPRule, error) {
|
func GetIPRules() ([]IPRule, error) {
|
||||||
log.Infof("IP rules collection is not supported on %s", runtime.GOOS)
|
log.Infof("IP rules collection is not supported on %s", runtime.GOOS)
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user