mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-10 18:09:56 +00:00
Compare commits
17 Commits
fix/ios-de
...
tests/enab
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6a6b25b9af | ||
|
|
a40028092d | ||
|
|
13200265d8 | ||
|
|
ed7a9363aa | ||
|
|
d56859dc5d | ||
|
|
367d37050b | ||
|
|
106527182f | ||
|
|
8e1d5b78c2 | ||
|
|
d3b63c6be9 | ||
|
|
60d2fa08b0 | ||
|
|
1e7b16db0a | ||
|
|
b377d99933 | ||
|
|
512899d82d | ||
|
|
5993ec6e43 | ||
|
|
eac6d501c3 | ||
|
|
deeae30612 | ||
|
|
f3cdf163e1 |
9
.github/workflows/golang-test-darwin.yml
vendored
9
.github/workflows/golang-test-darwin.yml
vendored
@@ -45,4 +45,11 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
flags: unit,client
|
||||
|
||||
61
.github/workflows/golang-test-linux.yml
vendored
61
.github/workflows/golang-test-linux.yml
vendored
@@ -158,7 +158,16 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -race -coverprofile=coverage.txt -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
flags: unit,client
|
||||
|
||||
|
||||
test_client_on_docker:
|
||||
name: "Client (Docker) / Unit"
|
||||
@@ -276,9 +285,17 @@ jobs:
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
go test ${{ matrix.raceFlag }} \
|
||||
-exec 'sudo' \
|
||||
-exec 'sudo' -coverprofile=coverage.txt \
|
||||
-timeout 10m -p 1 ./relay/... ./shared/relay/...
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
flags: unit,relay
|
||||
|
||||
test_proxy:
|
||||
name: "Proxy / Unit"
|
||||
needs: [build-cache]
|
||||
@@ -326,7 +343,15 @@ jobs:
|
||||
- name: Test
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
go test -timeout 10m -p 1 ./proxy/...
|
||||
go test -timeout 10m -p 1 -coverprofile=coverage.txt ./proxy/...
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
flags: unit,proxy
|
||||
|
||||
test_signal:
|
||||
name: "Signal / Unit"
|
||||
@@ -377,9 +402,17 @@ jobs:
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
go test \
|
||||
-exec 'sudo' \
|
||||
-exec 'sudo' -coverprofile=coverage.txt \
|
||||
-timeout 10m ./signal/... ./shared/signal/...
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
flags: unit,signal
|
||||
|
||||
test_management:
|
||||
name: "Management / Unit"
|
||||
needs: [build-cache]
|
||||
@@ -445,10 +478,18 @@ jobs:
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
go test -tags=devcert \
|
||||
go test -race -tags=devcert -coverprofile=coverage.txt \
|
||||
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
|
||||
-timeout 20m ./management/... ./shared/management/...
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
flags: unit,management
|
||||
|
||||
benchmark:
|
||||
name: "Management / Benchmark"
|
||||
needs: [build-cache]
|
||||
@@ -687,6 +728,14 @@ jobs:
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
go test -tags=integration \
|
||||
go test -tags=integration -coverprofile=coverage.txt \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||
-timeout 20m ./management/server/http/...
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
flags: integration,management
|
||||
|
||||
4
.github/workflows/release.yml
vendored
4
.github/workflows/release.yml
vendored
@@ -29,10 +29,10 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Generate FreeBSD port diff
|
||||
run: bash release_files/freebsd-port-diff.sh
|
||||
run: bash -x release_files/freebsd-port-diff.sh
|
||||
|
||||
- name: Generate FreeBSD port issue body
|
||||
run: bash release_files/freebsd-port-issue-body.sh
|
||||
run: bash -x release_files/freebsd-port-issue-body.sh
|
||||
|
||||
- name: Check if diff was generated
|
||||
id: check_diff
|
||||
|
||||
4
.github/workflows/wasm-build-validation.yml
vendored
4
.github/workflows/wasm-build-validation.yml
vendored
@@ -65,7 +65,7 @@ jobs:
|
||||
|
||||
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
||||
|
||||
if [ ${SIZE} -gt 58720256 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
|
||||
if [ ${SIZE} -gt 62914560 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 60MB limit!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/upload-server/types"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
const errCloseConnection = "Failed to close connection: %v"
|
||||
@@ -100,6 +101,7 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||
Anonymize: anonymizeFlag,
|
||||
SystemInfo: systemInfoFlag,
|
||||
LogFileCount: logFileCount,
|
||||
CliVersion: version.NetbirdVersion(),
|
||||
}
|
||||
if uploadBundleFlag {
|
||||
request.UploadURL = uploadBundleURLFlag
|
||||
@@ -298,6 +300,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
Anonymize: anonymizeFlag,
|
||||
SystemInfo: systemInfoFlag,
|
||||
LogFileCount: logFileCount,
|
||||
CliVersion: version.NetbirdVersion(),
|
||||
}
|
||||
if uploadBundleFlag {
|
||||
request.UploadURL = uploadBundleURLFlag
|
||||
@@ -432,6 +435,7 @@ func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, c
|
||||
SyncResponse: syncResponse,
|
||||
LogPath: logFilePath,
|
||||
CPUProfile: nil,
|
||||
DaemonVersion: version.NetbirdVersion(), // acting as daemon
|
||||
},
|
||||
debug.BundleConfig{
|
||||
IncludeSystemInfo: true,
|
||||
|
||||
@@ -102,7 +102,7 @@ func (p *program) Stop(srv service.Service) error {
|
||||
}
|
||||
|
||||
// Common setup for service control commands
|
||||
func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc) (service.Service, error) {
|
||||
func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc, consoleLog bool) (service.Service, error) {
|
||||
// rootCmd env vars are already applied by PersistentPreRunE.
|
||||
SetFlagsFromEnvVars(serviceCmd)
|
||||
|
||||
@@ -112,8 +112,14 @@ func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := util.InitLog(logLevel, logFiles...); err != nil {
|
||||
return nil, fmt.Errorf("init log: %w", err)
|
||||
if consoleLog {
|
||||
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
|
||||
return nil, fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := util.InitLog(logLevel, logFiles...); err != nil {
|
||||
return nil, fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
cfg, err := newSVCConfig()
|
||||
@@ -138,7 +144,7 @@ var runCmd = &cobra.Command{
|
||||
SetupCloseHandler(ctx, cancel)
|
||||
SetupDebugHandler(ctx, nil, nil, nil, util.FindFirstLogPath(logFiles))
|
||||
|
||||
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
||||
s, err := setupServiceControlCommand(cmd, ctx, cancel, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -152,7 +158,7 @@ var startCmd = &cobra.Command{
|
||||
Short: "starts NetBird service",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
||||
s, err := setupServiceControlCommand(cmd, ctx, cancel, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -170,7 +176,7 @@ var stopCmd = &cobra.Command{
|
||||
Short: "stops NetBird service",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
||||
s, err := setupServiceControlCommand(cmd, ctx, cancel, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -188,7 +194,7 @@ var restartCmd = &cobra.Command{
|
||||
Short: "restarts NetBird service",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
||||
s, err := setupServiceControlCommand(cmd, ctx, cancel, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -206,7 +212,7 @@ var svcStatusCmd = &cobra.Command{
|
||||
Short: "shows NetBird service status",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
||||
s, err := setupServiceControlCommand(cmd, ctx, cancel, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package iptables
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net"
|
||||
"slices"
|
||||
|
||||
@@ -421,12 +422,17 @@ func (m *aclManager) updateState() {
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
// Clone the maps so the persisted state holds a private snapshot. The
|
||||
// live maps keep being mutated by subsequent rule operations while the
|
||||
// state manager marshals the state from its periodic-save goroutine.
|
||||
// Sharing them by reference races the two and aborts the process with a
|
||||
// concurrent map iteration and write.
|
||||
if m.v6 {
|
||||
currentState.ACLEntries6 = m.entries
|
||||
currentState.ACLIPsetStore6 = m.ipsetStore
|
||||
currentState.ACLEntries6 = maps.Clone(m.entries)
|
||||
currentState.ACLIPsetStore6 = m.ipsetStore.clone()
|
||||
} else {
|
||||
currentState.ACLEntries = m.entries
|
||||
currentState.ACLIPsetStore = m.ipsetStore
|
||||
currentState.ACLEntries = maps.Clone(m.entries)
|
||||
currentState.ACLIPsetStore = m.ipsetStore.clone()
|
||||
}
|
||||
|
||||
if err := m.stateManager.UpdateState(currentState); err != nil {
|
||||
|
||||
@@ -4,6 +4,7 @@ package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -749,11 +750,17 @@ func (r *router) updateState() {
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
// Clone the rule map so the persisted state holds a private snapshot. The
|
||||
// live map keeps being mutated by subsequent rule operations while the
|
||||
// state manager marshals the state from its periodic-save goroutine.
|
||||
// Sharing it by reference races the two and aborts the process with a
|
||||
// concurrent map iteration and write. The ipset counter guards itself
|
||||
// during marshaling, so it can be shared directly.
|
||||
if r.v6 {
|
||||
currentState.RouteRules6 = r.rules
|
||||
currentState.RouteRules6 = maps.Clone(r.rules)
|
||||
currentState.RouteIPsetCounter6 = r.ipsetCounter
|
||||
} else {
|
||||
currentState.RouteRules = r.rules
|
||||
currentState.RouteRules = maps.Clone(r.rules)
|
||||
currentState.RouteIPsetCounter = r.ipsetCounter
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package iptables
|
||||
|
||||
import "encoding/json"
|
||||
import (
|
||||
"encoding/json"
|
||||
"maps"
|
||||
)
|
||||
|
||||
type ipList struct {
|
||||
ips map[string]struct{}
|
||||
@@ -19,6 +22,14 @@ func (s *ipList) addIP(ip string) {
|
||||
s.ips[ip] = struct{}{}
|
||||
}
|
||||
|
||||
// clone returns a deep copy of the ipList with its own ips map.
|
||||
func (s *ipList) clone() *ipList {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return &ipList{ips: maps.Clone(s.ips)}
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler
|
||||
func (s *ipList) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
@@ -55,6 +66,19 @@ func newIpsetStore() *ipsetStore {
|
||||
}
|
||||
}
|
||||
|
||||
// clone returns a deep copy of the ipsetStore with its own ipsets map and
|
||||
// independent ipList entries.
|
||||
func (s *ipsetStore) clone() *ipsetStore {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := &ipsetStore{ipsets: make(map[string]*ipList, len(s.ipsets))}
|
||||
for name, list := range s.ipsets {
|
||||
cloned.ipsets[name] = list.clone()
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
|
||||
r, ok := s.ipsets[ipsetName]
|
||||
return r, ok
|
||||
|
||||
@@ -118,7 +118,6 @@ func (c *ConnectClient) RunOniOS(
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsManager dns.IosDnsManager,
|
||||
stateFilePath string,
|
||||
cacheDir string,
|
||||
) error {
|
||||
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
|
||||
debug.SetGCPercent(5)
|
||||
@@ -128,7 +127,6 @@ func (c *ConnectClient) RunOniOS(
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
DnsManager: dnsManager,
|
||||
StateFilePath: stateFilePath,
|
||||
TempDir: cacheDir,
|
||||
}
|
||||
return c.run(mobileDependency, nil, "")
|
||||
}
|
||||
|
||||
@@ -250,11 +250,12 @@ type BundleGenerator struct {
|
||||
syncResponse *mgmProto.SyncResponse
|
||||
logPath string
|
||||
tempDir string
|
||||
statePath string
|
||||
cpuProfile []byte
|
||||
capturePath string
|
||||
refreshStatus func() // Optional callback to refresh status before bundle generation
|
||||
clientMetrics MetricsExporter
|
||||
daemonVersion string
|
||||
cliVersion string
|
||||
|
||||
anonymize bool
|
||||
includeSystemInfo bool
|
||||
@@ -275,11 +276,12 @@ type GeneratorDependencies struct {
|
||||
SyncResponse *mgmProto.SyncResponse
|
||||
LogPath string
|
||||
TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used.
|
||||
StatePath string // Path to the state file. If empty, the ServiceManager default path is used.
|
||||
CPUProfile []byte
|
||||
CapturePath string
|
||||
RefreshStatus func()
|
||||
ClientMetrics MetricsExporter
|
||||
DaemonVersion string
|
||||
CliVersion string
|
||||
}
|
||||
|
||||
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
|
||||
@@ -297,11 +299,12 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
||||
syncResponse: deps.SyncResponse,
|
||||
logPath: deps.LogPath,
|
||||
tempDir: deps.TempDir,
|
||||
statePath: deps.StatePath,
|
||||
cpuProfile: deps.CPUProfile,
|
||||
capturePath: deps.CapturePath,
|
||||
refreshStatus: deps.RefreshStatus,
|
||||
clientMetrics: deps.ClientMetrics,
|
||||
daemonVersion: deps.DaemonVersion,
|
||||
cliVersion: deps.CliVersion,
|
||||
|
||||
anonymize: cfg.Anonymize,
|
||||
includeSystemInfo: cfg.IncludeSystemInfo,
|
||||
@@ -462,9 +465,11 @@ func (g *BundleGenerator) addStatus() error {
|
||||
protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
|
||||
protoFullStatus.Events = g.statusRecorder.GetEventHistory()
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, nbstatus.ConvertOptions{
|
||||
Anonymize: g.anonymize,
|
||||
ProfileName: profName,
|
||||
Anonymize: g.anonymize,
|
||||
ProfileName: profName,
|
||||
DaemonVersion: g.daemonVersion,
|
||||
})
|
||||
overview.CliVersion = g.cliVersion
|
||||
statusOutput := overview.FullDetailSummary()
|
||||
|
||||
statusReader := strings.NewReader(statusOutput)
|
||||
@@ -801,6 +806,8 @@ func (g *BundleGenerator) addSyncResponse() error {
|
||||
AllowPartial: true,
|
||||
}
|
||||
|
||||
g.maskSecrets()
|
||||
|
||||
jsonBytes, err := options.Marshal(g.syncResponse)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate json: %w", err)
|
||||
@@ -813,12 +820,30 @@ func (g *BundleGenerator) addSyncResponse() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addStateFile() error {
|
||||
path := g.statePath
|
||||
if path == "" {
|
||||
sm := profilemanager.NewServiceManager("")
|
||||
path = sm.GetStatePath()
|
||||
func (g *BundleGenerator) maskSecrets() {
|
||||
if g.syncResponse == nil || g.syncResponse.NetbirdConfig == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if g.syncResponse.NetbirdConfig.Flow != nil {
|
||||
g.syncResponse.NetbirdConfig.Flow.TokenPayload = maskedValue
|
||||
|
||||
}
|
||||
|
||||
if g.syncResponse.NetbirdConfig.Relay != nil {
|
||||
g.syncResponse.NetbirdConfig.Relay.TokenPayload = maskedValue
|
||||
}
|
||||
|
||||
for i := range g.syncResponse.NetbirdConfig.Turns {
|
||||
if g.syncResponse.NetbirdConfig.Turns[i] != nil {
|
||||
g.syncResponse.NetbirdConfig.Turns[i].Password = maskedValue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addStateFile() error {
|
||||
sm := profilemanager.NewServiceManager("")
|
||||
path := sm.GetStatePath()
|
||||
if path == "" {
|
||||
return nil
|
||||
}
|
||||
@@ -1045,7 +1070,8 @@ func (g *BundleGenerator) addRotatedLogFiles(logDir string) {
|
||||
return
|
||||
}
|
||||
|
||||
pattern := filepath.Join(logDir, "client-*.log.gz")
|
||||
// This regex will match both logs rotated by us and logrotate on linux
|
||||
pattern := filepath.Join(logDir, "client*.log.*")
|
||||
files, err := filepath.Glob(pattern)
|
||||
if err != nil {
|
||||
log.Warnf("failed to glob rotated logs: %v", err)
|
||||
@@ -1078,7 +1104,12 @@ func (g *BundleGenerator) addRotatedLogFiles(logDir string) {
|
||||
|
||||
for i := 0; i < maxFiles; i++ {
|
||||
name := filepath.Base(files[i])
|
||||
if err := g.addSingleLogFileGz(files[i], name); err != nil {
|
||||
if strings.HasSuffix(name, ".gz") {
|
||||
err = g.addSingleLogFileGz(files[i], name)
|
||||
} else {
|
||||
err = g.addSingleLogfile(files[i], name)
|
||||
}
|
||||
if err != nil {
|
||||
log.Warnf("failed to add rotated log %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
//go:build ios
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// swiftLogFile is the Swift app log written by the iOS app into the same log
|
||||
// directory as the Go client log, so it can be collected into the bundle.
|
||||
const swiftLogFile = "swift-log.log"
|
||||
|
||||
// addPlatformLog collects logs for the iOS debug bundle. iOS has no logcat or
|
||||
// systemd journal, so we rely on file-based logs. addLogfile handles the Go
|
||||
// client log (logPath) with rotation, the stderr/stdout companions and
|
||||
// anonymization. The iOS app writes its own Swift log into the same directory,
|
||||
// so we add it alongside the Go log.
|
||||
func (g *BundleGenerator) addPlatformLog() error {
|
||||
if err := g.addLogfile(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if g.logPath == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
swiftLogPath := filepath.Join(filepath.Dir(g.logPath), swiftLogFile)
|
||||
if err := g.addSingleLogfile(swiftLogPath, swiftLogFile); err != nil {
|
||||
// The Swift log is best-effort: the app may not have written it yet.
|
||||
log.Warnf("failed to add %s to debug bundle: %v", swiftLogFile, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
103
client/internal/debug/debug_logfiles_test.go
Normal file
103
client/internal/debug/debug_logfiles_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package debug
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestAddRotatedLogFiles_PicksUpAllVariants asserts that the rotated-log
|
||||
// glob picks up logs rotated by timberjack (gzipped) and by logrotate (plain
|
||||
// and gzipped), and skips unrelated files.
|
||||
func TestAddRotatedLogFiles_PicksUpAllVariants(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
writeFile(t, filepath.Join(dir, "client.log"), "active log\n")
|
||||
writeFile(t, filepath.Join(dir, "other.log"), "unrelated\n")
|
||||
|
||||
timberjackRotated := "client-2026-05-21T10-30-45.000.log.gz"
|
||||
writeGzFile(t, filepath.Join(dir, timberjackRotated), "timberjack rotated content\n")
|
||||
|
||||
logrotatePlain := "client.log.1"
|
||||
writeFile(t, filepath.Join(dir, logrotatePlain), "logrotate plain content\n")
|
||||
|
||||
logrotateGz := "client.log.2.gz"
|
||||
writeGzFile(t, filepath.Join(dir, logrotateGz), "logrotate gz content\n")
|
||||
|
||||
names := runAddRotatedLogFiles(t, dir, 10)
|
||||
|
||||
require.Contains(t, names, timberjackRotated, "timberjack rotated file should be in bundle")
|
||||
require.Contains(t, names, logrotatePlain, "logrotate plain rotated file should be in bundle")
|
||||
require.Contains(t, names, logrotateGz, "logrotate gzipped rotated file should be in bundle")
|
||||
require.NotContains(t, names, "client.log", "active log should not be added by addRotatedLogFiles")
|
||||
require.NotContains(t, names, "other.log", "unrelated files should not be in bundle")
|
||||
}
|
||||
|
||||
// TestAddRotatedLogFiles_RespectsLogFileCount asserts that only the newest
|
||||
// logFileCount rotated files are bundled, ordered by mtime.
|
||||
func TestAddRotatedLogFiles_RespectsLogFileCount(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
oldest := filepath.Join(dir, "client.log.3")
|
||||
middle := filepath.Join(dir, "client.log.2")
|
||||
newest := filepath.Join(dir, "client.log.1")
|
||||
writeFile(t, oldest, "old\n")
|
||||
writeFile(t, middle, "mid\n")
|
||||
writeFile(t, newest, "new\n")
|
||||
|
||||
now := time.Now()
|
||||
require.NoError(t, os.Chtimes(oldest, now.Add(-2*time.Hour), now.Add(-2*time.Hour)))
|
||||
require.NoError(t, os.Chtimes(middle, now.Add(-1*time.Hour), now.Add(-1*time.Hour)))
|
||||
require.NoError(t, os.Chtimes(newest, now, now))
|
||||
|
||||
names := runAddRotatedLogFiles(t, dir, 2)
|
||||
|
||||
require.Contains(t, names, "client.log.1")
|
||||
require.Contains(t, names, "client.log.2")
|
||||
require.NotContains(t, names, "client.log.3", "oldest file should be dropped when logFileCount=2")
|
||||
}
|
||||
|
||||
// runAddRotatedLogFiles calls addRotatedLogFiles against a fresh in-memory
|
||||
// zip writer and returns the set of entry names that ended up in the archive.
|
||||
func runAddRotatedLogFiles(t *testing.T, dir string, logFileCount uint32) map[string]struct{} {
|
||||
t.Helper()
|
||||
|
||||
var buf bytes.Buffer
|
||||
g := &BundleGenerator{
|
||||
archive: zip.NewWriter(&buf),
|
||||
logFileCount: logFileCount,
|
||||
}
|
||||
g.addRotatedLogFiles(dir)
|
||||
require.NoError(t, g.archive.Close())
|
||||
|
||||
zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
|
||||
require.NoError(t, err)
|
||||
|
||||
names := make(map[string]struct{}, len(zr.File))
|
||||
for _, f := range zr.File {
|
||||
names[f.Name] = struct{}{}
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
func writeFile(t *testing.T, path, content string) {
|
||||
t.Helper()
|
||||
require.NoError(t, os.WriteFile(path, []byte(content), 0o644))
|
||||
}
|
||||
|
||||
func writeGzFile(t *testing.T, path, content string) {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
gw := gzip.NewWriter(&buf)
|
||||
_, err := io.WriteString(gw, content)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, gw.Close())
|
||||
require.NoError(t, os.WriteFile(path, buf.Bytes(), 0o644))
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !android && !ios
|
||||
//go:build !android
|
||||
|
||||
package debug
|
||||
|
||||
|
||||
@@ -777,13 +777,24 @@ func (s *DefaultServer) applyHostConfig() {
|
||||
// context is released rather than leaked until GC.
|
||||
func (s *DefaultServer) registerFallback() {
|
||||
originalNameservers := s.hostManager.getOriginalNameservers()
|
||||
if len(originalNameservers) == 0 {
|
||||
|
||||
serverIP := s.service.RuntimeIP()
|
||||
var servers []netip.AddrPort
|
||||
for _, ns := range originalNameservers {
|
||||
if ns == serverIP {
|
||||
log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, serverIP)
|
||||
continue
|
||||
}
|
||||
servers = append(servers, netip.AddrPortFrom(ns, DefaultPort))
|
||||
}
|
||||
|
||||
if len(servers) == 0 {
|
||||
log.Debugf("no fallback upstreams to register; clearing PriorityFallback handler")
|
||||
s.clearFallback()
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("registering original nameservers %v as upstream handlers with priority %d", originalNameservers, PriorityFallback)
|
||||
log.Infof("registering original nameservers %v as upstream handlers with priority %d", servers, PriorityFallback)
|
||||
|
||||
handler, err := newUpstreamResolver(
|
||||
s.ctx,
|
||||
@@ -797,11 +808,6 @@ func (s *DefaultServer) registerFallback() {
|
||||
return
|
||||
}
|
||||
handler.selectedRoutes = s.selectedRoutes
|
||||
|
||||
var servers []netip.AddrPort
|
||||
for _, ns := range originalNameservers {
|
||||
servers = append(servers, netip.AddrPortFrom(ns, DefaultPort))
|
||||
}
|
||||
handler.addRace(servers)
|
||||
|
||||
prev := s.fallbackHandler
|
||||
|
||||
@@ -72,6 +72,7 @@ import (
|
||||
sProto "github.com/netbirdio/netbird/shared/signal/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/util/capture"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
|
||||
@@ -1072,6 +1073,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
||||
state.KernelInterface = !e.wgInterface.IsUserspaceBind()
|
||||
state.FQDN = conf.GetFqdn()
|
||||
state.WgPort = e.config.WgPort
|
||||
|
||||
e.statusRecorder.UpdateLocalPeerState(state)
|
||||
|
||||
@@ -1150,6 +1152,7 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR
|
||||
LogPath: e.config.LogPath,
|
||||
TempDir: e.config.TempDir,
|
||||
ClientMetrics: e.clientMetrics,
|
||||
DaemonVersion: version.NetbirdVersion(),
|
||||
RefreshStatus: func() {
|
||||
e.RunHealthProbes(true)
|
||||
},
|
||||
|
||||
@@ -111,6 +111,7 @@ type LocalPeerState struct {
|
||||
PubKey string
|
||||
KernelInterface bool
|
||||
FQDN string
|
||||
WgPort int
|
||||
Routes map[string]struct{}
|
||||
}
|
||||
|
||||
@@ -1357,6 +1358,7 @@ func (fs FullStatus) ToProto() *proto.FullStatus {
|
||||
pbFullStatus.LocalPeerState.PubKey = fs.LocalPeerState.PubKey
|
||||
pbFullStatus.LocalPeerState.KernelInterface = fs.LocalPeerState.KernelInterface
|
||||
pbFullStatus.LocalPeerState.Fqdn = fs.LocalPeerState.FQDN
|
||||
pbFullStatus.LocalPeerState.WgPort = int32(fs.LocalPeerState.WgPort)
|
||||
pbFullStatus.LocalPeerState.RosenpassPermissive = fs.RosenpassState.Permissive
|
||||
pbFullStatus.LocalPeerState.RosenpassEnabled = fs.RosenpassState.Enabled
|
||||
pbFullStatus.NumberOfForwardingRules = int32(fs.NumOfForwardingRules)
|
||||
|
||||
@@ -700,6 +700,13 @@ func resolveURLsToIPs(urls []string) []net.IP {
|
||||
|
||||
// updateRouteSelectorFromManagement updates the route selector based on the isSelected status from the management server
|
||||
func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) {
|
||||
// An explicit user "deselect all" must not be overridden by management auto-apply.
|
||||
// Auto-applying an exit node here would call SelectRoutes, which clears the
|
||||
// deselect-all flag and re-enables every route the user turned off.
|
||||
if m.routeSelector.IsDeselectAll() {
|
||||
return
|
||||
}
|
||||
|
||||
exitNodeInfo := m.collectExitNodeInfo(clientRoutes)
|
||||
if len(exitNodeInfo.allIDs) == 0 {
|
||||
return
|
||||
|
||||
71
client/internal/routemanager/selector_management_test.go
Normal file
71
client/internal/routemanager/selector_management_test.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func exitNodeRoutes(netID route.NetID, skipAutoApply bool) route.HAMap {
|
||||
haID := route.HAUniqueID(string(netID) + "|0.0.0.0/0")
|
||||
return route.HAMap{
|
||||
haID: []*route.Route{
|
||||
{
|
||||
ID: "r-" + route.ID(netID),
|
||||
NetID: netID,
|
||||
Network: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Enabled: true,
|
||||
SkipAutoApply: skipAutoApply,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRouteSelectorFromManagement(t *testing.T) {
|
||||
t.Run("management auto-apply selects exit node without user selection", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
routes := exitNodeRoutes("exit1", false)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.True(t, m.routeSelector.IsSelected("exit1"), "auto-apply exit node should be selected")
|
||||
require.Len(t, m.routeSelector.FilterSelectedExitNodes(routes), 1, "selected exit node should pass the filter")
|
||||
})
|
||||
|
||||
t.Run("management SkipAutoApply leaves exit node deselected", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
routes := exitNodeRoutes("exit1", true)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.False(t, m.routeSelector.IsSelected("exit1"), "SkipAutoApply exit node should not be selected")
|
||||
require.Empty(t, m.routeSelector.FilterSelectedExitNodes(routes), "deselected exit node should be filtered out")
|
||||
})
|
||||
|
||||
t.Run("user selection is not overridden by management", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
require.NoError(t, m.routeSelector.SelectRoutes([]route.NetID{"exit1"}, true, []route.NetID{"exit1"}))
|
||||
routes := exitNodeRoutes("exit1", true)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.True(t, m.routeSelector.IsSelected("exit1"), "explicit user selection must survive a management sync that wants to skip auto-apply")
|
||||
require.Len(t, m.routeSelector.FilterSelectedExitNodes(routes), 1, "user-selected exit node should pass the filter")
|
||||
})
|
||||
|
||||
t.Run("deselect-all is preserved across a management sync", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
m.routeSelector.DeselectAllRoutes()
|
||||
routes := exitNodeRoutes("exit1", false)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.True(t, m.routeSelector.IsDeselectAll(), "an explicit deselect-all must not be cleared by management auto-apply")
|
||||
require.Empty(t, m.routeSelector.FilterSelectedExitNodes(routes), "no routes should be selected while deselect-all is set")
|
||||
})
|
||||
}
|
||||
@@ -116,6 +116,14 @@ func (rs *RouteSelector) DeselectAllRoutes() {
|
||||
clear(rs.selectedRoutes)
|
||||
}
|
||||
|
||||
// IsDeselectAll reports whether the user has explicitly deselected all routes.
|
||||
func (rs *RouteSelector) IsDeselectAll() bool {
|
||||
rs.mu.RLock()
|
||||
defer rs.mu.RUnlock()
|
||||
|
||||
return rs.deselectAll
|
||||
}
|
||||
|
||||
// IsSelected checks if a specific route is selected.
|
||||
func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
||||
rs.mu.RLock()
|
||||
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
@@ -26,7 +25,6 @@ import (
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
types "github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
|
||||
// ConnectionListener export internal Listener for mobile
|
||||
@@ -67,8 +65,6 @@ func init() {
|
||||
type Client struct {
|
||||
cfgFile string
|
||||
stateFile string
|
||||
cacheDir string
|
||||
logFilePath string
|
||||
recorder *peer.Status
|
||||
ctxCancel context.CancelFunc
|
||||
ctxCancelLock *sync.Mutex
|
||||
@@ -79,21 +75,16 @@ type Client struct {
|
||||
onHostDnsFn func([]string)
|
||||
dnsManager dns.IosDnsManager
|
||||
loginComplete bool
|
||||
connectClient *internal.ConnectClient
|
||||
// preloadedConfig holds config loaded from JSON (used on tvOS where file writes are blocked)
|
||||
preloadedConfig *profilemanager.Config
|
||||
|
||||
stateMu sync.RWMutex
|
||||
connectClient *internal.ConnectClient
|
||||
config *profilemanager.Config
|
||||
}
|
||||
|
||||
// NewClient instantiate a new Client
|
||||
func NewClient(cfgFile, stateFile, cacheDir, logFilePath, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client {
|
||||
func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client {
|
||||
return &Client{
|
||||
cfgFile: cfgFile,
|
||||
stateFile: stateFile,
|
||||
cacheDir: cacheDir,
|
||||
logFilePath: logFilePath,
|
||||
deviceName: deviceName,
|
||||
osName: osName,
|
||||
osVersion: osVersion,
|
||||
@@ -170,13 +161,8 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
|
||||
c.onHostDnsFn = func([]string) {}
|
||||
cfg.WgIface = interfaceName
|
||||
|
||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
c.setState(cfg, connectClient)
|
||||
// Persist the latest sync response so DebugBundle can include the network
|
||||
// map. On iOS this is backed by disk to keep it out of the constrained
|
||||
// process memory (see the syncstore package).
|
||||
connectClient.SetSyncResponsePersistence(true)
|
||||
return connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile, c.cacheDir)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
@@ -188,84 +174,6 @@ func (c *Client) Stop() {
|
||||
}
|
||||
|
||||
c.ctxCancel()
|
||||
c.setState(nil, nil)
|
||||
}
|
||||
|
||||
// DebugBundle generates a debug bundle, uploads it and returns the upload key.
|
||||
// It works with or without a running engine: when the engine is up it reuses
|
||||
// the live config, sync response and client metrics; otherwise it loads the
|
||||
// config from disk (or the preloaded tvOS config).
|
||||
func (c *Client) DebugBundle(anonymize bool) (string, error) {
|
||||
cfg, cc := c.stateSnapshot()
|
||||
|
||||
// If the engine hasn't been started, load config so we can reach management.
|
||||
if cfg == nil {
|
||||
if c.preloadedConfig != nil {
|
||||
cfg = c.preloadedConfig
|
||||
} else {
|
||||
var err error
|
||||
// Use DirectUpdateOrCreateConfig to avoid atomic file operations
|
||||
// (temp file + rename) blocked by the tvOS sandbox.
|
||||
cfg, err = profilemanager.DirectUpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||
ConfigPath: c.cfgFile,
|
||||
StateFilePath: c.stateFile,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
deps := debug.GeneratorDependencies{
|
||||
InternalConfig: cfg,
|
||||
StatusRecorder: c.recorder,
|
||||
TempDir: c.cacheDir,
|
||||
StatePath: c.stateFile,
|
||||
LogPath: c.logFilePath,
|
||||
}
|
||||
|
||||
if cc != nil {
|
||||
resp, err := cc.GetLatestSyncResponse()
|
||||
if err != nil {
|
||||
log.Warnf("get latest sync response: %v", err)
|
||||
}
|
||||
deps.SyncResponse = resp
|
||||
|
||||
if e := cc.Engine(); e != nil {
|
||||
if cm := e.GetClientMetrics(); cm != nil {
|
||||
deps.ClientMetrics = cm
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bundleGenerator := debug.NewBundleGenerator(
|
||||
deps,
|
||||
debug.BundleConfig{
|
||||
Anonymize: anonymize,
|
||||
IncludeSystemInfo: true,
|
||||
},
|
||||
)
|
||||
|
||||
path, err := bundleGenerator.Generate()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("generate debug bundle: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := os.Remove(path); err != nil {
|
||||
log.Errorf("failed to remove debug bundle file: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
uploadCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
key, err := debug.UploadDebugBundle(uploadCtx, types.DefaultBundleURL, cfg.ManagementURL.String(), path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("upload debug bundle: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("debug bundle uploaded with key %s", key)
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// SetTraceLogLevel configure the logger to trace level
|
||||
@@ -446,12 +354,11 @@ func (c *Client) ClearLoginComplete() {
|
||||
}
|
||||
|
||||
func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
|
||||
_, connectClient := c.stateSnapshot()
|
||||
if connectClient == nil {
|
||||
if c.connectClient == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
engine := connectClient.Engine()
|
||||
engine := c.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
@@ -563,12 +470,11 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom
|
||||
}
|
||||
|
||||
func (c *Client) SelectRoute(id string) error {
|
||||
_, connectClient := c.stateSnapshot()
|
||||
if connectClient == nil {
|
||||
if c.connectClient == nil {
|
||||
return fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
engine := connectClient.Engine()
|
||||
engine := c.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return fmt.Errorf("not connected")
|
||||
}
|
||||
@@ -594,11 +500,10 @@ func (c *Client) SelectRoute(id string) error {
|
||||
}
|
||||
|
||||
func (c *Client) DeselectRoute(id string) error {
|
||||
_, connectClient := c.stateSnapshot()
|
||||
if connectClient == nil {
|
||||
if c.connectClient == nil {
|
||||
return fmt.Errorf("not connected")
|
||||
}
|
||||
engine := connectClient.Engine()
|
||||
engine := c.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return fmt.Errorf("not connected")
|
||||
}
|
||||
@@ -622,22 +527,6 @@ func (c *Client) DeselectRoute(id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setState stores the running engine state so DebugBundle can reuse the live
|
||||
// config and ConnectClient. It is cleared on Stop.
|
||||
func (c *Client) setState(cfg *profilemanager.Config, cc *internal.ConnectClient) {
|
||||
c.stateMu.Lock()
|
||||
defer c.stateMu.Unlock()
|
||||
c.config = cfg
|
||||
c.connectClient = cc
|
||||
}
|
||||
|
||||
// stateSnapshot returns the current config and ConnectClient under the lock.
|
||||
func (c *Client) stateSnapshot() (*profilemanager.Config, *internal.ConnectClient) {
|
||||
c.stateMu.RLock()
|
||||
defer c.stateMu.RUnlock()
|
||||
return c.config, c.connectClient
|
||||
}
|
||||
|
||||
func formatDuration(d time.Duration) string {
|
||||
ds := d.String()
|
||||
dotIndex := strings.Index(ds, ".")
|
||||
|
||||
@@ -1614,6 +1614,7 @@ type LocalPeerState struct {
|
||||
RosenpassPermissive bool `protobuf:"varint,6,opt,name=rosenpassPermissive,proto3" json:"rosenpassPermissive,omitempty"`
|
||||
Networks []string `protobuf:"bytes,7,rep,name=networks,proto3" json:"networks,omitempty"`
|
||||
Ipv6 string `protobuf:"bytes,8,opt,name=ipv6,proto3" json:"ipv6,omitempty"`
|
||||
WgPort int32 `protobuf:"varint,9,opt,name=wgPort,proto3" json:"wgPort,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@@ -1704,6 +1705,13 @@ func (x *LocalPeerState) GetIpv6() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *LocalPeerState) GetWgPort() int32 {
|
||||
if x != nil {
|
||||
return x.WgPort
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// SignalState contains the latest state of a signal connection
|
||||
type SignalState struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
@@ -2709,6 +2717,7 @@ type DebugBundleRequest struct {
|
||||
SystemInfo bool `protobuf:"varint,3,opt,name=systemInfo,proto3" json:"systemInfo,omitempty"`
|
||||
UploadURL string `protobuf:"bytes,4,opt,name=uploadURL,proto3" json:"uploadURL,omitempty"`
|
||||
LogFileCount uint32 `protobuf:"varint,5,opt,name=logFileCount,proto3" json:"logFileCount,omitempty"`
|
||||
CliVersion string `protobuf:"bytes,6,opt,name=cliVersion,proto3" json:"cliVersion,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@@ -2771,6 +2780,13 @@ func (x *DebugBundleRequest) GetLogFileCount() uint32 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (x *DebugBundleRequest) GetCliVersion() string {
|
||||
if x != nil {
|
||||
return x.CliVersion
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type DebugBundleResponse struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"`
|
||||
@@ -6389,7 +6405,7 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"sshHostKey\x18\x13 \x01(\fR\n" +
|
||||
"sshHostKey\x12\x12\n" +
|
||||
"\x04ipv6\x18\x14 \x01(\tR\x04ipv6\"\x84\x02\n" +
|
||||
"\x04ipv6\x18\x14 \x01(\tR\x04ipv6\"\x9c\x02\n" +
|
||||
"\x0eLocalPeerState\x12\x0e\n" +
|
||||
"\x02IP\x18\x01 \x01(\tR\x02IP\x12\x16\n" +
|
||||
"\x06pubKey\x18\x02 \x01(\tR\x06pubKey\x12(\n" +
|
||||
@@ -6398,7 +6414,8 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\x10rosenpassEnabled\x18\x05 \x01(\bR\x10rosenpassEnabled\x120\n" +
|
||||
"\x13rosenpassPermissive\x18\x06 \x01(\bR\x13rosenpassPermissive\x12\x1a\n" +
|
||||
"\bnetworks\x18\a \x03(\tR\bnetworks\x12\x12\n" +
|
||||
"\x04ipv6\x18\b \x01(\tR\x04ipv6\"S\n" +
|
||||
"\x04ipv6\x18\b \x01(\tR\x04ipv6\x12\x16\n" +
|
||||
"\x06wgPort\x18\t \x01(\x05R\x06wgPort\"S\n" +
|
||||
"\vSignalState\x12\x10\n" +
|
||||
"\x03URL\x18\x01 \x01(\tR\x03URL\x12\x1c\n" +
|
||||
"\tconnected\x18\x02 \x01(\bR\tconnected\x12\x14\n" +
|
||||
@@ -6475,14 +6492,17 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\x12translatedHostname\x18\x04 \x01(\tR\x12translatedHostname\x128\n" +
|
||||
"\x0etranslatedPort\x18\x05 \x01(\v2\x10.daemon.PortInfoR\x0etranslatedPort\"G\n" +
|
||||
"\x17ForwardingRulesResponse\x12,\n" +
|
||||
"\x05rules\x18\x01 \x03(\v2\x16.daemon.ForwardingRuleR\x05rules\"\x94\x01\n" +
|
||||
"\x05rules\x18\x01 \x03(\v2\x16.daemon.ForwardingRuleR\x05rules\"\xb4\x01\n" +
|
||||
"\x12DebugBundleRequest\x12\x1c\n" +
|
||||
"\tanonymize\x18\x01 \x01(\bR\tanonymize\x12\x1e\n" +
|
||||
"\n" +
|
||||
"systemInfo\x18\x03 \x01(\bR\n" +
|
||||
"systemInfo\x12\x1c\n" +
|
||||
"\tuploadURL\x18\x04 \x01(\tR\tuploadURL\x12\"\n" +
|
||||
"\flogFileCount\x18\x05 \x01(\rR\flogFileCount\"}\n" +
|
||||
"\flogFileCount\x18\x05 \x01(\rR\flogFileCount\x12\x1e\n" +
|
||||
"\n" +
|
||||
"cliVersion\x18\x06 \x01(\tR\n" +
|
||||
"cliVersion\"}\n" +
|
||||
"\x13DebugBundleResponse\x12\x12\n" +
|
||||
"\x04path\x18\x01 \x01(\tR\x04path\x12 \n" +
|
||||
"\vuploadedKey\x18\x02 \x01(\tR\vuploadedKey\x120\n" +
|
||||
|
||||
@@ -349,6 +349,7 @@ message LocalPeerState {
|
||||
bool rosenpassPermissive = 6;
|
||||
repeated string networks = 7;
|
||||
string ipv6 = 8;
|
||||
int32 wgPort = 9;
|
||||
}
|
||||
|
||||
// SignalState contains the latest state of a signal connection
|
||||
@@ -471,6 +472,7 @@ message DebugBundleRequest {
|
||||
bool systemInfo = 3;
|
||||
string uploadURL = 4;
|
||||
uint32 logFileCount = 5;
|
||||
string cliVersion = 6;
|
||||
}
|
||||
|
||||
message DebugBundleResponse {
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
if ! which realpath > /dev/null 2>&1
|
||||
then
|
||||
echo realpath is not installed
|
||||
echo run: brew install coreutils
|
||||
exit 1
|
||||
if ! which realpath >/dev/null 2>&1; then
|
||||
echo realpath is not installed
|
||||
echo run: brew install coreutils
|
||||
exit 1
|
||||
fi
|
||||
|
||||
old_pwd=$(pwd)
|
||||
script_path=$(dirname $(realpath "$0"))
|
||||
script_path=$(dirname "$(realpath "$0")")
|
||||
cd "$script_path"
|
||||
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.36.6
|
||||
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
|
||||
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.6.1
|
||||
protoc -I ./ ./daemon.proto --go_out=../ --go-grpc_out=../ --experimental_allow_proto3_optional
|
||||
cd "$old_pwd"
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// DebugBundle creates a debug bundle and returns the location.
|
||||
@@ -67,6 +68,8 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
||||
CapturePath: capturePath,
|
||||
RefreshStatus: refreshStatus,
|
||||
ClientMetrics: clientMetrics,
|
||||
DaemonVersion: version.NetbirdVersion(),
|
||||
CliVersion: req.CliVersion,
|
||||
},
|
||||
debug.BundleConfig{
|
||||
Anonymize: req.GetAnonymize(),
|
||||
|
||||
@@ -143,6 +143,7 @@ type OutputOverview struct {
|
||||
IPv6 string `json:"netbirdIpv6,omitempty" yaml:"netbirdIpv6,omitempty"`
|
||||
PubKey string `json:"publicKey" yaml:"publicKey"`
|
||||
KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"`
|
||||
WgPort int `json:"wireguardPort" yaml:"wireguardPort"`
|
||||
FQDN string `json:"fqdn" yaml:"fqdn"`
|
||||
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
||||
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
|
||||
@@ -187,6 +188,7 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
|
||||
IPv6: pbFullStatus.GetLocalPeerState().GetIpv6(),
|
||||
PubKey: pbFullStatus.GetLocalPeerState().GetPubKey(),
|
||||
KernelInterface: pbFullStatus.GetLocalPeerState().GetKernelInterface(),
|
||||
WgPort: int(pbFullStatus.GetLocalPeerState().GetWgPort()),
|
||||
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
|
||||
RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
|
||||
RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
|
||||
@@ -547,6 +549,21 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
goarm = fmt.Sprintf(" (ARMv%s)", os.Getenv("GOARM"))
|
||||
}
|
||||
|
||||
daemonVersion := "N/A"
|
||||
if o.DaemonVersion != "" {
|
||||
daemonVersion = o.DaemonVersion
|
||||
}
|
||||
|
||||
cliVersion := version.NetbirdVersion()
|
||||
if o.CliVersion != "" {
|
||||
cliVersion = o.CliVersion
|
||||
}
|
||||
|
||||
wgPortString := "N/A"
|
||||
if o.WgPort > 0 {
|
||||
wgPortString = fmt.Sprintf("%d", o.WgPort)
|
||||
}
|
||||
|
||||
summary := fmt.Sprintf(
|
||||
"OS: %s\n"+
|
||||
"Daemon version: %s\n"+
|
||||
@@ -560,6 +577,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
"NetBird IP: %s\n"+
|
||||
"%s"+
|
||||
"Interface type: %s\n"+
|
||||
"Wireguard port: %s\n"+
|
||||
"Quantum resistance: %s\n"+
|
||||
"Lazy connection: %s\n"+
|
||||
"SSH Server: %s\n"+
|
||||
@@ -567,8 +585,8 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
"%s"+
|
||||
"Peers count: %s\n",
|
||||
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
|
||||
o.DaemonVersion,
|
||||
version.NetbirdVersion(),
|
||||
daemonVersion,
|
||||
cliVersion,
|
||||
o.ProfileName,
|
||||
managementConnString,
|
||||
signalConnString,
|
||||
@@ -578,6 +596,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
interfaceIP,
|
||||
ipv6Line,
|
||||
interfaceTypeString,
|
||||
wgPortString,
|
||||
rosenpassEnabledStatus,
|
||||
lazyConnectionEnabledStatus,
|
||||
sshServerStatus,
|
||||
|
||||
@@ -94,6 +94,7 @@ var resp = &proto.StatusResponse{
|
||||
Ipv6: "fd00::100",
|
||||
PubKey: "Some-Pub-Key",
|
||||
KernelInterface: true,
|
||||
WgPort: 51820,
|
||||
Fqdn: "some-localhost.awesome-domain.com",
|
||||
Networks: []string{
|
||||
"10.10.0.0/24",
|
||||
@@ -210,6 +211,7 @@ var overview = OutputOverview{
|
||||
IPv6: "fd00::100",
|
||||
PubKey: "Some-Pub-Key",
|
||||
KernelInterface: true,
|
||||
WgPort: 51820,
|
||||
FQDN: "some-localhost.awesome-domain.com",
|
||||
NSServerGroups: []NsServerGroupStateOutput{
|
||||
{
|
||||
@@ -369,6 +371,7 @@ func TestParsingToJSON(t *testing.T) {
|
||||
"netbirdIpv6": "fd00::100",
|
||||
"publicKey": "Some-Pub-Key",
|
||||
"usesKernelInterface": true,
|
||||
"wireguardPort": 51820,
|
||||
"fqdn": "some-localhost.awesome-domain.com",
|
||||
"quantumResistance": false,
|
||||
"quantumResistancePermissive": false,
|
||||
@@ -487,6 +490,7 @@ netbirdIp: 192.168.178.100/16
|
||||
netbirdIpv6: fd00::100
|
||||
publicKey: Some-Pub-Key
|
||||
usesKernelInterface: true
|
||||
wireguardPort: 51820
|
||||
fqdn: some-localhost.awesome-domain.com
|
||||
quantumResistance: false
|
||||
quantumResistancePermissive: false
|
||||
@@ -579,12 +583,13 @@ FQDN: some-localhost.awesome-domain.com
|
||||
NetBird IP: 192.168.178.100/16
|
||||
NetBird IPv6: fd00::100
|
||||
Interface type: Kernel
|
||||
Wireguard port: %d
|
||||
Quantum resistance: false
|
||||
Lazy connection: false
|
||||
SSH Server: Disabled
|
||||
Networks: 10.10.0.0/24
|
||||
Peers count: 2/2 Connected
|
||||
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
|
||||
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion, overview.WgPort)
|
||||
|
||||
assert.Equal(t, expectedDetail, detail)
|
||||
}
|
||||
@@ -604,6 +609,7 @@ FQDN: some-localhost.awesome-domain.com
|
||||
NetBird IP: 192.168.178.100/16
|
||||
NetBird IPv6: fd00::100
|
||||
Interface type: Kernel
|
||||
Wireguard port: 51820
|
||||
Quantum resistance: false
|
||||
Lazy connection: false
|
||||
SSH Server: Disabled
|
||||
|
||||
@@ -502,7 +502,7 @@ func (s *serviceClient) getConnectionForm() *widget.Form {
|
||||
{Text: "Pre-shared Key", Widget: s.iPreSharedKey},
|
||||
{Text: "Quantum-Resistance", Widget: s.sRosenpassPermissive},
|
||||
{Text: "Interface Name", Widget: s.iInterfaceName},
|
||||
{Text: "Interface Port", Widget: s.iInterfacePort},
|
||||
{Text: "Interface Port", Widget: s.iInterfacePort, HintText: "If set to 0, a random free port will be used"},
|
||||
{Text: "MTU", Widget: s.iMTU},
|
||||
{Text: "Log File", Widget: s.iLogFile},
|
||||
},
|
||||
@@ -558,8 +558,8 @@ func (s *serviceClient) parseNumericSettings() (int64, int64, error) {
|
||||
if err != nil {
|
||||
return 0, 0, errors.New("invalid interface port")
|
||||
}
|
||||
if port < 1 || port > 65535 {
|
||||
return 0, 0, errors.New("invalid interface port: out of range 1-65535")
|
||||
if port < 0 || port > 65535 {
|
||||
return 0, 0, errors.New("invalid interface port: out of range 0-65535")
|
||||
}
|
||||
|
||||
var mtu int64
|
||||
@@ -1438,7 +1438,7 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config {
|
||||
}
|
||||
|
||||
config.WgIface = cfg.InterfaceName
|
||||
if cfg.WireguardPort != 0 {
|
||||
if cfg.WireguardPort >= 0 && cfg.WireguardPort <= 65535 {
|
||||
config.WgPort = int(cfg.WireguardPort)
|
||||
} else {
|
||||
config.WgPort = iface.DefaultWgPort
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
uptypes "github.com/netbirdio/netbird/upload-server/types"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// Initial state for the debug collection
|
||||
@@ -462,6 +463,7 @@ func (s *serviceClient) createDebugBundleFromCollection(
|
||||
request := &proto.DebugBundleRequest{
|
||||
Anonymize: params.anonymize,
|
||||
SystemInfo: params.systemInfo,
|
||||
CliVersion: version.NetbirdVersion(),
|
||||
}
|
||||
|
||||
if params.upload {
|
||||
@@ -593,6 +595,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
|
||||
request := &proto.DebugBundleRequest{
|
||||
Anonymize: anonymize,
|
||||
SystemInfo: systemInfo,
|
||||
CliVersion: version.NetbirdVersion(),
|
||||
}
|
||||
|
||||
if uploadURL != "" {
|
||||
|
||||
@@ -99,6 +99,9 @@ func addFields(entry *logrus.Entry) {
|
||||
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
|
||||
entry.Data[context.AccountIDKey] = ctxAccountID
|
||||
}
|
||||
if ctxUserAgent, ok := entry.Context.Value(context.UserAgentKey).(string); ok {
|
||||
entry.Data[context.UserAgentKey] = ctxUserAgent
|
||||
}
|
||||
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
|
||||
entry.Data[context.UserIDKey] = ctxInitiatorID
|
||||
}
|
||||
|
||||
2
go.mod
2
go.mod
@@ -24,13 +24,13 @@ require (
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
google.golang.org/grpc v1.80.0
|
||||
google.golang.org/protobuf v1.36.11
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
)
|
||||
|
||||
require (
|
||||
fyne.io/fyne/v2 v2.7.0
|
||||
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9
|
||||
git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3
|
||||
github.com/DeRuina/timberjack v1.4.2
|
||||
github.com/awnumar/memguard v0.23.0
|
||||
github.com/aws/aws-sdk-go-v2 v1.38.3
|
||||
github.com/aws/aws-sdk-go-v2/config v1.31.6
|
||||
|
||||
4
go.sum
4
go.sum
@@ -29,6 +29,8 @@ github.com/Azure/go-ntlmssp v0.1.0 h1:DjFo6YtWzNqNvQdrwEyr/e4nhU3vRiwenz5QX7sFz+
|
||||
github.com/Azure/go-ntlmssp v0.1.0/go.mod h1:NYqdhxd/8aAct/s4qSYZEerdPuH1liG2/X9DiVTbhpk=
|
||||
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
|
||||
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
||||
github.com/DeRuina/timberjack v1.4.2 h1:4bKlzhKdsR+2oNkgef9mqb4n11ICow8VK88RfzJPzN8=
|
||||
github.com/DeRuina/timberjack v1.4.2/go.mod h1:RLoeQrwrCGIEF8gO5nV5b/gMD0QIy7bzQhBUgpp1EqE=
|
||||
github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI=
|
||||
github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU=
|
||||
github.com/Masterminds/semver/v3 v3.3.0 h1:B8LGeaivUe71a5qox1ICM/JLl0NqZSW5CHyL+hmvYS0=
|
||||
@@ -940,8 +942,6 @@ gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8
|
||||
gopkg.in/go-playground/validator.v9 v9.29.1/go.mod h1:+c9/zcJMFNgbLvly1L1V+PpxWdVbfP1avr/N00E2vyQ=
|
||||
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
|
||||
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
|
||||
gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI=
|
||||
gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
|
||||
|
||||
@@ -19,6 +19,46 @@ readonly MSG_SEPARATOR="=========================================="
|
||||
# Utility Functions
|
||||
############################################
|
||||
|
||||
check_docker_sock_perms() {
|
||||
local sock="${DOCKER_HOST:-unix:///var/run/docker.sock}"
|
||||
sock="${sock#unix://}"
|
||||
|
||||
if [[ ! -S "$sock" ]]; then
|
||||
return 0
|
||||
fi
|
||||
|
||||
if [[ ! -r "$sock" ]] || [[ ! -w "$sock" ]]; then
|
||||
local group
|
||||
if [[ "${OSTYPE}" == "darwin"* ]]; then
|
||||
group="$(stat -f '%Sg' "$sock")"
|
||||
else
|
||||
group="$(stat -c '%G' "$sock")"
|
||||
fi
|
||||
|
||||
echo "Cannot access Docker socket: $sock" > /dev/stderr
|
||||
echo "" > /dev/stderr
|
||||
echo "Socket permissions:" > /dev/stderr
|
||||
ls -l "$sock" > /dev/stderr
|
||||
echo "" > /dev/stderr
|
||||
|
||||
if [[ "$group" == "docker" ]]; then
|
||||
echo "Your user may need to be added to the '$group' group:" > /dev/stderr
|
||||
echo " sudo usermod -aG $group \"$USER\"" > /dev/stderr
|
||||
echo "Then log out and back in, or run this for the current shell:" > /dev/stderr
|
||||
echo " newgrp $group" > /dev/stderr
|
||||
echo "Note: newgrp is temporary; usermod is the permanent group change." > /dev/stderr
|
||||
else
|
||||
echo "The Docker socket is owned by the '$group' group, which is not the standard 'docker' group." > /dev/stderr
|
||||
echo "For safety, this script will not suggest adding your user to '$group'." > /dev/stderr
|
||||
echo "Instead, either run this script with appropriate privileges (for example, via sudo) or follow Docker's post-install steps to configure access via the 'docker' group:" > /dev/stderr
|
||||
echo " https://docs.docker.com/engine/install/linux-postinstall/" > /dev/stderr
|
||||
fi
|
||||
|
||||
exit 1
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
check_docker_compose() {
|
||||
if command -v docker-compose &> /dev/null
|
||||
then
|
||||
@@ -311,11 +351,12 @@ initialize_default_values() {
|
||||
NETBIRD_STUN_PORT=3478
|
||||
|
||||
# Docker images
|
||||
DASHBOARD_IMAGE="netbirdio/dashboard:latest"
|
||||
DASHBOARD_IMAGE=${DASHBOARD_IMAGE:-"netbirdio/dashboard:latest"}
|
||||
# Combined server replaces separate signal, relay, and management containers
|
||||
NETBIRD_SERVER_IMAGE="netbirdio/netbird-server:latest"
|
||||
NETBIRD_PROXY_IMAGE="netbirdio/reverse-proxy:latest"
|
||||
|
||||
NETBIRD_SERVER_IMAGE=${NETBIRD_SERVER_IMAGE:-"netbirdio/netbird-server:latest"}
|
||||
NETBIRD_PROXY_IMAGE=${NETBIRD_PROXY_IMAGE:-"netbirdio/reverse-proxy:latest"}
|
||||
TRAEFIK_IMAGE=${TRAEFIK_IMAGE:-"traefik:v3.6"}
|
||||
CROWDSEC_IMAGE=${CROWDSEC_IMAGE:-"crowdsecurity/crowdsec:v1.7.7"}
|
||||
# Reverse proxy configuration
|
||||
REVERSE_PROXY_TYPE="0"
|
||||
TRAEFIK_EXTERNAL_NETWORK=""
|
||||
@@ -580,12 +621,15 @@ start_services_and_show_instructions() {
|
||||
}
|
||||
|
||||
init_environment() {
|
||||
# Check if docker compose is installed using check_docker_compose function
|
||||
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
|
||||
check_docker_sock_perms
|
||||
|
||||
initialize_default_values
|
||||
configure_domain
|
||||
configure_reverse_proxy
|
||||
|
||||
check_jq
|
||||
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
|
||||
|
||||
check_existing_installation
|
||||
generate_configuration_files
|
||||
@@ -656,7 +700,7 @@ render_docker_compose_traefik_builtin() {
|
||||
if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
|
||||
crowdsec_service="
|
||||
crowdsec:
|
||||
image: crowdsecurity/crowdsec:v1.7.7
|
||||
image: $CROWDSEC_IMAGE
|
||||
container_name: netbird-crowdsec
|
||||
restart: unless-stopped
|
||||
networks: [netbird]
|
||||
@@ -687,7 +731,7 @@ render_docker_compose_traefik_builtin() {
|
||||
services:
|
||||
# Traefik reverse proxy (automatic TLS via Let's Encrypt)
|
||||
traefik:
|
||||
image: traefik:v3.6
|
||||
image: $TRAEFIK_IMAGE
|
||||
container_name: netbird-traefik
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
@@ -771,7 +815,7 @@ $traefik_dynamic_volume
|
||||
labels:
|
||||
- traefik.enable=true
|
||||
# gRPC router (needs h2c backend for HTTP/2 cleartext)
|
||||
- traefik.http.routers.netbird-grpc.rule=Host(\`$NETBIRD_DOMAIN\`) && (PathPrefix(\`/signalexchange.SignalExchange/\`) || PathPrefix(\`/management.ManagementService/\`))
|
||||
- traefik.http.routers.netbird-grpc.rule=Host(\`$NETBIRD_DOMAIN\`) && (PathPrefix(\`/signalexchange.SignalExchange/\`) || PathPrefix(\`/management.ManagementService/\`) || PathPrefix(\`/management.ProxyService/\`))
|
||||
- traefik.http.routers.netbird-grpc.entrypoints=websecure
|
||||
- traefik.http.routers.netbird-grpc.tls=true
|
||||
- traefik.http.routers.netbird-grpc.tls.certresolver=letsencrypt
|
||||
|
||||
@@ -122,7 +122,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
||||
s.errCh = make(chan error, 4)
|
||||
|
||||
if s.autoResolveDomains {
|
||||
s.resolveDomains(srvCtx)
|
||||
s.ResolveDomains(srvCtx)
|
||||
}
|
||||
|
||||
s.PeersManager()
|
||||
@@ -398,10 +398,10 @@ func (s *BaseServer) serveGRPCWithHTTP(ctx context.Context, listener net.Listene
|
||||
}()
|
||||
}
|
||||
|
||||
// resolveDomains determines dnsDomain and mgmtSingleAccModeDomain based on store state.
|
||||
// ResolveDomains determines dnsDomain and mgmtSingleAccModeDomain based on store state.
|
||||
// Fresh installs use the default self-hosted domain, while existing installs reuse the
|
||||
// persisted account domain to keep addressing stable across config changes.
|
||||
func (s *BaseServer) resolveDomains(ctx context.Context) {
|
||||
func (s *BaseServer) ResolveDomains(ctx context.Context) {
|
||||
st := s.Store()
|
||||
|
||||
setDefault := func(logMsg string, args ...any) {
|
||||
|
||||
@@ -22,7 +22,7 @@ func TestResolveDomains_FreshInstallUsesDefault(t *testing.T) {
|
||||
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
|
||||
Inject[store.Store](srv, mockStore)
|
||||
|
||||
srv.resolveDomains(context.Background())
|
||||
srv.ResolveDomains(context.Background())
|
||||
|
||||
require.Equal(t, DefaultSelfHostedDomain, srv.dnsDomain)
|
||||
require.Equal(t, DefaultSelfHostedDomain, srv.mgmtSingleAccModeDomain)
|
||||
@@ -40,7 +40,7 @@ func TestResolveDomains_ExistingInstallUsesPersistedDomain(t *testing.T) {
|
||||
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
|
||||
Inject[store.Store](srv, mockStore)
|
||||
|
||||
srv.resolveDomains(context.Background())
|
||||
srv.ResolveDomains(context.Background())
|
||||
|
||||
require.Equal(t, "vpn.mycompany.com", srv.dnsDomain)
|
||||
require.Equal(t, "vpn.mycompany.com", srv.mgmtSingleAccModeDomain)
|
||||
@@ -56,7 +56,7 @@ func TestResolveDomains_StoreErrorFallsBackToDefault(t *testing.T) {
|
||||
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
|
||||
Inject[store.Store](srv, mockStore)
|
||||
|
||||
srv.resolveDomains(context.Background())
|
||||
srv.ResolveDomains(context.Background())
|
||||
|
||||
require.Equal(t, DefaultSelfHostedDomain, srv.dnsDomain)
|
||||
require.Equal(t, DefaultSelfHostedDomain, srv.mgmtSingleAccModeDomain)
|
||||
|
||||
@@ -666,8 +666,10 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
|
||||
case resp := <-conn.sendChan:
|
||||
if err := conn.sendResponse(resp); err != nil {
|
||||
errChan <- err
|
||||
log.WithContext(conn.ctx).Tracef("Failed to send response to proxy %s: %v", conn.proxyID, err)
|
||||
return
|
||||
}
|
||||
log.WithContext(conn.ctx).Tracef("Send response to proxy %s", conn.proxyID)
|
||||
case <-conn.ctx.Done():
|
||||
return
|
||||
}
|
||||
@@ -978,6 +980,7 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
|
||||
Mode: m.Mode,
|
||||
ListenPort: m.ListenPort,
|
||||
AccessRestrictions: m.AccessRestrictions,
|
||||
Private: m.Private,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
88
management/internals/shared/grpc/proxy_clone_test.go
Normal file
88
management/internals/shared/grpc/proxy_clone_test.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// authTokenField is the only per-proxy field that shallowCloneMapping must NOT
|
||||
// copy from the source, since callers assign it individually after cloning.
|
||||
const authTokenField = "AuthToken"
|
||||
|
||||
// TestShallowCloneMapping_ClonesAllFields populates every exported field of
|
||||
// ProxyMapping with a non-zero value and verifies the clone carries each one
|
||||
// (except AuthToken). It uses reflection so adding a new field to ProxyMapping
|
||||
// without updating shallowCloneMapping fails this test.
|
||||
func TestShallowCloneMapping_ClonesAllFields(t *testing.T) {
|
||||
src := &proto.ProxyMapping{}
|
||||
populated := populateExportedFields(t, reflect.ValueOf(src).Elem())
|
||||
require.NotEmpty(t, populated, "ProxyMapping should expose fields to populate")
|
||||
|
||||
clone := shallowCloneMapping(src)
|
||||
require.NotNil(t, clone, "clone must not be nil")
|
||||
|
||||
srcVal := reflect.ValueOf(src).Elem()
|
||||
cloneVal := reflect.ValueOf(clone).Elem()
|
||||
|
||||
for _, name := range populated {
|
||||
srcField := srcVal.FieldByName(name).Interface()
|
||||
cloneField := cloneVal.FieldByName(name).Interface()
|
||||
|
||||
if name == authTokenField {
|
||||
assert.Zero(t, cloneField, "AuthToken must not be cloned; it is set per proxy after cloning")
|
||||
continue
|
||||
}
|
||||
|
||||
assert.Equal(t, srcField, cloneField, "field %s must be carried over by shallowCloneMapping", name)
|
||||
}
|
||||
}
|
||||
|
||||
// populateExportedFields sets a non-zero value on every settable exported field
|
||||
// of the struct and returns their names.
|
||||
func populateExportedFields(t *testing.T, v reflect.Value) []string {
|
||||
t.Helper()
|
||||
|
||||
var names []string
|
||||
typ := v.Type()
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Field(i)
|
||||
structField := typ.Field(i)
|
||||
|
||||
if structField.PkgPath != "" || !field.CanSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
setNonZero(t, field, structField.Name)
|
||||
names = append(names, structField.Name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// setNonZero assigns a deterministic non-zero value based on the field kind.
|
||||
func setNonZero(t *testing.T, field reflect.Value, name string) {
|
||||
t.Helper()
|
||||
|
||||
switch field.Kind() {
|
||||
case reflect.String:
|
||||
field.SetString("non-zero-" + name)
|
||||
case reflect.Bool:
|
||||
field.SetBool(true)
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
field.SetInt(7)
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
field.SetUint(7)
|
||||
case reflect.Ptr:
|
||||
field.Set(reflect.New(field.Type().Elem()))
|
||||
case reflect.Slice:
|
||||
field.Set(reflect.MakeSlice(field.Type(), 1, 1))
|
||||
case reflect.Map:
|
||||
field.Set(reflect.MakeMapWithSize(field.Type(), 0))
|
||||
default:
|
||||
t.Fatalf("unhandled field kind %s for field %s; extend setNonZero", field.Kind(), name)
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,7 @@ const (
|
||||
RoleKey = nbcontext.RoleKey
|
||||
UserIDKey = nbcontext.UserIDKey
|
||||
PeerIDKey = nbcontext.PeerIDKey
|
||||
UserAgentKey = nbcontext.UserAgentKey
|
||||
)
|
||||
|
||||
// RoleFromContext returns the role stored in ctx, or empty string and false if absent.
|
||||
|
||||
@@ -1216,6 +1216,7 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
|
||||
Preload("NetworkResources").
|
||||
Preload("Onboarding").
|
||||
Preload("Services.Targets").
|
||||
Preload("Domains").
|
||||
Take(&account, idQueryCondition, accountID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
|
||||
@@ -1302,7 +1303,7 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, 12)
|
||||
errChan := make(chan error, 16)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
@@ -1403,6 +1404,17 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
|
||||
account.Services = services
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
domains, err := s.ListCustomDomains(ctx, accountID)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
account.Domains = domains
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -21,6 +23,63 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// TestGetAccount_LoadsCustomDomains verifies GetAccount populates account.Domains.
|
||||
// SynthesizePrivateServiceZones depends on this relation to anchor a custom-domain
|
||||
// private service's DNS zone; without the preload the relation is empty and the
|
||||
// service is silently skipped, so a custom domain never resolves on clients.
|
||||
func TestGetAccount_LoadsCustomDomains(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
defer cleanup()
|
||||
|
||||
assertGetAccountLoadsCustomDomains(t, store)
|
||||
}
|
||||
|
||||
func TestPostgresql_GetAccount_LoadsCustomDomains(t *testing.T) {
|
||||
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
assertGetAccountLoadsCustomDomains(t, store)
|
||||
}
|
||||
|
||||
// assertGetAccountLoadsCustomDomains exercises both the gorm and pgx GetAccount
|
||||
// paths: it persists two custom domains and asserts the relation comes back
|
||||
// populated, which SynthesizePrivateServiceZones relies on.
|
||||
func assertGetAccountLoadsCustomDomains(t *testing.T, store Store) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
accountID := "acct-custom-domains"
|
||||
require.NoError(t, store.SaveAccount(ctx, newAccountWithId(ctx, accountID, "user-1", "")))
|
||||
|
||||
_, err := store.CreateCustomDomain(ctx, accountID, "example.com", "eu.proxy.netbird.io", true)
|
||||
require.NoError(t, err, "creating the first custom domain must succeed")
|
||||
_, err = store.CreateCustomDomain(ctx, accountID, "apps.acme.io", "us.proxy.netbird.io", false)
|
||||
require.NoError(t, err, "creating the second custom domain must succeed")
|
||||
|
||||
account, err := store.GetAccount(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, account.Domains, 2, "GetAccount must preload the account's custom domains")
|
||||
|
||||
byDomain := map[string]string{}
|
||||
for _, d := range account.Domains {
|
||||
require.NotNil(t, d)
|
||||
byDomain[d.Domain] = d.TargetCluster
|
||||
}
|
||||
assert.Equal(t, "eu.proxy.netbird.io", byDomain["example.com"], "custom domain must carry its target cluster")
|
||||
assert.Equal(t, "us.proxy.netbird.io", byDomain["apps.acme.io"], "custom domain must carry its target cluster")
|
||||
}
|
||||
|
||||
// TestGetAccount_ComprehensiveFieldValidation validates that GetAccount properly loads
|
||||
// all fields and nested objects from the database, including deeply nested structures.
|
||||
func TestGetAccount_ComprehensiveFieldValidation(t *testing.T) {
|
||||
|
||||
@@ -21,6 +21,8 @@ const (
|
||||
httpRequestCounterPrefix = "management.http.request.counter"
|
||||
httpResponseCounterPrefix = "management.http.response.counter"
|
||||
httpRequestDurationPrefix = "management.http.request.duration.ms"
|
||||
|
||||
RequestIDHeader = "X-Request-Id"
|
||||
)
|
||||
|
||||
// WrappedResponseWriter is a wrapper for http.ResponseWriter that allows the
|
||||
@@ -172,6 +174,10 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
|
||||
reqID := xid.New().String()
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.UserAgentKey, r.UserAgent())
|
||||
|
||||
rw.Header().Set(RequestIDHeader, reqID)
|
||||
|
||||
log.WithContext(ctx).Tracef("HTTP request %v: %v %v", reqID, r.Method, r.URL)
|
||||
|
||||
|
||||
@@ -273,7 +273,7 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
|
||||
}
|
||||
|
||||
peerGroups := a.GetPeerGroups(peerID)
|
||||
zonesByCluster := map[string]*nbdns.CustomZone{}
|
||||
zonesByApex := map[string]*nbdns.CustomZone{}
|
||||
|
||||
for _, svc := range a.Services {
|
||||
if svc == nil || !svc.Enabled || !svc.Private {
|
||||
@@ -290,19 +290,24 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
|
||||
continue
|
||||
}
|
||||
|
||||
zone, exists := zonesByCluster[svc.ProxyCluster]
|
||||
serviceDomainZone := a.privateServiceDomainZone(svc)
|
||||
if serviceDomainZone == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
zone, exists := zonesByApex[serviceDomainZone]
|
||||
if !exists {
|
||||
// NonAuthoritative makes this a match-only zone: queries for
|
||||
// names without an explicit record fall through to the
|
||||
// upstream resolver instead of returning NXDOMAIN. Without
|
||||
// it, adding a single private service would black-hole every
|
||||
// other name under the cluster apex.
|
||||
// other name under the zone apex.
|
||||
zone = &nbdns.CustomZone{
|
||||
Domain: dns.Fqdn(svc.ProxyCluster),
|
||||
Domain: dns.Fqdn(serviceDomainZone),
|
||||
Records: []nbdns.SimpleRecord{},
|
||||
NonAuthoritative: true,
|
||||
}
|
||||
zonesByCluster[svc.ProxyCluster] = zone
|
||||
zonesByApex[serviceDomainZone] = zone
|
||||
}
|
||||
|
||||
emitted := 0
|
||||
@@ -340,8 +345,8 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
|
||||
}
|
||||
}
|
||||
|
||||
out := make([]nbdns.CustomZone, 0, len(zonesByCluster))
|
||||
for _, zone := range zonesByCluster {
|
||||
out := make([]nbdns.CustomZone, 0, len(zonesByApex))
|
||||
for _, zone := range zonesByApex {
|
||||
if len(zone.Records) == 0 {
|
||||
continue
|
||||
}
|
||||
@@ -357,6 +362,33 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
|
||||
return out
|
||||
}
|
||||
|
||||
// privateServiceDomainZone returns the DNS zone name for the given private service domain by
|
||||
// looking at the proxy cluster domain then the custom domains.
|
||||
func (a *Account) privateServiceDomainZone(svc *service.Service) string {
|
||||
if domainFromSuffix(svc.Domain, svc.ProxyCluster) {
|
||||
return svc.ProxyCluster
|
||||
}
|
||||
|
||||
// Longest matching custom domain wins
|
||||
zoneName := ""
|
||||
for _, d := range a.Domains {
|
||||
if d == nil || d.TargetCluster != svc.ProxyCluster {
|
||||
continue
|
||||
}
|
||||
if domainFromSuffix(svc.Domain, d.Domain) && len(d.Domain) > len(zoneName) {
|
||||
zoneName = d.Domain
|
||||
}
|
||||
}
|
||||
return zoneName
|
||||
}
|
||||
|
||||
func domainFromSuffix(domain, suffix string) bool {
|
||||
if suffix == "" {
|
||||
return false
|
||||
}
|
||||
return domain == suffix || strings.HasSuffix(domain, "."+suffix)
|
||||
}
|
||||
|
||||
// peerInDistributionGroups reports whether any of the peer's groups
|
||||
// matches the service's bearer-auth distribution_groups.
|
||||
func peerInDistributionGroups(peerGroups LookupMap, distributionGroups []string) bool {
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
proxydomain "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
@@ -234,6 +235,113 @@ func TestPrivateZone_GetPeerNetworkMap_PeerOutsideGroups_OmitsSynthZone(t *testi
|
||||
assert.False(t, ok, "peer outside the distribution_groups must not see the synth zone")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_CustomDomain_ZoneApexIsRegisteredDomain(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
// A custom-domain service: Domain is the custom FQDN, ProxyCluster
|
||||
// is the cluster serving it, and account.Domains holds the registered
|
||||
// custom domain. The synth zone apex must be the registered domain,
|
||||
// not the cluster, or the client's match-only zone never intercepts
|
||||
// the query.
|
||||
account.Services[0].Domain = "app.example.com"
|
||||
account.Domains = []*proxydomain.Domain{
|
||||
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
|
||||
}
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
require.Len(t, zones, 1, "custom-domain service must still produce one zone")
|
||||
zone := zones[0]
|
||||
assert.Equal(t, "example.com.", zone.Domain, "zone apex must be the registered custom domain, not the cluster or the service FQDN")
|
||||
assert.True(t, zone.NonAuthoritative, "synth zone must remain match-only")
|
||||
require.Len(t, zone.Records, 1, "custom-domain service yields one A record")
|
||||
rec := zone.Records[0]
|
||||
assert.Equal(t, "app.example.com.", rec.Name, "record name is the custom service FQDN")
|
||||
assert.Equal(t, "100.64.0.99", rec.RData, "record points at the embedded proxy peer's tunnel IP")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_CustomAndFreeDomain_SeparateZones(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Domains = []*proxydomain.Domain{
|
||||
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
|
||||
}
|
||||
account.Services = append(account.Services, &service.Service{
|
||||
ID: "svc-2",
|
||||
AccountID: "acct-1",
|
||||
Name: "custom",
|
||||
Domain: "app.example.com",
|
||||
ProxyCluster: "eu.proxy.netbird.io",
|
||||
Enabled: true,
|
||||
Private: true,
|
||||
Mode: service.ModeHTTP,
|
||||
AccessGroups: []string{"grp-admins"},
|
||||
})
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
require.Len(t, zones, 2, "a free-domain and a custom-domain service must not collapse into one zone")
|
||||
|
||||
free, ok := findCustomZone(zones, "eu.proxy.netbird.io")
|
||||
require.True(t, ok, "free-domain service keeps the shared cluster-apex zone")
|
||||
require.Len(t, free.Records, 1, "cluster zone carries only the free-domain record")
|
||||
assert.Equal(t, "myapp.eu.proxy.netbird.io.", free.Records[0].Name, "cluster zone record is the free-domain FQDN")
|
||||
|
||||
custom, ok := findCustomZone(zones, "example.com")
|
||||
require.True(t, ok, "custom-domain service gets its own zone at the registered custom domain apex")
|
||||
require.Len(t, custom.Records, 1, "custom zone carries only the custom-domain record")
|
||||
assert.Equal(t, "app.example.com.", custom.Records[0].Name, "custom zone record is the custom-domain FQDN")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_TwoServicesSameCustomDomain_OneZone(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Domains = []*proxydomain.Domain{
|
||||
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
|
||||
}
|
||||
account.Services[0].Domain = "a.example.com"
|
||||
account.Services = append(account.Services, &service.Service{
|
||||
ID: "svc-2",
|
||||
AccountID: "acct-1",
|
||||
Name: "bapp",
|
||||
Domain: "b.example.com",
|
||||
ProxyCluster: "eu.proxy.netbird.io",
|
||||
Enabled: true,
|
||||
Private: true,
|
||||
Mode: service.ModeHTTP,
|
||||
AccessGroups: []string{"grp-admins"},
|
||||
})
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
require.Len(t, zones, 1, "two services under the same registered custom domain must share one zone")
|
||||
assert.Equal(t, "example.com.", zones[0].Domain, "shared zone apex is the registered custom domain")
|
||||
require.Len(t, zones[0].Records, 2, "both services surface as records in the shared custom-domain zone")
|
||||
names := []string{zones[0].Records[0].Name, zones[0].Records[1].Name}
|
||||
assert.ElementsMatch(t, []string{"a.example.com.", "b.example.com."}, names, "both custom-domain service FQDNs must surface")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_CustomDomainNotRegistered_NoZone(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
// Service domain is outside the cluster and no account.Domains entry
|
||||
// covers it: there is no apex that would intercept the query, so the
|
||||
// service must be skipped rather than emit an unmatchable record.
|
||||
account.Services[0].Domain = "app.example.com"
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
assert.Empty(t, zones, "a custom-domain service with no registered domain apex must not produce a zone")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_CustomDomainClusterMismatch_NoZone(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
// The registered custom domain matches the service FQDN by suffix but
|
||||
// targets a different cluster than the service's ProxyCluster. It must
|
||||
// be ignored, leaving no apex to intercept the query — otherwise the
|
||||
// zone would point at this cluster's proxy peers under a domain owned
|
||||
// by a different cluster.
|
||||
account.Services[0].Domain = "app.example.com"
|
||||
account.Domains = []*proxydomain.Domain{
|
||||
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "us.proxy.netbird.io", Validated: true},
|
||||
}
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
assert.Empty(t, zones, "a custom domain targeting a different cluster must not anchor the service zone")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_TwoServicesSameCluster_OneZone(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Services = append(account.Services, &service.Service{
|
||||
@@ -254,3 +362,72 @@ func TestSynthesizePrivateServiceZones_TwoServicesSameCluster_OneZone(t *testing
|
||||
names := []string{zones[0].Records[0].Name, zones[0].Records[1].Name}
|
||||
assert.ElementsMatch(t, []string{"myapp.eu.proxy.netbird.io.", "anotherapp.eu.proxy.netbird.io."}, names, "both service domains must surface")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_MixedClusterCustomAndPublic(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Domains = []*proxydomain.Domain{
|
||||
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
|
||||
}
|
||||
|
||||
privateService := func(id, domain string) *service.Service {
|
||||
return &service.Service{
|
||||
ID: id,
|
||||
AccountID: "acct-1",
|
||||
Name: id,
|
||||
Domain: domain,
|
||||
ProxyCluster: "eu.proxy.netbird.io",
|
||||
Enabled: true,
|
||||
Private: true,
|
||||
Mode: service.ModeHTTP,
|
||||
AccessGroups: []string{"grp-admins"},
|
||||
}
|
||||
}
|
||||
publicService := func(id, domain string) *service.Service {
|
||||
s := privateService(id, domain)
|
||||
s.Private = false
|
||||
return s
|
||||
}
|
||||
|
||||
account.Services = []*service.Service{
|
||||
// 3 private services under the cluster suffix.
|
||||
privateService("cluster-1", "cluster1.eu.proxy.netbird.io"),
|
||||
privateService("cluster-2", "cluster2.eu.proxy.netbird.io"),
|
||||
privateService("cluster-3", "cluster3.eu.proxy.netbird.io"),
|
||||
// 4 private services under the custom domain suffix.
|
||||
privateService("custom-1", "custom1.example.com"),
|
||||
privateService("custom-2", "custom2.example.com"),
|
||||
privateService("custom-3", "custom3.example.com"),
|
||||
privateService("custom-4", "custom4.example.com"),
|
||||
// 2 public services, one per suffix, must not surface.
|
||||
publicService("public-cluster", "public.eu.proxy.netbird.io"),
|
||||
publicService("public-custom", "public.example.com"),
|
||||
}
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
require.Len(t, zones, 2, "one zone per apex: the cluster apex and the custom domain apex")
|
||||
|
||||
cluster, ok := findCustomZone(zones, "eu.proxy.netbird.io")
|
||||
require.True(t, ok, "cluster-suffix services collapse into the cluster-apex zone")
|
||||
clusterNames := recordNames(cluster)
|
||||
assert.ElementsMatch(t,
|
||||
[]string{"cluster1.eu.proxy.netbird.io.", "cluster2.eu.proxy.netbird.io.", "cluster3.eu.proxy.netbird.io."},
|
||||
clusterNames,
|
||||
"only the 3 private cluster services surface in the cluster zone (public one excluded)")
|
||||
|
||||
custom, ok := findCustomZone(zones, "example.com")
|
||||
require.True(t, ok, "custom-suffix services collapse into the custom-domain-apex zone")
|
||||
customNames := recordNames(custom)
|
||||
assert.ElementsMatch(t,
|
||||
[]string{"custom1.example.com.", "custom2.example.com.", "custom3.example.com.", "custom4.example.com."},
|
||||
customNames,
|
||||
"only the 4 private custom services surface in the custom zone (public one excluded)")
|
||||
}
|
||||
|
||||
// recordNames returns the record names of a zone for order-independent assertions.
|
||||
func recordNames(zone nbdns.CustomZone) []string {
|
||||
names := make([]string, 0, len(zone.Records))
|
||||
for _, r := range zone.Records {
|
||||
names = append(names, r.Name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
@@ -557,7 +557,6 @@ func (c *NetworkMapComponents) getRoutingPeerRoutes(peerID string) (enabledRoute
|
||||
return enabledRoutes, disabledRoutes
|
||||
}
|
||||
|
||||
|
||||
func (c *NetworkMapComponents) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route {
|
||||
var filteredRoutes []*route.Route
|
||||
for _, r := range routes {
|
||||
@@ -628,9 +627,14 @@ func (c *NetworkMapComponents) getDefaultPermit(r *route.Route, includeIPv6 bool
|
||||
|
||||
rules := []*RouteFirewallRule{&rule}
|
||||
|
||||
if includeIPv6 && r.IsDynamic() {
|
||||
isDefaultV4 := r.Network.Addr().Is4() && r.Network.Bits() == 0
|
||||
if includeIPv6 && (r.IsDynamic() || isDefaultV4) {
|
||||
ruleV6 := rule
|
||||
ruleV6.SourceRanges = []string{"::/0"}
|
||||
if isDefaultV4 {
|
||||
ruleV6.Destination = "::/0"
|
||||
ruleV6.RouteID = r.ID + "-v6-default"
|
||||
}
|
||||
rules = append(rules, &ruleV6)
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -1029,6 +1030,48 @@ func TestComponents_RouteDefaultPermit(t *testing.T) {
|
||||
assert.True(t, hasDefaultPermit, "route without ACG should have default permit rule with 0.0.0.0/0 source")
|
||||
}
|
||||
|
||||
// TestComponents_ExitNodeDefaultPermitIPv6 verifies that a default exit node route
|
||||
// (0.0.0.0/0) without AccessControlGroups also emits an IPv6 default permit rule
|
||||
// (::/0 source and destination) for peers that support IPv6, mirroring the route
|
||||
// the client installs. Without it, IPv6 traffic is routed to the exit node but
|
||||
// dropped at the forward chain.
|
||||
func TestComponents_ExitNodeDefaultPermitIPv6(t *testing.T) {
|
||||
account, validatedPeers := scalableTestAccount(20, 2)
|
||||
|
||||
routingPeerID := "peer-5"
|
||||
routingPeer := account.Peers[routingPeerID]
|
||||
routingPeer.IPv6 = netip.MustParseAddr("fd00::5")
|
||||
routingPeer.Meta.Capabilities = append(routingPeer.Meta.Capabilities, nbpeer.PeerCapabilityIPv6Overlay)
|
||||
|
||||
account.Routes["route-exit"] = &route.Route{
|
||||
ID: "route-exit", Network: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
PeerID: routingPeerID, Peer: routingPeer.Key,
|
||||
Enabled: true, Groups: []string{"group-all"}, PeerGroups: []string{"group-0"},
|
||||
AccessControlGroups: []string{},
|
||||
AccountID: "test-account",
|
||||
}
|
||||
|
||||
nm := componentsNetworkMap(account, routingPeerID, validatedPeers)
|
||||
require.NotNil(t, nm)
|
||||
|
||||
hasV4 := false
|
||||
hasV6 := false
|
||||
for _, rfr := range nm.RoutesFirewallRules {
|
||||
switch rfr.Destination {
|
||||
case "0.0.0.0/0":
|
||||
if slices.Contains(rfr.SourceRanges, "0.0.0.0/0") {
|
||||
hasV4 = true
|
||||
}
|
||||
case "::/0":
|
||||
if slices.Contains(rfr.SourceRanges, "::/0") {
|
||||
hasV6 = true
|
||||
}
|
||||
}
|
||||
}
|
||||
assert.True(t, hasV4, "exit node route should have an IPv4 default permit rule (0.0.0.0/0)")
|
||||
assert.True(t, hasV6, "exit node route should have an IPv6 default permit rule (::/0)")
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
// 15. MULTIPLE ROUTERS PER NETWORK
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -249,6 +249,7 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
Private: private,
|
||||
MaxDialTimeout: maxDialTimeout,
|
||||
MaxSessionIdleTimeout: maxSessionIdleTimeout,
|
||||
MappingBatchWatchdog: envDurationOrDefault("NB_PROXY_MAPPING_BATCH_WATCHDOG", 0),
|
||||
GeoDataDir: geoDataDir,
|
||||
CrowdSecAPIURL: crowdsecAPIURL,
|
||||
CrowdSecAPIKey: crowdsecAPIKey,
|
||||
|
||||
@@ -28,6 +28,10 @@ import (
|
||||
|
||||
const deviceNamePrefix = "ingress-proxy-"
|
||||
|
||||
const clientStopTimeout = 30 * time.Second
|
||||
|
||||
const createProxyPeerTimeout = 30 * time.Second
|
||||
|
||||
// backendKey identifies a backend by its host:port from the target URL.
|
||||
type backendKey string
|
||||
|
||||
@@ -162,6 +166,7 @@ type NetBird struct {
|
||||
|
||||
clientsMux sync.RWMutex
|
||||
clients map[types.AccountID]*clientEntry
|
||||
lifecycleMu sync.Map
|
||||
initLogOnce sync.Once
|
||||
statusNotifier statusNotifier
|
||||
// readyHandler runs after the embedded client for an account reports
|
||||
@@ -177,6 +182,10 @@ type NetBird struct {
|
||||
// (i.e. when a new client was actually created, not when an existing one
|
||||
// was reused). The duration covers keygen + gRPC CreateProxyPeer + embed.New.
|
||||
OnAddPeer func(d time.Duration, err error)
|
||||
|
||||
// startClient runs the post-create client startup. Nil uses runClientStartup;
|
||||
// tests override it to avoid a real embed client.Start.
|
||||
startClient func(accountID types.AccountID, client *embed.Client)
|
||||
}
|
||||
|
||||
// ClientDebugInfo contains debug information about a client.
|
||||
@@ -200,31 +209,20 @@ type skipTLSVerifyContextKey struct{}
|
||||
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, serviceID types.ServiceID) error {
|
||||
si := serviceInfo{serviceID: serviceID}
|
||||
|
||||
n.clientsMux.Lock()
|
||||
if n.registerExistingClient(accountID, key, si) {
|
||||
return nil
|
||||
}
|
||||
|
||||
entry, exists := n.clients[accountID]
|
||||
if exists {
|
||||
entry.services[key] = si
|
||||
started := entry.started
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
}).Debug("registered service with existing client")
|
||||
|
||||
if started && n.statusNotifier != nil {
|
||||
// Use a background context, not the caller's: the management
|
||||
// connection notification must land even if the request /
|
||||
// stream that triggered this registration is cancelled.
|
||||
// Mirrors the async runClientStartup path.
|
||||
if err := n.statusNotifier.NotifyStatus(context.Background(), accountID, serviceID, true); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
}).WithError(err).Warn("failed to notify status for existing client")
|
||||
}
|
||||
lifecycle := n.accountLifecycle(accountID)
|
||||
lifecycle.Lock()
|
||||
transferred := false
|
||||
defer func() {
|
||||
if !transferred {
|
||||
lifecycle.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
if n.registerExistingClient(accountID, key, si) {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -234,10 +232,10 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
|
||||
n.OnAddPeer(time.Since(createStart), err)
|
||||
}
|
||||
if err != nil {
|
||||
n.clientsMux.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
n.clientsMux.Lock()
|
||||
n.clients[accountID] = entry
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
@@ -246,17 +244,64 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
|
||||
"service_key": key,
|
||||
}).Info("created new client for account")
|
||||
|
||||
// Attempt to start the client in the background; if this fails we will
|
||||
// retry on the first request via RoundTrip. runClientStartup uses its
|
||||
// own background context so the caller's request-scoped ctx can't
|
||||
// cancel the inbound bring-up.
|
||||
go n.runClientStartup(accountID, entry.client)
|
||||
transferred = true
|
||||
go func() {
|
||||
defer lifecycle.Unlock()
|
||||
n.startClientStartup(accountID, entry.client)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *NetBird) startClientStartup(accountID types.AccountID, client *embed.Client) {
|
||||
if n.startClient != nil {
|
||||
n.startClient(accountID, client)
|
||||
return
|
||||
}
|
||||
n.runClientStartup(accountID, client)
|
||||
}
|
||||
|
||||
// registerExistingClient registers the service against an already-present
|
||||
// client for the account and returns true when it did. It notifies management
|
||||
// of the new service when the client is already started.
|
||||
func (n *NetBird) registerExistingClient(accountID types.AccountID, key ServiceKey, si serviceInfo) bool {
|
||||
n.clientsMux.Lock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if !exists {
|
||||
n.clientsMux.Unlock()
|
||||
return false
|
||||
}
|
||||
entry.services[key] = si
|
||||
started := entry.started
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
}).Debug("registered service with existing client")
|
||||
|
||||
if started && n.statusNotifier != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(context.Background(), accountID, si.serviceID, true); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
}).WithError(err).Warn("failed to notify status for existing client")
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// accountLifecycle returns the per-account lifecycle mutex, serialising client
|
||||
// creation against teardown so a slow client.Stop cannot race a new
|
||||
// client.Start for the same account, without blocking clientsMux.
|
||||
func (n *NetBird) accountLifecycle(accountID types.AccountID) *sync.Mutex {
|
||||
mu, _ := n.lifecycleMu.LoadOrStore(accountID, &sync.Mutex{})
|
||||
return mu.(*sync.Mutex)
|
||||
}
|
||||
|
||||
// createClientEntry generates a WireGuard keypair, authenticates with management,
|
||||
// and creates an embedded NetBird client. Must be called with clientsMux held.
|
||||
// and creates an embedded NetBird client. Must be called with the account's
|
||||
// lifecycle mutex held.
|
||||
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, si serviceInfo) (*clientEntry, error) {
|
||||
serviceID := si.serviceID
|
||||
n.logger.WithFields(log.Fields{
|
||||
@@ -276,7 +321,9 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
"public_key": publicKey.String(),
|
||||
}).Debug("authenticating new proxy peer with management")
|
||||
|
||||
resp, err := n.mgmtClient.CreateProxyPeer(ctx, &proto.CreateProxyPeerRequest{
|
||||
createCtx, cancel := context.WithTimeout(ctx, createProxyPeerTimeout)
|
||||
defer cancel()
|
||||
resp, err := n.mgmtClient.CreateProxyPeer(createCtx, &proto.CreateProxyPeerRequest{
|
||||
ServiceId: string(serviceID),
|
||||
AccountId: string(accountID),
|
||||
Token: authToken,
|
||||
@@ -444,6 +491,15 @@ func (n *NetBird) notifyClientReady(accountID types.AccountID, client *embed.Cli
|
||||
// RemovePeer unregisters a service from an account. The client is only stopped
|
||||
// when no services are using it anymore.
|
||||
func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key ServiceKey) error {
|
||||
lifecycle := n.accountLifecycle(accountID)
|
||||
lifecycle.Lock()
|
||||
transferred := false
|
||||
defer func() {
|
||||
if !transferred {
|
||||
lifecycle.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
n.clientsMux.Lock()
|
||||
|
||||
entry, exists := n.clients[accountID]
|
||||
@@ -466,17 +522,8 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key
|
||||
delete(entry.services, key)
|
||||
|
||||
stopClient := len(entry.services) == 0
|
||||
var client *embed.Client
|
||||
var transport, insecureTransport *http.Transport
|
||||
var inbound any
|
||||
var stopHandler func(types.AccountID, any)
|
||||
if stopClient {
|
||||
n.logger.WithField("account_id", accountID).Info("stopping client, no more services")
|
||||
client = entry.client
|
||||
transport = entry.transport
|
||||
insecureTransport = entry.insecureTransport
|
||||
inbound = entry.inbound
|
||||
stopHandler = n.stopHandler
|
||||
delete(n.clients, accountID)
|
||||
} else {
|
||||
n.logger.WithFields(log.Fields{
|
||||
@@ -490,19 +537,40 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key
|
||||
n.notifyDisconnect(ctx, accountID, key, si.serviceID)
|
||||
|
||||
if stopClient {
|
||||
if inbound != nil && stopHandler != nil {
|
||||
stopHandler(accountID, inbound)
|
||||
}
|
||||
transport.CloseIdleConnections()
|
||||
insecureTransport.CloseIdleConnections()
|
||||
if err := client.Stop(ctx); err != nil {
|
||||
n.logger.WithField("account_id", accountID).WithError(err).Warn("failed to stop netbird client")
|
||||
}
|
||||
transferred = true
|
||||
go n.stopClientLocked(accountID, lifecycle, entry)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// stopClientLocked releases a client's resources off the caller's goroutine so a
|
||||
// slow client.Stop cannot wedge the mapping receive loop (which calls RemovePeer
|
||||
// synchronously). It unlocks lifecycle when done so a new client.Start for the
|
||||
// same account waits for this teardown.
|
||||
func (n *NetBird) stopClientLocked(accountID types.AccountID, lifecycle *sync.Mutex, entry *clientEntry) {
|
||||
defer lifecycle.Unlock()
|
||||
|
||||
if entry.inbound != nil && n.stopHandler != nil {
|
||||
n.stopHandler(accountID, entry.inbound)
|
||||
}
|
||||
if entry.transport != nil {
|
||||
entry.transport.CloseIdleConnections()
|
||||
}
|
||||
if entry.insecureTransport != nil {
|
||||
entry.insecureTransport.CloseIdleConnections()
|
||||
}
|
||||
if entry.client == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), clientStopTimeout)
|
||||
defer cancel()
|
||||
if err := entry.client.Stop(ctx); err != nil {
|
||||
n.logger.WithField("account_id", accountID).WithError(err).Warn("failed to stop netbird client")
|
||||
}
|
||||
}
|
||||
|
||||
func (n *NetBird) notifyDisconnect(ctx context.Context, accountID types.AccountID, key ServiceKey, serviceID types.ServiceID) {
|
||||
if n.statusNotifier == nil {
|
||||
return
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -22,6 +23,18 @@ func (m *mockMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxy
|
||||
return &proto.CreateProxyPeerResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
// signalMgmtClient closes entered the first time CreateProxyPeer is called, so
|
||||
// tests can detect AddPeer reaching client creation.
|
||||
type signalMgmtClient struct {
|
||||
entered chan struct{}
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (m *signalMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxyPeerRequest, _ ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error) {
|
||||
m.once.Do(func() { close(m.entered) })
|
||||
return &proto.CreateProxyPeerResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
type mockStatusNotifier struct {
|
||||
mu sync.Mutex
|
||||
statuses []statusCall
|
||||
@@ -52,11 +65,15 @@ func (m *mockStatusNotifier) calls() []statusCall {
|
||||
// mockNetBird creates a NetBird instance for testing without actually connecting.
|
||||
// It uses an invalid management URL to prevent real connections.
|
||||
func mockNetBird() *NetBird {
|
||||
return NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
|
||||
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
|
||||
MgmtAddr: "http://invalid.test:9999",
|
||||
WGPort: 0,
|
||||
PreSharedKey: "",
|
||||
}, nil, nil, &mockMgmtClient{})
|
||||
// Skip the real embed client.Start, which would hang against the unreachable
|
||||
// mgmt URL and (now that the lifecycle lock spans startup) serialise removes.
|
||||
nb.startClient = func(types.AccountID, *embed.Client) {}
|
||||
return nb
|
||||
}
|
||||
|
||||
func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) {
|
||||
@@ -288,6 +305,7 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
|
||||
WGPort: 0,
|
||||
PreSharedKey: "",
|
||||
}, nil, notifier, &mockMgmtClient{})
|
||||
nb.startClient = func(types.AccountID, *embed.Client) {}
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add first service — creates a new client entry.
|
||||
@@ -372,6 +390,117 @@ func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
|
||||
assert.False(t, calls[0].connected)
|
||||
}
|
||||
|
||||
// TestNetBird_RemovePeer_TeardownIsAsync proves the fix for the receive-loop
|
||||
// stall: RemovePeer must return promptly even when the client teardown blocks,
|
||||
// because teardown runs off the caller's goroutine. The receive loop calls
|
||||
// RemovePeer synchronously, so a blocking teardown inline would wedge it.
|
||||
func TestNetBird_RemovePeer_TeardownIsAsync(t *testing.T) {
|
||||
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
|
||||
MgmtAddr: "http://invalid.test:9999",
|
||||
}, nil, &mockStatusNotifier{}, &mockMgmtClient{})
|
||||
|
||||
accountID := types.AccountID("acct-async-teardown")
|
||||
key := DomainServiceKey("svc.example")
|
||||
|
||||
teardownEntered := make(chan struct{})
|
||||
releaseTeardown := make(chan struct{})
|
||||
nb.SetClientLifecycle(nil, func(types.AccountID, any) {
|
||||
close(teardownEntered)
|
||||
<-releaseTeardown
|
||||
})
|
||||
|
||||
nb.clientsMux.Lock()
|
||||
nb.clients[accountID] = &clientEntry{
|
||||
services: map[ServiceKey]serviceInfo{key: {serviceID: types.ServiceID("svc-1")}},
|
||||
started: true,
|
||||
inbound: struct{}{},
|
||||
}
|
||||
nb.clientsMux.Unlock()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() { done <- nb.RemovePeer(context.Background(), accountID, key) }()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("RemovePeer did not return while teardown was blocked — teardown is not async")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-teardownEntered:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("teardown never ran")
|
||||
}
|
||||
|
||||
close(releaseTeardown)
|
||||
}
|
||||
|
||||
// TestNetBird_AddPeer_WaitsForTeardown proves the lifecycle lock serialises a
|
||||
// new client bringup behind an in-flight teardown for the same account, so a
|
||||
// slow client.Stop can never race a new client.Start for that account.
|
||||
//
|
||||
// It targets the handoff race specifically: AddPeer is launched immediately
|
||||
// after RemovePeer returns, WITHOUT waiting for the teardown goroutine to start.
|
||||
// This only passes if RemovePeer acquires the lifecycle lock synchronously
|
||||
// (before returning) and hands it to the teardown goroutine — if the goroutine
|
||||
// acquired the lock itself, AddPeer could win the lock in this window and start
|
||||
// a replacement client while the old teardown is still pending.
|
||||
func TestNetBird_AddPeer_WaitsForTeardown(t *testing.T) {
|
||||
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
|
||||
MgmtAddr: "http://invalid.test:9999",
|
||||
}, nil, &mockStatusNotifier{}, &mockMgmtClient{})
|
||||
nb.startClient = func(types.AccountID, *embed.Client) {}
|
||||
|
||||
accountID := types.AccountID("acct-serialize")
|
||||
key := DomainServiceKey("svc.example")
|
||||
|
||||
addEntered := make(chan struct{})
|
||||
releaseTeardown := make(chan struct{})
|
||||
nb.SetClientLifecycle(nil, func(types.AccountID, any) {
|
||||
// Block teardown until released. If AddPeer ever reaches createClientEntry
|
||||
// (signalled via the mgmt client below) while we hold the lock, the lock
|
||||
// failed to serialise and the test fails before we release.
|
||||
<-releaseTeardown
|
||||
})
|
||||
|
||||
nb.clientsMux.Lock()
|
||||
nb.clients[accountID] = &clientEntry{
|
||||
services: map[ServiceKey]serviceInfo{key: {serviceID: types.ServiceID("svc-1")}},
|
||||
started: true,
|
||||
inbound: struct{}{},
|
||||
}
|
||||
nb.clientsMux.Unlock()
|
||||
|
||||
// createClientEntry calls CreateProxyPeer; closing addEntered there tells us
|
||||
// AddPeer got past the lifecycle lock and into client creation.
|
||||
nb.mgmtClient = &signalMgmtClient{entered: addEntered}
|
||||
|
||||
require.NoError(t, nb.RemovePeer(context.Background(), accountID, key))
|
||||
|
||||
// Launch AddPeer with NO synchronisation against the teardown goroutine.
|
||||
addReturned := make(chan struct{})
|
||||
go func() {
|
||||
_ = nb.AddPeer(context.Background(), accountID, DomainServiceKey("svc2.example"), "key-2", types.ServiceID("svc-2"))
|
||||
close(addReturned)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-addEntered:
|
||||
t.Fatal("AddPeer entered client creation while teardown held the lifecycle lock — handoff race not closed")
|
||||
case <-addReturned:
|
||||
t.Fatal("AddPeer completed while teardown held the lifecycle lock — not serialised")
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
}
|
||||
|
||||
close(releaseTeardown)
|
||||
select {
|
||||
case <-addReturned:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("AddPeer never completed after teardown released the lifecycle lock")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNotifyClientReady_UsesBackgroundCtx pins the contract that the
|
||||
// post-Start hooks (readyHandler + statusNotifier.NotifyStatus) run on
|
||||
// a fresh context.Background() rather than inheriting the AddPeer
|
||||
|
||||
@@ -114,6 +114,10 @@ type Config struct {
|
||||
MaxDialTimeout time.Duration
|
||||
// MaxSessionIdleTimeout caps the per-service session idle timeout.
|
||||
MaxSessionIdleTimeout time.Duration
|
||||
// MappingBatchWatchdog bounds how long a single mapping batch may spend
|
||||
// being applied before the receive loop reconnects to resync. Zero falls
|
||||
// back to the internal default.
|
||||
MappingBatchWatchdog time.Duration
|
||||
|
||||
// GeoDataDir is the directory containing GeoLite2 MMDB files.
|
||||
GeoDataDir string
|
||||
@@ -164,6 +168,7 @@ func New(ctx context.Context, cfg Config) *Server {
|
||||
Private: cfg.Private,
|
||||
MaxDialTimeout: cfg.MaxDialTimeout,
|
||||
MaxSessionIdleTimeout: cfg.MaxSessionIdleTimeout,
|
||||
MappingBatchWatchdog: cfg.MappingBatchWatchdog,
|
||||
GeoDataDir: cfg.GeoDataDir,
|
||||
CrowdSecAPIURL: cfg.CrowdSecAPIURL,
|
||||
CrowdSecAPIKey: cfg.CrowdSecAPIKey,
|
||||
|
||||
282
proxy/mapping_stall_test.go
Normal file
282
proxy/mapping_stall_test.go
Normal file
@@ -0,0 +1,282 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// blockingMgmtClient implements roundtrip's managementClient interface.
|
||||
// CreateProxyPeer parks until release is closed, signalling entry on entered.
|
||||
// This reproduces the confirmed real-world stall: createClientEntry calls
|
||||
// CreateProxyPeer synchronously while holding clientsMux, and the proxy's
|
||||
// receive loop calls that path synchronously inside processMappings.
|
||||
type blockingMgmtClient struct {
|
||||
entered chan struct{}
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (b *blockingMgmtClient) CreateProxyPeer(ctx context.Context, _ *proto.CreateProxyPeerRequest, _ ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error) {
|
||||
b.once.Do(func() { close(b.entered) })
|
||||
// Park until the caller's context is cancelled. In production this ctx is
|
||||
// the gRPC mapping-stream context with no per-call timeout, so a slow or
|
||||
// unresponsive CreateProxyPeer parks the receive loop here indefinitely.
|
||||
<-ctx.Done()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
// gatedMappingStream is a mock GetMappingUpdate client stream that hands out a
|
||||
// pre-seeded list of messages, then records how many times Recv advanced. It
|
||||
// lets the test observe whether the single-threaded receive loop ever gets
|
||||
// past the first (blocking) batch to pull the second message.
|
||||
type gatedMappingStream struct {
|
||||
grpc.ClientStream
|
||||
messages []*proto.GetMappingUpdateResponse
|
||||
idx int32
|
||||
}
|
||||
|
||||
func (g *gatedMappingStream) Recv() (*proto.GetMappingUpdateResponse, error) {
|
||||
i := int(atomic.LoadInt32(&g.idx))
|
||||
if i >= len(g.messages) {
|
||||
// Block instead of returning EOF so the loop doesn't exit; we only
|
||||
// care whether the loop ever reaches this second Recv at all.
|
||||
select {}
|
||||
}
|
||||
msg := g.messages[i]
|
||||
atomic.AddInt32(&g.idx, 1)
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (g *gatedMappingStream) deliveredCount() int32 { return atomic.LoadInt32(&g.idx) }
|
||||
|
||||
func (g *gatedMappingStream) Header() (metadata.MD, error) { return nil, nil } //nolint:nilnil
|
||||
func (g *gatedMappingStream) Trailer() metadata.MD { return nil }
|
||||
func (g *gatedMappingStream) CloseSend() error { return nil }
|
||||
func (g *gatedMappingStream) Context() context.Context { return context.Background() }
|
||||
func (g *gatedMappingStream) SendMsg(any) error { return nil }
|
||||
func (g *gatedMappingStream) RecvMsg(any) error { return nil }
|
||||
|
||||
// noopNotifier satisfies roundtrip's statusNotifier interface.
|
||||
type noopNotifier struct{}
|
||||
|
||||
func (noopNotifier) NotifyStatus(context.Context, types.AccountID, types.ServiceID, bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// noopProxyClient is a proto.ProxyServiceClient that no-ops the one method the
|
||||
// teardown unwind reaches (SendStatusUpdate, via notifyError when the parked
|
||||
// AddPeer is cancelled). The embedded nil interface satisfies the rest at
|
||||
// compile time; none of those methods are called by this test.
|
||||
type noopProxyClient struct {
|
||||
proto.ProxyServiceClient
|
||||
}
|
||||
|
||||
func (noopProxyClient) SendStatusUpdate(context.Context, *proto.SendStatusUpdateRequest, ...grpc.CallOption) (*proto.SendStatusUpdateResponse, error) {
|
||||
return &proto.SendStatusUpdateResponse{}, nil
|
||||
}
|
||||
|
||||
// TestMappingStream_StallsWhenApplyBlocks proves the deadlock: the proxy's
|
||||
// mapping receive loop processes batches strictly serially, so when applying
|
||||
// one batch blocks (here: createClientEntry parked on a synchronous
|
||||
// CreateProxyPeer call, exactly as observed in production), the loop never
|
||||
// advances to Recv the next batch. Management can keep sending updates onto
|
||||
// the stream with no error and no channel overflow, yet the proxy applies
|
||||
// nothing further — it is stuck.
|
||||
func TestMappingStream_StallsWhenApplyBlocks(t *testing.T) {
|
||||
logger := log.New()
|
||||
logger.SetLevel(log.PanicLevel)
|
||||
|
||||
mgmt := &blockingMgmtClient{
|
||||
entered: make(chan struct{}),
|
||||
}
|
||||
|
||||
nb := roundtrip.NewNetBird(
|
||||
context.Background(),
|
||||
"proxy-test",
|
||||
"proxy.example.com",
|
||||
roundtrip.ClientConfig{},
|
||||
logger,
|
||||
noopNotifier{},
|
||||
mgmt,
|
||||
)
|
||||
|
||||
s := &Server{
|
||||
Logger: logger,
|
||||
netbird: nb,
|
||||
mgmtClient: noopProxyClient{},
|
||||
routerReady: closedChan(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
}
|
||||
|
||||
// First batch: a CREATED mapping for a brand-new account. addMapping ->
|
||||
// netbird.AddPeer -> createClientEntry -> CreateProxyPeer, which blocks.
|
||||
// Empty Path keeps setupHTTPMapping a no-op (it returns early), so the
|
||||
// ONLY blocking point is the synchronous CreateProxyPeer in AddPeer —
|
||||
// no routers/auth need wiring. The second batch exists only to detect
|
||||
// whether the loop ever advances past the blocked first batch.
|
||||
stream := &gatedMappingStream{
|
||||
messages: []*proto.GetMappingUpdateResponse{
|
||||
{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
Id: "svc-1",
|
||||
AccountId: "acct-1",
|
||||
AuthToken: "token-1",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
Id: "svc-2",
|
||||
AccountId: "acct-2",
|
||||
AuthToken: "token-2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
// Unblock the parked apply on teardown via ctx (CreateProxyPeer returns
|
||||
// ctx.Err()), so the wedged loop goroutine unwinds before embed.New —
|
||||
// avoiding any dependency on collaborators this test deliberately leaves
|
||||
// nil. The deadlock is fully proven before this fires.
|
||||
t.Cleanup(cancel)
|
||||
|
||||
loopDone := make(chan struct{})
|
||||
syncDone := false
|
||||
go func() {
|
||||
defer close(loopDone)
|
||||
_ = s.handleMappingStream(ctx, stream, &syncDone, time.Time{})
|
||||
}()
|
||||
|
||||
// The loop must reach the blocking apply for the first batch.
|
||||
select {
|
||||
case <-mgmt.entered:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("receive loop never reached CreateProxyPeer for the first batch")
|
||||
}
|
||||
|
||||
// THE DEADLOCK: while the first batch is parked in CreateProxyPeer, the
|
||||
// single-threaded loop cannot advance. The second batch is never pulled,
|
||||
// even though it is already available on the stream. Give it ample time.
|
||||
// deliveredCount is atomic; syncDone is intentionally not read here because
|
||||
// the loop goroutine owns it (reading it from the test would race).
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
assert.Equal(t, int32(1), stream.deliveredCount(),
|
||||
"loop must NOT consume the second batch while the first is blocked in apply — proxy is stuck")
|
||||
|
||||
select {
|
||||
case <-loopDone:
|
||||
t.Fatal("receive loop returned while it should be wedged in apply")
|
||||
default:
|
||||
// Still wedged, as expected.
|
||||
}
|
||||
}
|
||||
|
||||
// TestMappingStream_StallsWhenRemoveBlocks proves the deadlock for the REMOVE
|
||||
// path observed in production: a mapping remove tears down the account's last
|
||||
// embedded client via netbird.RemovePeer -> client.Stop -> Engine.Stop, whose
|
||||
// jobExecutorWG.Wait() is unbounded. Because the receive loop is single-
|
||||
// threaded, a blocked remove wedges the loop: no further mapping updates of any
|
||||
// kind (create/modify/remove) are applied, while management keeps sending them
|
||||
// successfully (no send error, no channel-full). Matches the reported symptom:
|
||||
// the last log line is a remove that stops a client, then silence.
|
||||
func TestMappingStream_StallsWhenRemoveBlocks(t *testing.T) {
|
||||
logger := log.New()
|
||||
logger.SetLevel(log.PanicLevel)
|
||||
|
||||
enteredRemove := make(chan struct{})
|
||||
blockRemove := make(chan struct{})
|
||||
var once sync.Once
|
||||
|
||||
s := &Server{
|
||||
Logger: logger,
|
||||
mgmtClient: noopProxyClient{},
|
||||
routerReady: closedChan(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
// Stand in for netbird.RemovePeer -> client.Stop hanging on
|
||||
// Engine.Stop's unbounded jobExecutorWG.Wait(). Only the first remove
|
||||
// blocks; later removes return immediately so the recovery assertion
|
||||
// can observe the loop advancing.
|
||||
removePeer: func(ctx context.Context, _ types.AccountID, _ roundtrip.ServiceKey) error {
|
||||
first := false
|
||||
once.Do(func() {
|
||||
first = true
|
||||
close(enteredRemove)
|
||||
})
|
||||
if !first {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-blockRemove:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// Batch 1 removes a service (blocks in teardown). Batch 2 is a later update
|
||||
// that must never be applied while the remove is wedged.
|
||||
stream := &gatedMappingStream{
|
||||
messages: []*proto.GetMappingUpdateResponse{
|
||||
{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
{Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, Id: "svc-1", AccountId: "acct-1"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
{Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, Id: "svc-2", AccountId: "acct-1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
loopDone := make(chan struct{})
|
||||
syncDone := false
|
||||
go func() {
|
||||
defer close(loopDone)
|
||||
_ = s.handleMappingStream(context.Background(), stream, &syncDone, time.Time{})
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-enteredRemove:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("receive loop never reached the blocking remove for the first batch")
|
||||
}
|
||||
|
||||
// THE DEADLOCK: the loop is parked in the blocked remove and cannot advance.
|
||||
// syncDone is owned by the loop goroutine, so it is not read here.
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
assert.Equal(t, int32(1), stream.deliveredCount(),
|
||||
"loop must NOT consume the second batch while the first remove is blocked — proxy is stuck")
|
||||
|
||||
select {
|
||||
case <-loopDone:
|
||||
t.Fatal("receive loop returned while it should be wedged on the remove")
|
||||
default:
|
||||
}
|
||||
|
||||
// Unblock and confirm the wedge was solely the blocked remove: the loop
|
||||
// then advances and consumes the next batch.
|
||||
close(blockRemove)
|
||||
assert.Eventually(t, func() bool {
|
||||
return stream.deliveredCount() >= 2
|
||||
}, 2*time.Second, 5*time.Millisecond,
|
||||
"once the remove unblocks, the loop must advance and consume the next batch")
|
||||
}
|
||||
@@ -118,6 +118,9 @@ type Server struct {
|
||||
// The mapping worker waits on this before processing updates.
|
||||
routerReady chan struct{}
|
||||
|
||||
// removePeer defaults to netbird.RemovePeer; overridable in tests.
|
||||
removePeer func(ctx context.Context, accountID types.AccountID, key roundtrip.ServiceKey) error
|
||||
|
||||
// inbound, when non-nil, manages per-account inbound listeners. Set by
|
||||
// initPrivateInbound only when Private is true so the standalone
|
||||
// proxy keeps its zero-overhead default path.
|
||||
@@ -227,6 +230,10 @@ type Server struct {
|
||||
// Zero means no cap (the proxy honors whatever management sends).
|
||||
// Set via NB_PROXY_MAX_SESSION_IDLE_TIMEOUT for shared deployments.
|
||||
MaxSessionIdleTimeout time.Duration
|
||||
// MappingBatchWatchdog bounds how long a single mapping batch may spend
|
||||
// in processMappings before the receive loop reconnects to resync.
|
||||
// Zero uses defaultMappingBatchWatchdog.
|
||||
MappingBatchWatchdog time.Duration
|
||||
}
|
||||
|
||||
// clampIdleTimeout returns d capped to MaxSessionIdleTimeout when configured.
|
||||
@@ -1172,24 +1179,30 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
|
||||
s.healthChecker.SetManagementConnected(false)
|
||||
}
|
||||
|
||||
connected := false
|
||||
onConnected := func() { connected = true }
|
||||
|
||||
var streamErr error
|
||||
if syncSupported {
|
||||
streamErr = s.trySyncMappings(ctx, client, &initialSyncDone)
|
||||
streamErr = s.trySyncMappings(ctx, client, &initialSyncDone, onConnected)
|
||||
if isSyncUnimplemented(streamErr) {
|
||||
syncSupported = false
|
||||
s.Logger.Info("management does not support SyncMappings, falling back to GetMappingUpdate")
|
||||
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone)
|
||||
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone, onConnected)
|
||||
}
|
||||
} else {
|
||||
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone)
|
||||
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone, onConnected)
|
||||
}
|
||||
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetManagementConnected(false)
|
||||
}
|
||||
|
||||
// Stream established — reset backoff so the next failure retries quickly.
|
||||
bo.Reset()
|
||||
// Reset backoff only when a stream actually connected, so immediate
|
||||
// connect failures still back off instead of spinning.
|
||||
if connected {
|
||||
bo.Reset()
|
||||
}
|
||||
|
||||
if streamErr == nil {
|
||||
return fmt.Errorf("stream closed by server")
|
||||
@@ -1221,7 +1234,7 @@ func (s *Server) proxyCapabilities() *proto.ProxyCapabilities {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool) error {
|
||||
func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool, onConnected func()) error {
|
||||
connectTime := time.Now()
|
||||
mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: s.ID,
|
||||
@@ -1234,6 +1247,7 @@ func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServ
|
||||
return fmt.Errorf("create mapping stream: %w", err)
|
||||
}
|
||||
|
||||
onConnected()
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetManagementConnected(true)
|
||||
}
|
||||
@@ -1242,7 +1256,7 @@ func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServ
|
||||
return s.handleMappingStream(ctx, mappingClient, initialSyncDone, connectTime)
|
||||
}
|
||||
|
||||
func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool) error {
|
||||
func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool, onConnected func()) error {
|
||||
connectTime := time.Now()
|
||||
stream, err := client.SyncMappings(ctx)
|
||||
if err != nil {
|
||||
@@ -1263,6 +1277,7 @@ func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceC
|
||||
return fmt.Errorf("send sync init: %w", err)
|
||||
}
|
||||
|
||||
onConnected()
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetManagementConnected(true)
|
||||
}
|
||||
@@ -1307,7 +1322,9 @@ func (s *Server) handleSyncMappingsStream(ctx context.Context, stream proto.Prox
|
||||
|
||||
batchStart := time.Now()
|
||||
s.Logger.Debug("Received mapping update, starting processing")
|
||||
s.processMappings(ctx, msg.GetMapping())
|
||||
if err := s.processMappingsGuarded(ctx, msg.GetMapping()); err != nil {
|
||||
return err
|
||||
}
|
||||
s.Logger.Debug("Processing mapping update completed")
|
||||
tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart)
|
||||
|
||||
@@ -1391,7 +1408,9 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
|
||||
|
||||
batchStart := time.Now()
|
||||
s.Logger.Debug("Received mapping update, starting processing")
|
||||
s.processMappings(ctx, msg.GetMapping())
|
||||
if err := s.processMappingsGuarded(ctx, msg.GetMapping()); err != nil {
|
||||
return err
|
||||
}
|
||||
s.Logger.Debug("Processing mapping update completed")
|
||||
tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart)
|
||||
}
|
||||
@@ -1456,6 +1475,44 @@ func redactMappingForLog(m *proto.ProxyMapping) *proto.ProxyMapping {
|
||||
return c
|
||||
}
|
||||
|
||||
const defaultMappingBatchWatchdog = 2 * time.Minute
|
||||
|
||||
// mappingBatchWatchdog returns the configured batch watchdog or the default.
|
||||
func (s *Server) mappingBatchWatchdog() time.Duration {
|
||||
if s.MappingBatchWatchdog > 0 {
|
||||
return s.MappingBatchWatchdog
|
||||
}
|
||||
return defaultMappingBatchWatchdog
|
||||
}
|
||||
|
||||
// processMappingsGuarded applies a batch under a watchdog, returning an error
|
||||
// if processing exceeds the watchdog so the caller reconnects and resyncs
|
||||
// instead of wedging silently.
|
||||
func (s *Server) processMappingsGuarded(ctx context.Context, mappings []*proto.ProxyMapping) error {
|
||||
batchCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
s.processMappings(batchCtx, mappings)
|
||||
}()
|
||||
|
||||
watchdog := s.mappingBatchWatchdog()
|
||||
timer := time.NewTimer(watchdog)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
s.Logger.Errorf("processing mapping batch exceeded %s, cancelling and reconnecting to resync", watchdog)
|
||||
return fmt.Errorf("mapping batch processing stalled after %s", watchdog)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
|
||||
debug := s.Logger != nil && s.Logger.IsLevelEnabled(log.DebugLevel)
|
||||
for _, mapping := range mappings {
|
||||
@@ -1951,7 +2008,11 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
|
||||
func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping) {
|
||||
accountID := types.AccountID(mapping.GetAccountId())
|
||||
svcKey := s.serviceKeyForMapping(mapping)
|
||||
if err := s.netbird.RemovePeer(ctx, accountID, svcKey); err != nil {
|
||||
removePeer := s.removePeer
|
||||
if removePeer == nil {
|
||||
removePeer = s.netbird.RemovePeer
|
||||
}
|
||||
if err := removePeer(ctx, accountID, svcKey); err != nil {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_id": mapping.GetId(),
|
||||
|
||||
@@ -417,15 +417,30 @@ if type uname >/dev/null 2>&1; then
|
||||
# Check the availability of a compatible package manager
|
||||
if check_use_bin_variable; then
|
||||
PACKAGE_MANAGER="bin"
|
||||
elif [ -e /run/ostree-booted ]; then
|
||||
if [ -x "$(command -v rpm-ostree)" ]; then
|
||||
PACKAGE_MANAGER="rpm-ostree"
|
||||
echo "The installation will be performed using rpm-ostree package manager"
|
||||
elif [ -x "$(command -v bootc)" ]; then
|
||||
echo "Detected bootc system without rpm-ostree." >&2
|
||||
echo "NetBird cannot be installed via package manager on this system." >&2
|
||||
echo "Options:" >&2
|
||||
echo " 1. Install via Distrobox (instructions in the installation docs)" >&2
|
||||
echo " 2. Rebuild your base image with rpm-ostree included" >&2
|
||||
echo " 3. Bake NetBird into your Containerfile" >&2
|
||||
exit 1
|
||||
else
|
||||
echo "Detected ostree-booted system without rpm-ostree or bootc." >&2
|
||||
echo "NetBird cannot be installed automatically on this atomic system." >&2
|
||||
echo "Please install NetBird by rebuilding your base image or use a supported package manager." >&2
|
||||
exit 1
|
||||
fi
|
||||
elif [ -x "$(command -v apt-get)" ]; then
|
||||
PACKAGE_MANAGER="apt"
|
||||
echo "The installation will be performed using apt package manager"
|
||||
elif [ -x "$(command -v dnf)" ]; then
|
||||
PACKAGE_MANAGER="dnf"
|
||||
echo "The installation will be performed using dnf package manager"
|
||||
elif [ -x "$(command -v rpm-ostree)" ]; then
|
||||
PACKAGE_MANAGER="rpm-ostree"
|
||||
echo "The installation will be performed using rpm-ostree package manager"
|
||||
elif [ -x "$(command -v yum)" ]; then
|
||||
PACKAGE_MANAGER="yum"
|
||||
echo "The installation will be performed using yum package manager"
|
||||
|
||||
@@ -6,4 +6,5 @@ const (
|
||||
RoleKey = "role"
|
||||
UserIDKey = "userID"
|
||||
PeerIDKey = "peerID"
|
||||
UserAgentKey = "userAgent"
|
||||
)
|
||||
|
||||
@@ -5107,31 +5107,63 @@ components:
|
||||
responses:
|
||||
not_found:
|
||||
description: Resource not found
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content: { }
|
||||
validation_failed_simple:
|
||||
description: Validation failed
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content: { }
|
||||
bad_request:
|
||||
description: Bad Request
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content: { }
|
||||
internal_error:
|
||||
description: Internal Server Error
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content: { }
|
||||
validation_failed:
|
||||
description: Validation failed
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content: { }
|
||||
forbidden:
|
||||
description: Forbidden
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content: { }
|
||||
requires_authentication:
|
||||
description: Requires authentication
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content: { }
|
||||
conflict:
|
||||
description: Conflict
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
headers:
|
||||
X-Request-Id:
|
||||
description: |
|
||||
Unique identifier assigned to the request by the server and set on every
|
||||
response. Useful for correlating client requests with server-side logs.
|
||||
schema:
|
||||
type: string
|
||||
example: cot7r4n3l3vh3qj4qveg
|
||||
securitySchemes:
|
||||
BearerAuth:
|
||||
type: http
|
||||
|
||||
@@ -9,12 +9,14 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer"
|
||||
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
|
||||
"github.com/netbirdio/netbird/shared/relay/healthcheck"
|
||||
"github.com/netbirdio/netbird/shared/relay/messages"
|
||||
)
|
||||
@@ -172,6 +174,19 @@ type Client struct {
|
||||
stateSubscription *PeersStateSubscription
|
||||
|
||||
mtu uint16
|
||||
|
||||
// transportFallback, when set, records datagram-too-large failures so a
|
||||
// datagram-sized transport is avoided on subsequent connects. Shared via
|
||||
// the manager.
|
||||
transportFallback *transportFallback
|
||||
// datagramFallbackTriggered guards a single fallback per connection so a
|
||||
// burst of oversized datagrams triggers one reconnect, not many.
|
||||
datagramFallbackTriggered atomic.Bool
|
||||
}
|
||||
|
||||
// SetTransportFallback wires the shared datagram-transport fallback tracker.
|
||||
func (c *Client) SetTransportFallback(tf *transportFallback) {
|
||||
c.transportFallback = tf
|
||||
}
|
||||
|
||||
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
|
||||
@@ -361,12 +376,13 @@ func (c *Client) Close() error {
|
||||
}
|
||||
|
||||
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
dialers := c.getDialers()
|
||||
mode := transportModeFromEnv()
|
||||
dialers := c.getDialers(mode)
|
||||
|
||||
var conn net.Conn
|
||||
if c.serverIP.IsValid() {
|
||||
var err error
|
||||
conn, err = c.dialRaceDirect(ctx, dialers)
|
||||
conn, err = c.dialRaceDirect(ctx, mode, dialers)
|
||||
if err != nil {
|
||||
c.log.Infof("dial via server IP %s failed, falling back to FQDN: %v", c.serverIP, err)
|
||||
conn = nil
|
||||
@@ -375,6 +391,9 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
|
||||
if conn == nil {
|
||||
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...)
|
||||
if mode.sequential() {
|
||||
rd.WithSequential()
|
||||
}
|
||||
var err error
|
||||
conn, err = rd.Dial(ctx)
|
||||
if err != nil {
|
||||
@@ -382,6 +401,7 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
}
|
||||
}
|
||||
c.relayConn = conn
|
||||
c.datagramFallbackTriggered.Store(false)
|
||||
|
||||
instanceURL, err := c.handShake(ctx)
|
||||
if err != nil {
|
||||
@@ -396,7 +416,7 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
}
|
||||
|
||||
// dialRaceDirect dials c.serverIP, preserving the original FQDN as the TLS ServerName for SNI.
|
||||
func (c *Client) dialRaceDirect(ctx context.Context, dialers []dialer.DialeFn) (net.Conn, error) {
|
||||
func (c *Client) dialRaceDirect(ctx context.Context, mode TransportMode, dialers []dialer.DialeFn) (net.Conn, error) {
|
||||
directURL, serverName, err := substituteHost(c.connectionURL, c.serverIP)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("substitute host: %w", err)
|
||||
@@ -406,6 +426,9 @@ func (c *Client) dialRaceDirect(ctx context.Context, dialers []dialer.DialeFn) (
|
||||
|
||||
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, directURL, dialers...).
|
||||
WithServerName(serverName)
|
||||
if mode.sequential() {
|
||||
rd.WithSequential()
|
||||
}
|
||||
return rd.Dial(ctx)
|
||||
}
|
||||
|
||||
@@ -631,13 +654,53 @@ func (c *Client) writeTo(containerRef *connContainer, dstID messages.PeerID, pay
|
||||
}
|
||||
|
||||
// the write always return with 0 length because the underling does not support the size feedback.
|
||||
_, err = c.relayConn.Write(msg)
|
||||
conn := c.relayConn
|
||||
_, err = conn.Write(msg)
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to write transport message: %s", err)
|
||||
if errors.Is(err, netErr.ErrDatagramTooLarge) {
|
||||
c.onDatagramTooLarge(conn, err)
|
||||
} else {
|
||||
c.log.Errorf("failed to write transport message: %s", err)
|
||||
}
|
||||
}
|
||||
return len(payload), err
|
||||
}
|
||||
|
||||
// onDatagramTooLarge reacts to a datagram rejected as too large for the path.
|
||||
// When a non-datagram transport is available, it records a fallback for this
|
||||
// server and closes the connection so the reconnect avoids datagram-sized
|
||||
// transports. A single fallback is triggered per connection regardless of how
|
||||
// many oversized datagrams arrive. cause carries the datagram size and budget.
|
||||
func (c *Client) onDatagramTooLarge(conn net.Conn, cause error) {
|
||||
// Handle one oversized datagram per connection; a burst triggers a single
|
||||
// fallback (and a single log line), not many.
|
||||
if !c.datagramFallbackTriggered.CompareAndSwap(false, true) {
|
||||
return
|
||||
}
|
||||
|
||||
// If the selected mode offers no non-datagram transport (e.g. pinned to a
|
||||
// datagram-sized transport), reconnecting would just re-fail, so leave the
|
||||
// connection up rather than loop.
|
||||
if len(nonDatagramSized(c.baseDialers(transportModeFromEnv()))) == 0 {
|
||||
c.log.Warnf("%s, but no non-datagram transport is available, not falling back", cause)
|
||||
return
|
||||
}
|
||||
|
||||
// Without the shared tracker a reconnect would just select the same
|
||||
// transport again and re-fail, so leave the connection up rather than loop.
|
||||
if c.transportFallback == nil {
|
||||
c.log.Debugf("%s, but no transport fallback configured, leaving connection up", cause)
|
||||
return
|
||||
}
|
||||
|
||||
window := c.transportFallback.recordFailure(c.connectionURL)
|
||||
c.log.Warnf("%s, avoiding datagram-sized transport for %s", cause, window)
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
c.log.Debugf("close relay connection for transport fallback: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) listenForStopEvents(ctx context.Context, hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
|
||||
for {
|
||||
select {
|
||||
|
||||
18
shared/relay/client/dialer/capability.go
Normal file
18
shared/relay/client/dialer/capability.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package dialer
|
||||
|
||||
// DatagramSized is implemented by dialers whose connections carry each write in
|
||||
// a single datagram, so a write can be rejected when it exceeds the path's
|
||||
// datagram budget (e.g. QUIC). Transports without this capability (e.g.
|
||||
// WebSocket over TCP) impose no per-write size limit, so the relay client can
|
||||
// fall back to them when a datagram-sized transport rejects a write as too
|
||||
// large. The capability is advertised per dialer rather than hardcoded, so a
|
||||
// new transport only needs to declare whether it is datagram-sized.
|
||||
type DatagramSized interface {
|
||||
DatagramSized()
|
||||
}
|
||||
|
||||
// IsDatagramSized reports whether d produces datagram-sized connections.
|
||||
func IsDatagramSized(d DialeFn) bool {
|
||||
_, ok := d.(DatagramSized)
|
||||
return ok
|
||||
}
|
||||
@@ -4,4 +4,9 @@ import "errors"
|
||||
|
||||
var (
|
||||
ErrClosedByServer = errors.New("closed by server")
|
||||
|
||||
// ErrDatagramTooLarge is returned when a transport message exceeds the
|
||||
// QUIC datagram size the path to the relay can carry. The relay client
|
||||
// treats it as a signal to fall back to a non-datagram transport.
|
||||
ErrDatagramTooLarge = errors.New("datagram frame too large")
|
||||
)
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
|
||||
)
|
||||
@@ -52,11 +51,8 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (int, error) {
|
||||
err := c.session.SendDatagram(b)
|
||||
if err != nil {
|
||||
err = c.remoteCloseErrHandling(err)
|
||||
log.Errorf("failed to write to QUIC stream: %v", err)
|
||||
return 0, err
|
||||
if err := c.session.SendDatagram(b); err != nil {
|
||||
return 0, c.writeErrHandling(err, len(b))
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
@@ -95,3 +91,15 @@ func (c *Conn) remoteCloseErrHandling(err error) error {
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// writeErrHandling normalizes SendDatagram errors. A datagram that exceeds the
|
||||
// path's QUIC packet budget is mapped to ErrDatagramTooLarge (annotated with the
|
||||
// datagram size and path budget) so the relay client can fall back to a
|
||||
// non-datagram transport.
|
||||
func (c *Conn) writeErrHandling(err error, size int) error {
|
||||
var tooLarge *quic.DatagramTooLargeError
|
||||
if errors.As(err, &tooLarge) {
|
||||
return fmt.Errorf("%w: %d byte datagram over path budget %d", netErr.ErrDatagramTooLarge, size, tooLarge.MaxDatagramPayloadSize)
|
||||
}
|
||||
return c.remoteCloseErrHandling(err)
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
@@ -23,6 +24,12 @@ func (d Dialer) Protocol() string {
|
||||
return Network
|
||||
}
|
||||
|
||||
// DatagramSized marks QUIC as a datagram-sized transport: relay traffic is
|
||||
// carried in QUIC DATAGRAM frames, which must fit a single packet.
|
||||
func (d Dialer) DatagramSized() {
|
||||
// Intentional marker method; presence is the capability signal.
|
||||
}
|
||||
|
||||
func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, error) {
|
||||
quicURL, err := prepareURL(address)
|
||||
if err != nil {
|
||||
@@ -47,6 +54,7 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn,
|
||||
MaxIdleTimeout: 4 * time.Minute,
|
||||
EnableDatagrams: true,
|
||||
InitialPacketSize: nbRelay.QUICInitialPacketSize,
|
||||
Tracer: connectionTracer(quicURL),
|
||||
}
|
||||
|
||||
udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{Port: 0})
|
||||
@@ -74,6 +82,28 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn,
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// connectionTracer returns a QUIC tracer that logs the DPLPMTUD result and the
|
||||
// reason a relay connection closed, so the path MTU settled on and teardown
|
||||
// cause are visible in logs. Lines carry the relay address as a structured
|
||||
// field, matching the rest of the relay client logging.
|
||||
func connectionTracer(addr string) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||
relayLog := log.WithField("relay", addr)
|
||||
return func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||
return &logging.ConnectionTracer{
|
||||
UpdatedMTU: func(mtu logging.ByteCount, done bool) {
|
||||
if done {
|
||||
relayLog.Infof("QUIC path MTU settled at %d", mtu)
|
||||
return
|
||||
}
|
||||
relayLog.Debugf("QUIC path MTU probing at %d", mtu)
|
||||
},
|
||||
ClosedConnection: func(err error) {
|
||||
relayLog.Debugf("QUIC connection closed: %v", err)
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func prepareURL(address string) (string, error) {
|
||||
var host string
|
||||
var defaultPort string
|
||||
|
||||
@@ -32,6 +32,7 @@ type RaceDial struct {
|
||||
serverName string
|
||||
dialerFns []DialeFn
|
||||
connectionTimeout time.Duration
|
||||
sequential bool
|
||||
}
|
||||
|
||||
func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL string, dialerFns ...DialeFn) *RaceDial {
|
||||
@@ -53,7 +54,21 @@ func (r *RaceDial) WithServerName(serverName string) *RaceDial {
|
||||
return r
|
||||
}
|
||||
|
||||
// WithSequential makes Dial try the dialers in order, falling back to the next
|
||||
// only when one fails to connect, instead of racing them concurrently.
|
||||
//
|
||||
// Mutates the receiver and is not safe for concurrent reconfiguration; a
|
||||
// RaceDial is intended to be constructed per dial and discarded.
|
||||
func (r *RaceDial) WithSequential() *RaceDial {
|
||||
r.sequential = true
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) {
|
||||
if r.sequential {
|
||||
return r.dialSequential(ctx)
|
||||
}
|
||||
|
||||
connChan := make(chan dialResult, len(r.dialerFns))
|
||||
winnerConn := make(chan net.Conn, 1)
|
||||
abortCtx, abort := context.WithCancel(ctx)
|
||||
@@ -72,6 +87,30 @@ func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) {
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// dialSequential tries each dialer in order, returning the first connection and
|
||||
// falling back to the next on failure.
|
||||
func (r *RaceDial) dialSequential(ctx context.Context) (net.Conn, error) {
|
||||
for _, dfn := range r.dialerFns {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attemptCtx, cancel := context.WithTimeout(ctx, r.connectionTimeout)
|
||||
r.log.Infof("dialing Relay server via %s", dfn.Protocol())
|
||||
conn, err := dfn.Dial(attemptCtx, r.serverURL, r.serverName)
|
||||
cancel()
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, err
|
||||
}
|
||||
r.log.Errorf("failed to dial via %s: %s", dfn.Protocol(), err)
|
||||
continue
|
||||
}
|
||||
r.log.Infof("successfully dialed via: %s", dfn.Protocol())
|
||||
return conn, nil
|
||||
}
|
||||
return nil, errors.New("failed to dial to Relay server on any protocol")
|
||||
}
|
||||
|
||||
func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) {
|
||||
ctx, cancel := context.WithTimeout(abortCtx, r.connectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
@@ -250,3 +250,66 @@ func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRaceDialSequentialFallback(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
|
||||
var firstDialed, secondDialed bool
|
||||
preferred := &MockDialer{
|
||||
protocolStr: "quic",
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
firstDialed = true
|
||||
return nil, errors.New("quic unreachable")
|
||||
},
|
||||
}
|
||||
fallbackConn := &MockConn{remoteAddr: &MockAddr{network: "ws"}}
|
||||
fallback := &MockDialer{
|
||||
protocolStr: "ws",
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
secondDialed = true
|
||||
return fallbackConn, nil
|
||||
},
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, preferred, fallback).WithSequential()
|
||||
conn, err := rd.Dial(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("expected fallback to succeed, got %v", err)
|
||||
}
|
||||
if conn != fallbackConn {
|
||||
t.Errorf("expected fallback connection, got %v", conn)
|
||||
}
|
||||
if !firstDialed || !secondDialed {
|
||||
t.Errorf("expected both dialers attempted in order, first=%v second=%v", firstDialed, secondDialed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRaceDialSequentialPreferredWins(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
|
||||
preferredConn := &MockConn{remoteAddr: &MockAddr{network: "quic"}}
|
||||
preferred := &MockDialer{
|
||||
protocolStr: "quic",
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
return preferredConn, nil
|
||||
},
|
||||
}
|
||||
fallback := &MockDialer{
|
||||
protocolStr: "ws",
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
t.Errorf("fallback dialer must not be tried when preferred succeeds")
|
||||
return nil, errors.New("should not happen")
|
||||
},
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, preferred, fallback).WithSequential()
|
||||
conn, err := rd.Dial(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("expected preferred to succeed, got %v", err)
|
||||
}
|
||||
if conn != preferredConn {
|
||||
t.Errorf("expected preferred connection, got %v", conn)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,11 +9,42 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
|
||||
)
|
||||
|
||||
// getDialers returns the list of dialers to use for connecting to the relay server.
|
||||
func (c *Client) getDialers() []dialer.DialeFn {
|
||||
if c.mtu > 0 && c.mtu > iface.DefaultMTU {
|
||||
c.log.Infof("MTU %d exceeds default (%d), forcing WebSocket transport to avoid DATAGRAM frame size issues", c.mtu, iface.DefaultMTU)
|
||||
return []dialer.DialeFn{ws.Dialer{}}
|
||||
// getDialers returns the ordered dialers for connecting to the relay server. It
|
||||
// applies the datagram fallback generically: if this server recently rejected a
|
||||
// datagram-sized transport, those dialers are dropped, leaving the rest.
|
||||
func (c *Client) getDialers(mode TransportMode) []dialer.DialeFn {
|
||||
dialers := c.baseDialers(mode)
|
||||
|
||||
if c.transportFallback != nil && c.transportFallback.avoidDatagramSized(c.connectionURL) {
|
||||
if filtered := nonDatagramSized(dialers); len(filtered) > 0 {
|
||||
c.log.Infof("relay recently rejected a datagram-sized transport, avoiding it")
|
||||
return filtered
|
||||
}
|
||||
}
|
||||
return []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}}
|
||||
return dialers
|
||||
}
|
||||
|
||||
// baseDialers returns the ordered dialers for the mode, before any datagram
|
||||
// fallback filtering. For racing modes (auto) the order is irrelevant; for
|
||||
// prefer modes the first entry is tried before falling back to the second.
|
||||
func (c *Client) baseDialers(mode TransportMode) []dialer.DialeFn {
|
||||
switch mode {
|
||||
case TransportModeWS:
|
||||
c.log.Infof("%s=ws, using WebSocket transport", EnvRelayTransport)
|
||||
return []dialer.DialeFn{ws.Dialer{}}
|
||||
case TransportModeQUIC:
|
||||
c.log.Infof("%s=quic, using QUIC transport", EnvRelayTransport)
|
||||
return []dialer.DialeFn{quic.Dialer{}}
|
||||
}
|
||||
|
||||
all := []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}}
|
||||
if mode == TransportModePreferWS {
|
||||
all = []dialer.DialeFn{ws.Dialer{}, quic.Dialer{}}
|
||||
}
|
||||
|
||||
if c.mtu > 0 && c.mtu > iface.DefaultMTU {
|
||||
c.log.Infof("MTU %d exceeds default (%d), avoiding datagram-sized transports", c.mtu, iface.DefaultMTU)
|
||||
return nonDatagramSized(all)
|
||||
}
|
||||
return all
|
||||
}
|
||||
|
||||
101
shared/relay/client/dialers_generic_test.go
Normal file
101
shared/relay/client/dialers_generic_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
//go:build !js
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer"
|
||||
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer/quic"
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
|
||||
)
|
||||
|
||||
// TestDatagramSizedCapability locks the capability the generic fallback relies
|
||||
// on: QUIC is datagram-sized, WebSocket is not.
|
||||
func TestDatagramSizedCapability(t *testing.T) {
|
||||
assert.True(t, dialer.IsDatagramSized(quic.Dialer{}), "QUIC must advertise datagram-sized")
|
||||
assert.False(t, dialer.IsDatagramSized(ws.Dialer{}), "WebSocket must not advertise datagram-sized")
|
||||
}
|
||||
|
||||
func protocols(dialers []dialer.DialeFn) []string {
|
||||
out := make([]string, len(dialers))
|
||||
for i, d := range dialers {
|
||||
out[i] = d.Protocol()
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestGetDialers(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mode string
|
||||
mtu uint16
|
||||
preferWS bool
|
||||
want []string
|
||||
}{
|
||||
{name: "auto races quic and ws", mode: "auto", mtu: iface.DefaultMTU, want: []string{"quic", "WS"}},
|
||||
{name: "ws pinned", mode: "ws", mtu: iface.DefaultMTU, want: []string{"WS"}},
|
||||
{name: "quic pinned", mode: "quic", mtu: iface.DefaultMTU, want: []string{"quic"}},
|
||||
{name: "prefer-quic orders quic first", mode: "prefer-quic", mtu: iface.DefaultMTU, want: []string{"quic", "WS"}},
|
||||
{name: "prefer-ws orders ws first", mode: "prefer-ws", mtu: iface.DefaultMTU, want: []string{"WS", "quic"}},
|
||||
{name: "mtu above default forces ws", mode: "auto", mtu: iface.DefaultMTU + 100, want: []string{"WS"}},
|
||||
{name: "sticky fallback forces ws in auto", mode: "auto", mtu: iface.DefaultMTU, preferWS: true, want: []string{"WS"}},
|
||||
{name: "sticky fallback forces ws in prefer-quic", mode: "prefer-quic", mtu: iface.DefaultMTU, preferWS: true, want: []string{"WS"}},
|
||||
{name: "quic pin overrides sticky fallback", mode: "quic", mtu: iface.DefaultMTU, preferWS: true, want: []string{"quic"}},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv(EnvRelayTransport, tc.mode)
|
||||
if tc.mode == "" {
|
||||
os.Unsetenv(EnvRelayTransport)
|
||||
}
|
||||
|
||||
tf := newTransportFallback()
|
||||
if tc.preferWS {
|
||||
tf.recordFailure(url)
|
||||
}
|
||||
|
||||
c := &Client{
|
||||
log: log.WithField("test", t.Name()),
|
||||
connectionURL: url,
|
||||
mtu: tc.mtu,
|
||||
transportFallback: tf,
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.want, protocols(c.getDialers(transportModeFromEnv())))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStickyFallbackAfterDatagramTooLarge verifies the full chain: an oversized
|
||||
// datagram records a fallback that makes the next dial pick WebSocket, the way a
|
||||
// reconnect would after the connection is closed.
|
||||
func TestStickyFallbackAfterDatagramTooLarge(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
t.Setenv(EnvRelayTransport, string(TransportModeAuto))
|
||||
|
||||
c := &Client{
|
||||
log: log.WithField("test", t.Name()),
|
||||
connectionURL: url,
|
||||
mtu: iface.DefaultMTU,
|
||||
transportFallback: newTransportFallback(),
|
||||
}
|
||||
|
||||
// First dial races both transports.
|
||||
assert.Equal(t, []string{"quic", "WS"}, protocols(c.getDialers(transportModeFromEnv())))
|
||||
|
||||
// An oversized datagram records the fallback for this server.
|
||||
c.onDatagramTooLarge(&closeTrackingConn{}, netErr.ErrDatagramTooLarge)
|
||||
|
||||
// The reconnect now sticks to WebSocket.
|
||||
assert.Equal(t, []string{"WS"}, protocols(c.getDialers(transportModeFromEnv())))
|
||||
}
|
||||
@@ -7,7 +7,11 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
|
||||
)
|
||||
|
||||
func (c *Client) getDialers() []dialer.DialeFn {
|
||||
func (c *Client) getDialers(_ TransportMode) []dialer.DialeFn {
|
||||
// JS/WASM build only uses WebSocket transport
|
||||
return []dialer.DialeFn{ws.Dialer{}}
|
||||
}
|
||||
|
||||
func (c *Client) baseDialers(_ TransportMode) []dialer.DialeFn {
|
||||
return []dialer.DialeFn{ws.Dialer{}}
|
||||
}
|
||||
|
||||
@@ -79,23 +79,30 @@ type Manager struct {
|
||||
|
||||
cleanupInterval time.Duration
|
||||
keepUnusedServerTime time.Duration
|
||||
|
||||
// transportFallback is shared across home and foreign relay clients so a
|
||||
// datagram-too-large failure makes that server avoid datagram-sized transports across reconnects.
|
||||
transportFallback *transportFallback
|
||||
}
|
||||
|
||||
// NewManager creates a new manager instance.
|
||||
// The serverURL address can be empty. In this case, the manager will not serve.
|
||||
func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uint16, opts ...ManagerOption) *Manager {
|
||||
tokenStore := &relayAuth.TokenStore{}
|
||||
tf := newTransportFallback()
|
||||
|
||||
m := &Manager{
|
||||
ctx: ctx,
|
||||
peerID: peerID,
|
||||
tokenStore: tokenStore,
|
||||
mtu: mtu,
|
||||
ctx: ctx,
|
||||
peerID: peerID,
|
||||
tokenStore: tokenStore,
|
||||
mtu: mtu,
|
||||
transportFallback: tf,
|
||||
serverPicker: &ServerPicker{
|
||||
TokenStore: tokenStore,
|
||||
PeerID: peerID,
|
||||
MTU: mtu,
|
||||
ConnectionTimeout: defaultConnectionTimeout,
|
||||
TransportFallback: tf,
|
||||
},
|
||||
relayClients: make(map[string]*RelayTrack),
|
||||
onDisconnectedListeners: make(map[string]*list.List),
|
||||
@@ -287,6 +294,7 @@ func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string
|
||||
m.relayClientsMutex.Unlock()
|
||||
|
||||
relayClient := NewClientWithServerIP(serverAddress, serverIP, m.tokenStore, m.peerID, m.mtu)
|
||||
relayClient.SetTransportFallback(m.transportFallback)
|
||||
err := relayClient.Connect(m.ctx)
|
||||
if err != nil {
|
||||
rt.err = err
|
||||
|
||||
@@ -29,6 +29,7 @@ type ServerPicker struct {
|
||||
PeerID string
|
||||
MTU uint16
|
||||
ConnectionTimeout time.Duration
|
||||
TransportFallback *transportFallback
|
||||
}
|
||||
|
||||
func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
|
||||
@@ -70,6 +71,7 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
|
||||
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
|
||||
log.Infof("try to connecting to relay server: %s", url)
|
||||
relayClient := NewClient(url, sp.TokenStore, sp.PeerID, sp.MTU)
|
||||
relayClient.SetTransportFallback(sp.TransportFallback)
|
||||
err := relayClient.Connect(ctx)
|
||||
resultChan <- connResult{
|
||||
RelayClient: relayClient,
|
||||
|
||||
129
shared/relay/client/transport.go
Normal file
129
shared/relay/client/transport.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer"
|
||||
)
|
||||
|
||||
// EnvRelayTransport pins the relay transport. Valid values: "auto" (default,
|
||||
// race QUIC and WebSocket), "quic" (QUIC only), "ws" (WebSocket only),
|
||||
// "prefer-quic" / "prefer-ws" (try the preferred transport first, fall back to
|
||||
// the other only if it fails to connect; no race). The prefer modes trade a
|
||||
// slower connect when the preferred transport is blackholed for deterministic
|
||||
// transport selection.
|
||||
const EnvRelayTransport = "NB_RELAY_TRANSPORT"
|
||||
|
||||
const (
|
||||
// transportFallbackBase is the initial window a relay server avoids
|
||||
// datagram-sized transports after a datagram is rejected as too large.
|
||||
transportFallbackBase = 10 * time.Minute
|
||||
// transportFallbackMax caps the pinned window when failures repeat.
|
||||
transportFallbackMax = 60 * time.Minute
|
||||
)
|
||||
|
||||
// TransportMode selects which relay dialers are used.
|
||||
type TransportMode string
|
||||
|
||||
const (
|
||||
TransportModeAuto TransportMode = "auto"
|
||||
TransportModeQUIC TransportMode = "quic"
|
||||
TransportModeWS TransportMode = "ws"
|
||||
TransportModePreferQUIC TransportMode = "prefer-quic"
|
||||
TransportModePreferWS TransportMode = "prefer-ws"
|
||||
)
|
||||
|
||||
// transportModeFromEnv reads EnvRelayTransport, defaulting to auto for an empty
|
||||
// or unrecognized value.
|
||||
func transportModeFromEnv() TransportMode {
|
||||
switch TransportMode(strings.ToLower(strings.TrimSpace(os.Getenv(EnvRelayTransport)))) {
|
||||
case "", TransportModeAuto:
|
||||
return TransportModeAuto
|
||||
case TransportModeQUIC:
|
||||
return TransportModeQUIC
|
||||
case TransportModeWS:
|
||||
return TransportModeWS
|
||||
case TransportModePreferQUIC:
|
||||
return TransportModePreferQUIC
|
||||
case TransportModePreferWS:
|
||||
return TransportModePreferWS
|
||||
default:
|
||||
log.Warnf("invalid %s value %q, using %q", EnvRelayTransport, os.Getenv(EnvRelayTransport), TransportModeAuto)
|
||||
return TransportModeAuto
|
||||
}
|
||||
}
|
||||
|
||||
// sequential reports whether the mode tries dialers in order with fallback
|
||||
// instead of racing them concurrently.
|
||||
func (m TransportMode) sequential() bool {
|
||||
return m == TransportModePreferQUIC || m == TransportModePreferWS
|
||||
}
|
||||
|
||||
// transportFallback tracks relay servers that have rejected a datagram-sized
|
||||
// transport (a write too large for the path) and should temporarily avoid such
|
||||
// transports. It is shared across the relay manager so the preference survives
|
||||
// client recreation (foreign relay clients are evicted and rebuilt on
|
||||
// disconnect). Entries are keyed by server URL and expire after a window that
|
||||
// grows on repeated failures.
|
||||
type transportFallback struct {
|
||||
mu sync.Mutex
|
||||
entries map[string]*fallbackEntry
|
||||
}
|
||||
|
||||
type fallbackEntry struct {
|
||||
until time.Time
|
||||
duration time.Duration
|
||||
}
|
||||
|
||||
func newTransportFallback() *transportFallback {
|
||||
return &transportFallback{entries: make(map[string]*fallbackEntry)}
|
||||
}
|
||||
|
||||
// avoidDatagramSized reports whether serverURL is currently within a window
|
||||
// where datagram-sized transports should be avoided.
|
||||
func (f *transportFallback) avoidDatagramSized(serverURL string) bool {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
e := f.entries[serverURL]
|
||||
return e != nil && time.Now().Before(e.until)
|
||||
}
|
||||
|
||||
// recordFailure makes serverURL avoid datagram-sized transports for a window:
|
||||
// transportFallbackBase on the first failure, doubling up to transportFallbackMax
|
||||
// when a datagram transport fails again after a previous window expired. It
|
||||
// returns the active window duration.
|
||||
func (f *transportFallback) recordFailure(serverURL string) time.Duration {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
e := f.entries[serverURL]
|
||||
switch {
|
||||
case e == nil:
|
||||
e = &fallbackEntry{duration: transportFallbackBase}
|
||||
f.entries[serverURL] = e
|
||||
case now.Before(e.until):
|
||||
return time.Until(e.until)
|
||||
default:
|
||||
e.duration = min(e.duration*2, transportFallbackMax)
|
||||
}
|
||||
e.until = now.Add(e.duration)
|
||||
return e.duration
|
||||
}
|
||||
|
||||
// nonDatagramSized returns the dialers from in that are not datagram-sized,
|
||||
// preserving order.
|
||||
func nonDatagramSized(in []dialer.DialeFn) []dialer.DialeFn {
|
||||
out := make([]dialer.DialeFn, 0, len(in))
|
||||
for _, d := range in {
|
||||
if !dialer.IsDatagramSized(d) {
|
||||
out = append(out, d)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
140
shared/relay/client/transport_test.go
Normal file
140
shared/relay/client/transport_test.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
|
||||
)
|
||||
|
||||
// closeTrackingConn records whether Close was called; only Close is exercised.
|
||||
type closeTrackingConn struct {
|
||||
net.Conn
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (c *closeTrackingConn) Close() error {
|
||||
c.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestTransportModeFromEnv(t *testing.T) {
|
||||
tests := []struct {
|
||||
value string
|
||||
want TransportMode
|
||||
}{
|
||||
{"", TransportModeAuto},
|
||||
{"auto", TransportModeAuto},
|
||||
{"quic", TransportModeQUIC},
|
||||
{"QUIC", TransportModeQUIC},
|
||||
{"ws", TransportModeWS},
|
||||
{" Ws ", TransportModeWS},
|
||||
{"prefer-quic", TransportModePreferQUIC},
|
||||
{"prefer-ws", TransportModePreferWS},
|
||||
{"garbage", TransportModeAuto},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.value, func(t *testing.T) {
|
||||
t.Setenv(EnvRelayTransport, tc.value)
|
||||
if tc.value == "" {
|
||||
os.Unsetenv(EnvRelayTransport)
|
||||
}
|
||||
assert.Equal(t, tc.want, transportModeFromEnv())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransportFallbackRecordAndExpiry(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
f := newTransportFallback()
|
||||
|
||||
assert.False(t, f.avoidDatagramSized(url), "no fallback recorded yet")
|
||||
|
||||
d := f.recordFailure(url)
|
||||
assert.Equal(t, transportFallbackBase, d, "first failure pins for the base window")
|
||||
assert.True(t, f.avoidDatagramSized(url), "datagram-sized transport avoided within the window")
|
||||
|
||||
// A second failure while still inside the window must not grow the window.
|
||||
d = f.recordFailure(url)
|
||||
assert.LessOrEqual(t, d, transportFallbackBase, "still within the active window")
|
||||
require.NotNil(t, f.entries[url])
|
||||
assert.Equal(t, transportFallbackBase, f.entries[url].duration, "duration unchanged inside window")
|
||||
|
||||
// Expire the window: datagram-sized transport allowed again.
|
||||
f.entries[url].until = time.Now().Add(-time.Second)
|
||||
assert.False(t, f.avoidDatagramSized(url), "window expired, datagram-sized transport allowed")
|
||||
}
|
||||
|
||||
func TestTransportFallbackGrowsOnRepeat(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
f := newTransportFallback()
|
||||
|
||||
want := transportFallbackBase
|
||||
for i := range 6 {
|
||||
d := f.recordFailure(url)
|
||||
assert.Equal(t, want, d, "window after %d expiries", i)
|
||||
|
||||
// expire the window so the next failure is treated as a repeat
|
||||
f.entries[url].until = time.Now().Add(-time.Second)
|
||||
|
||||
want = min(want*2, transportFallbackMax)
|
||||
}
|
||||
|
||||
assert.Equal(t, transportFallbackMax, f.entries[url].duration, "window caps at the max")
|
||||
}
|
||||
|
||||
func TestOnDatagramTooLargeAuto(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
t.Setenv(EnvRelayTransport, string(TransportModeAuto))
|
||||
|
||||
tf := newTransportFallback()
|
||||
c := &Client{
|
||||
log: log.WithField("test", t.Name()),
|
||||
connectionURL: url,
|
||||
transportFallback: tf,
|
||||
}
|
||||
conn := &closeTrackingConn{}
|
||||
|
||||
c.onDatagramTooLarge(conn, netErr.ErrDatagramTooLarge)
|
||||
|
||||
assert.True(t, conn.closed, "connection closed to force reconnect")
|
||||
assert.True(t, tf.avoidDatagramSized(url), "fallback recorded for the server")
|
||||
|
||||
// A second oversized datagram on the same connection must not re-close.
|
||||
conn.closed = false
|
||||
c.onDatagramTooLarge(conn, netErr.ErrDatagramTooLarge)
|
||||
assert.False(t, conn.closed, "single fallback per connection")
|
||||
}
|
||||
|
||||
func TestOnDatagramTooLargeQUICPinned(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
t.Setenv(EnvRelayTransport, string(TransportModeQUIC))
|
||||
|
||||
tf := newTransportFallback()
|
||||
c := &Client{
|
||||
log: log.WithField("test", t.Name()),
|
||||
connectionURL: url,
|
||||
transportFallback: tf,
|
||||
}
|
||||
conn := &closeTrackingConn{}
|
||||
|
||||
c.onDatagramTooLarge(conn, netErr.ErrDatagramTooLarge)
|
||||
|
||||
assert.False(t, conn.closed, "QUIC pin keeps the connection, no fallback redial")
|
||||
assert.False(t, tf.avoidDatagramSized(url), "QUIC pin records no fallback")
|
||||
}
|
||||
|
||||
func TestTransportFallbackPerServer(t *testing.T) {
|
||||
f := newTransportFallback()
|
||||
f.recordFailure("rels://a.example:443")
|
||||
|
||||
assert.True(t, f.avoidDatagramSized("rels://a.example:443"))
|
||||
assert.False(t, f.avoidDatagramSized("rels://b.example:443"), "fallback is scoped to one server")
|
||||
}
|
||||
54
util/log.go
54
util/log.go
@@ -1,15 +1,16 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
|
||||
"github.com/DeRuina/timberjack"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/grpclog"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
)
|
||||
@@ -37,8 +38,7 @@ func InitLog(logLevel string, logs ...string) error {
|
||||
func InitLogger(logger *log.Logger, logLevel string, logs ...string) error {
|
||||
level, err := log.ParseLevel(logLevel)
|
||||
if err != nil {
|
||||
logger.Errorf("Failed parsing log-level %s: %s", logLevel, err)
|
||||
return err
|
||||
return fmt.Errorf("failed parsing log-level %s: %w", logLevel, err)
|
||||
}
|
||||
var writers []io.Writer
|
||||
logFmt := os.Getenv("NB_LOG_FORMAT")
|
||||
@@ -59,7 +59,11 @@ func InitLogger(logger *log.Logger, logLevel string, logs ...string) error {
|
||||
case "":
|
||||
logger.Warnf("empty log path received: %#v", logPath)
|
||||
default:
|
||||
writers = append(writers, newRotatedOutput(logPath))
|
||||
writer, err := setupLogFile(logPath, isRotationDisabled(logger))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed setting up log file: %s, %w", logPath, err)
|
||||
}
|
||||
writers = append(writers, writer)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,17 +98,43 @@ func FindFirstLogPath(logs []string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func isRotationDisabled(logger *log.Logger) bool {
|
||||
v, _ := os.LookupEnv("NB_LOG_DISABLE_ROTATION")
|
||||
disabled, _ := strconv.ParseBool(v)
|
||||
if disabled {
|
||||
logger.Warnf("log rotation is disabled by env flag")
|
||||
return true
|
||||
}
|
||||
conflict, configPath := FindFirstLogrotateConflict()
|
||||
if conflict {
|
||||
logger.Warnf("log rotation conflict detected in: %#v, rotation is disabled", configPath)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func setupLogFile(logPath string, disableRotation bool) (io.Writer, error) {
|
||||
if disableRotation {
|
||||
file, err := os.OpenFile(logPath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0600)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return file, nil
|
||||
}
|
||||
return newRotatedOutput(logPath), nil
|
||||
}
|
||||
|
||||
func newRotatedOutput(logPath string) io.Writer {
|
||||
maxLogSize := getLogMaxSize()
|
||||
lumberjackLogger := &lumberjack.Logger{
|
||||
timberjackLogger := &timberjack.Logger{
|
||||
// Log file absolute path, os agnostic
|
||||
Filename: filepath.ToSlash(logPath),
|
||||
MaxSize: maxLogSize, // MB
|
||||
MaxBackups: 10,
|
||||
MaxAge: 30, // days
|
||||
Compress: true,
|
||||
Filename: filepath.ToSlash(logPath),
|
||||
MaxSize: maxLogSize, // MB
|
||||
MaxBackups: 10,
|
||||
MaxAge: 30, // days
|
||||
Compression: "gzip",
|
||||
}
|
||||
return lumberjackLogger
|
||||
return timberjackLogger
|
||||
}
|
||||
|
||||
func setGRPCLibLogger(logger *log.Logger) {
|
||||
@@ -127,7 +157,7 @@ func getLogMaxSize() int {
|
||||
if sizeVar, ok := os.LookupEnv("NB_LOG_MAX_SIZE_MB"); ok {
|
||||
size, err := strconv.ParseInt(sizeVar, 10, 64)
|
||||
if err != nil {
|
||||
log.Errorf("Failed parsing log-size %s: %s. Should be just an integer", sizeVar, err)
|
||||
log.Errorf("failed parsing log-size %s: %s. Should be just an integer", sizeVar, err)
|
||||
return defaultLogSize
|
||||
}
|
||||
|
||||
|
||||
96
util/log_test.go
Normal file
96
util/log_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestSetupLogFile_RotatesOnSize drives >MaxSize bytes through the writer
|
||||
// returned by setupLogFile and asserts a backup file appears.
|
||||
func TestSetupLogFile_RotatesOnSize(t *testing.T) {
|
||||
t.Setenv("NB_LOG_MAX_SIZE_MB", "1")
|
||||
|
||||
dir := t.TempDir()
|
||||
logPath := filepath.Join(dir, "netbird.log")
|
||||
|
||||
w, err := setupLogFile(logPath, false)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
if c, ok := w.(io.Closer); ok {
|
||||
_ = c.Close()
|
||||
}
|
||||
})
|
||||
|
||||
chunk := []byte(strings.Repeat("x", 64*1024) + "\n")
|
||||
for range 20 {
|
||||
_, err := w.Write(chunk)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
info, err := os.Stat(logPath)
|
||||
require.NoError(t, err)
|
||||
require.Less(t, info.Size(), int64(1<<20),
|
||||
"active log should be < 1 MB after rotation, got %d", info.Size())
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
entries, _ := os.ReadDir(dir)
|
||||
for _, e := range entries {
|
||||
name := e.Name()
|
||||
if name == filepath.Base(logPath) {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(name, "netbird-") && strings.HasSuffix(name, ".log.gz") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}, 5*time.Second, 50*time.Millisecond, "expected a rotated backup file in %s", dir)
|
||||
}
|
||||
|
||||
// TestSetupLogFile_RotationDisabled verifies that with rotation off, the file
|
||||
// grows past MaxSize and no backups are created.
|
||||
func TestSetupLogFile_RotationDisabled(t *testing.T) {
|
||||
t.Setenv("NB_LOG_MAX_SIZE_MB", "1")
|
||||
|
||||
dir := t.TempDir()
|
||||
logPath := filepath.Join(dir, "netbird.log")
|
||||
|
||||
w, err := setupLogFile(logPath, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
f, ok := w.(*os.File)
|
||||
require.True(t, ok, "expected plain *os.File when rotation is disabled, got %T", w)
|
||||
t.Cleanup(func() { _ = f.Close() })
|
||||
|
||||
chunk := []byte(strings.Repeat("y", 64*1024) + "\n")
|
||||
for range 20 {
|
||||
_, err := w.Write(chunk)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
info, err := os.Stat(logPath)
|
||||
require.NoError(t, err)
|
||||
require.GreaterOrEqual(t, info.Size(), int64(1<<20),
|
||||
"file should exceed MaxSize when rotation is disabled, got %d", info.Size())
|
||||
|
||||
entries, err := os.ReadDir(dir)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 1, "no backup files should exist when rotation is disabled, got %v", entries)
|
||||
}
|
||||
|
||||
// TestIsRotationDisabled_EnvFlag covers the NB_LOG_DISABLE_ROTATION env path.
|
||||
// The logrotate-conflict branch is exercised separately on linux.
|
||||
func TestIsRotationDisabled_EnvFlag(t *testing.T) {
|
||||
logger := log.New()
|
||||
logger.SetOutput(io.Discard)
|
||||
|
||||
t.Setenv("NB_LOG_DISABLE_ROTATION", "true")
|
||||
require.True(t, isRotationDisabled(logger))
|
||||
}
|
||||
93
util/logrotate_linux.go
Normal file
93
util/logrotate_linux.go
Normal file
@@ -0,0 +1,93 @@
|
||||
//go:build linux
|
||||
|
||||
package util
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultLogrotateConfPath = "/etc/logrotate.conf"
|
||||
defaultLogrotateConfDir = "/etc/logrotate.d"
|
||||
netbirdString = "netbird"
|
||||
)
|
||||
|
||||
// FindLogrotateConflicts scans the standard logrotate locations for
|
||||
// indications of conflict with netbird. It returns true and the config file
|
||||
// path if a conflict was found.
|
||||
func FindFirstLogrotateConflict() (bool, string) {
|
||||
return findFirstLogrotateConflictIn(defaultLogrotateConfPath, defaultLogrotateConfDir)
|
||||
}
|
||||
|
||||
func findFirstLogrotateConflictIn(confPath, confDir string) (bool, string) {
|
||||
for _, f := range listLogrotateConfigs(confPath, confDir) {
|
||||
present, err := scanLogrotateFile(f, netbirdString)
|
||||
if err != nil {
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
log.Debugf("scan %s: %v", f, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if present {
|
||||
return present, f
|
||||
}
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// listLogrotateConfigs returns all config files for logrotate.
|
||||
func listLogrotateConfigs(confPath, confDir string) []string {
|
||||
files := []string{confPath}
|
||||
entries, err := os.ReadDir(confDir)
|
||||
if err != nil {
|
||||
return files
|
||||
}
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
files = append(files, filepath.Join(confDir, e.Name()))
|
||||
}
|
||||
return files
|
||||
}
|
||||
|
||||
// scanLogrotateFile reads a config and reports if a non-comment line
|
||||
// contains the given substring.
|
||||
func scanLogrotateFile(path string, substring string) (bool, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer func() {
|
||||
if err := f.Close(); err != nil {
|
||||
log.Debugf("close %s: %v", path, err)
|
||||
}
|
||||
}()
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(stripLogrotateComment(scanner.Text()))
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(line, substring) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func stripLogrotateComment(line string) string {
|
||||
before, _, _ := strings.Cut(line, "#")
|
||||
return before
|
||||
}
|
||||
95
util/logrotate_linux_test.go
Normal file
95
util/logrotate_linux_test.go
Normal file
@@ -0,0 +1,95 @@
|
||||
//go:build linux
|
||||
|
||||
package util
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFindFirstLogrotateConflict(t *testing.T) {
|
||||
t.Run("conflict in confDir", func(t *testing.T) {
|
||||
confPath, confDir := newLogrotateLayout(t)
|
||||
conflictPath := filepath.Join(confDir, "netbird")
|
||||
writeLogrotateConfig(t, conflictPath, `/var/log/netbird/*.log {
|
||||
daily
|
||||
rotate 7
|
||||
}`)
|
||||
writeLogrotateConfig(t, filepath.Join(confDir, "nginx"), `/var/log/nginx/*.log { daily }`)
|
||||
|
||||
got, path := findFirstLogrotateConflictIn(confPath, confDir)
|
||||
require.True(t, got)
|
||||
require.Equal(t, conflictPath, path)
|
||||
})
|
||||
|
||||
t.Run("conflict in main conf file", func(t *testing.T) {
|
||||
confPath, confDir := newLogrotateLayout(t)
|
||||
writeLogrotateConfig(t, confPath, `weekly
|
||||
rotate 4
|
||||
include /etc/logrotate.d
|
||||
/var/log/netbird/client.log { rotate 5 }`)
|
||||
|
||||
got, path := findFirstLogrotateConflictIn(confPath, confDir)
|
||||
require.True(t, got)
|
||||
require.Equal(t, confPath, path)
|
||||
})
|
||||
|
||||
t.Run("no conflict when netbird is absent", func(t *testing.T) {
|
||||
confPath, confDir := newLogrotateLayout(t)
|
||||
writeLogrotateConfig(t, filepath.Join(confDir, "nginx"), `/var/log/nginx/*.log { daily }`)
|
||||
writeLogrotateConfig(t, filepath.Join(confDir, "syslog"), `/var/log/syslog { weekly }`)
|
||||
|
||||
got, path := findFirstLogrotateConflictIn(confPath, confDir)
|
||||
require.False(t, got)
|
||||
require.Empty(t, path)
|
||||
})
|
||||
|
||||
t.Run("commented-out netbird line is ignored", func(t *testing.T) {
|
||||
confPath, confDir := newLogrotateLayout(t)
|
||||
writeLogrotateConfig(t, filepath.Join(confDir, "misc"), `# /var/log/netbird/*.log { daily }
|
||||
/var/log/other.log { weekly }`)
|
||||
|
||||
got, path := findFirstLogrotateConflictIn(confPath, confDir)
|
||||
require.False(t, got)
|
||||
require.Empty(t, path)
|
||||
})
|
||||
|
||||
t.Run("subdirectories in confDir are ignored", func(t *testing.T) {
|
||||
confPath, confDir := newLogrotateLayout(t)
|
||||
sub := filepath.Join(confDir, "nested")
|
||||
require.NoError(t, os.MkdirAll(sub, 0o755))
|
||||
writeLogrotateConfig(t, filepath.Join(sub, "netbird"), `/var/log/netbird/*.log { daily }`)
|
||||
|
||||
got, path := findFirstLogrotateConflictIn(confPath, confDir)
|
||||
require.False(t, got)
|
||||
require.Empty(t, path)
|
||||
})
|
||||
|
||||
t.Run("missing paths return no conflict", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
got, path := findFirstLogrotateConflictIn(
|
||||
filepath.Join(dir, "does-not-exist.conf"),
|
||||
filepath.Join(dir, "does-not-exist.d"),
|
||||
)
|
||||
require.False(t, got)
|
||||
require.Empty(t, path)
|
||||
})
|
||||
}
|
||||
|
||||
// newLogrotateLayout creates a temp logrotate.conf path and logrotate.d dir,
|
||||
// returning their paths. The conf file itself is not created.
|
||||
func newLogrotateLayout(t *testing.T) (confPath, confDir string) {
|
||||
t.Helper()
|
||||
root := t.TempDir()
|
||||
confDir = filepath.Join(root, "logrotate.d")
|
||||
require.NoError(t, os.MkdirAll(confDir, 0o755))
|
||||
return filepath.Join(root, "logrotate.conf"), confDir
|
||||
}
|
||||
|
||||
func writeLogrotateConfig(t *testing.T, path, body string) {
|
||||
t.Helper()
|
||||
require.NoError(t, os.WriteFile(path, []byte(body), 0o644))
|
||||
}
|
||||
10
util/logrotate_nonlinux.go
Normal file
10
util/logrotate_nonlinux.go
Normal file
@@ -0,0 +1,10 @@
|
||||
//go:build !linux
|
||||
|
||||
package util
|
||||
|
||||
// FindLogrotateConflicts scans the standard logrotate locations for
|
||||
// indications of conflict with netbird. It will always return false for
|
||||
// non-linux devices.
|
||||
func FindFirstLogrotateConflict() (bool, string) {
|
||||
return false, ""
|
||||
}
|
||||
Reference in New Issue
Block a user