Compare commits

..

7 Commits

53 changed files with 978 additions and 1062 deletions

View File

@@ -29,10 +29,10 @@ jobs:
persist-credentials: false
- name: Generate FreeBSD port diff
run: bash -x release_files/freebsd-port-diff.sh
run: bash release_files/freebsd-port-diff.sh
- name: Generate FreeBSD port issue body
run: bash -x release_files/freebsd-port-issue-body.sh
run: bash release_files/freebsd-port-issue-body.sh
- name: Check if diff was generated
id: check_diff

View File

@@ -19,7 +19,6 @@ import (
"github.com/netbirdio/netbird/client/server"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/upload-server/types"
"github.com/netbirdio/netbird/version"
)
const errCloseConnection = "Failed to close connection: %v"
@@ -101,7 +100,6 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
Anonymize: anonymizeFlag,
SystemInfo: systemInfoFlag,
LogFileCount: logFileCount,
CliVersion: version.NetbirdVersion(),
}
if uploadBundleFlag {
request.UploadURL = uploadBundleURLFlag
@@ -300,7 +298,6 @@ func runForDuration(cmd *cobra.Command, args []string) error {
Anonymize: anonymizeFlag,
SystemInfo: systemInfoFlag,
LogFileCount: logFileCount,
CliVersion: version.NetbirdVersion(),
}
if uploadBundleFlag {
request.UploadURL = uploadBundleURLFlag
@@ -435,7 +432,6 @@ func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, c
SyncResponse: syncResponse,
LogPath: logFilePath,
CPUProfile: nil,
DaemonVersion: version.NetbirdVersion(), // acting as daemon
},
debug.BundleConfig{
IncludeSystemInfo: true,

View File

@@ -102,7 +102,7 @@ func (p *program) Stop(srv service.Service) error {
}
// Common setup for service control commands
func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc, consoleLog bool) (service.Service, error) {
func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc) (service.Service, error) {
// rootCmd env vars are already applied by PersistentPreRunE.
SetFlagsFromEnvVars(serviceCmd)
@@ -112,14 +112,8 @@ func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel
return nil, err
}
if consoleLog {
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
return nil, fmt.Errorf("init log: %w", err)
}
} else {
if err := util.InitLog(logLevel, logFiles...); err != nil {
return nil, fmt.Errorf("init log: %w", err)
}
if err := util.InitLog(logLevel, logFiles...); err != nil {
return nil, fmt.Errorf("init log: %w", err)
}
cfg, err := newSVCConfig()
@@ -144,7 +138,7 @@ var runCmd = &cobra.Command{
SetupCloseHandler(ctx, cancel)
SetupDebugHandler(ctx, nil, nil, nil, util.FindFirstLogPath(logFiles))
s, err := setupServiceControlCommand(cmd, ctx, cancel, false)
s, err := setupServiceControlCommand(cmd, ctx, cancel)
if err != nil {
return err
}
@@ -158,7 +152,7 @@ var startCmd = &cobra.Command{
Short: "starts NetBird service",
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
s, err := setupServiceControlCommand(cmd, ctx, cancel, false)
s, err := setupServiceControlCommand(cmd, ctx, cancel)
if err != nil {
return err
}
@@ -176,7 +170,7 @@ var stopCmd = &cobra.Command{
Short: "stops NetBird service",
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
s, err := setupServiceControlCommand(cmd, ctx, cancel, false)
s, err := setupServiceControlCommand(cmd, ctx, cancel)
if err != nil {
return err
}
@@ -194,7 +188,7 @@ var restartCmd = &cobra.Command{
Short: "restarts NetBird service",
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
s, err := setupServiceControlCommand(cmd, ctx, cancel, false)
s, err := setupServiceControlCommand(cmd, ctx, cancel)
if err != nil {
return err
}
@@ -212,7 +206,7 @@ var svcStatusCmd = &cobra.Command{
Short: "shows NetBird service status",
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
s, err := setupServiceControlCommand(cmd, ctx, cancel, true)
s, err := setupServiceControlCommand(cmd, ctx, cancel)
if err != nil {
return err
}

View File

@@ -254,8 +254,6 @@ type BundleGenerator struct {
capturePath string
refreshStatus func() // Optional callback to refresh status before bundle generation
clientMetrics MetricsExporter
daemonVersion string
cliVersion string
anonymize bool
includeSystemInfo bool
@@ -280,8 +278,6 @@ type GeneratorDependencies struct {
CapturePath string
RefreshStatus func()
ClientMetrics MetricsExporter
DaemonVersion string
CliVersion string
}
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
@@ -303,8 +299,6 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
capturePath: deps.CapturePath,
refreshStatus: deps.RefreshStatus,
clientMetrics: deps.ClientMetrics,
daemonVersion: deps.DaemonVersion,
cliVersion: deps.CliVersion,
anonymize: cfg.Anonymize,
includeSystemInfo: cfg.IncludeSystemInfo,
@@ -465,11 +459,9 @@ func (g *BundleGenerator) addStatus() error {
protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
protoFullStatus.Events = g.statusRecorder.GetEventHistory()
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, nbstatus.ConvertOptions{
Anonymize: g.anonymize,
ProfileName: profName,
DaemonVersion: g.daemonVersion,
Anonymize: g.anonymize,
ProfileName: profName,
})
overview.CliVersion = g.cliVersion
statusOutput := overview.FullDetailSummary()
statusReader := strings.NewReader(statusOutput)
@@ -806,8 +798,6 @@ func (g *BundleGenerator) addSyncResponse() error {
AllowPartial: true,
}
g.maskSecrets()
jsonBytes, err := options.Marshal(g.syncResponse)
if err != nil {
return fmt.Errorf("generate json: %w", err)
@@ -820,27 +810,6 @@ func (g *BundleGenerator) addSyncResponse() error {
return nil
}
func (g *BundleGenerator) maskSecrets() {
if g.syncResponse == nil || g.syncResponse.NetbirdConfig == nil {
return
}
if g.syncResponse.NetbirdConfig.Flow != nil {
g.syncResponse.NetbirdConfig.Flow.TokenPayload = maskedValue
}
if g.syncResponse.NetbirdConfig.Relay != nil {
g.syncResponse.NetbirdConfig.Relay.TokenPayload = maskedValue
}
for i := range g.syncResponse.NetbirdConfig.Turns {
if g.syncResponse.NetbirdConfig.Turns[i] != nil {
g.syncResponse.NetbirdConfig.Turns[i].Password = maskedValue
}
}
}
func (g *BundleGenerator) addStateFile() error {
sm := profilemanager.NewServiceManager("")
path := sm.GetStatePath()
@@ -1070,8 +1039,7 @@ func (g *BundleGenerator) addRotatedLogFiles(logDir string) {
return
}
// This regex will match both logs rotated by us and logrotate on linux
pattern := filepath.Join(logDir, "client*.log.*")
pattern := filepath.Join(logDir, "client-*.log.gz")
files, err := filepath.Glob(pattern)
if err != nil {
log.Warnf("failed to glob rotated logs: %v", err)
@@ -1104,12 +1072,7 @@ func (g *BundleGenerator) addRotatedLogFiles(logDir string) {
for i := 0; i < maxFiles; i++ {
name := filepath.Base(files[i])
if strings.HasSuffix(name, ".gz") {
err = g.addSingleLogFileGz(files[i], name)
} else {
err = g.addSingleLogfile(files[i], name)
}
if err != nil {
if err := g.addSingleLogFileGz(files[i], name); err != nil {
log.Warnf("failed to add rotated log %s: %v", name, err)
}
}

View File

@@ -1,103 +0,0 @@
package debug
import (
"archive/zip"
"bytes"
"compress/gzip"
"io"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// TestAddRotatedLogFiles_PicksUpAllVariants asserts that the rotated-log
// glob picks up logs rotated by timberjack (gzipped) and by logrotate (plain
// and gzipped), and skips unrelated files.
func TestAddRotatedLogFiles_PicksUpAllVariants(t *testing.T) {
dir := t.TempDir()
writeFile(t, filepath.Join(dir, "client.log"), "active log\n")
writeFile(t, filepath.Join(dir, "other.log"), "unrelated\n")
timberjackRotated := "client-2026-05-21T10-30-45.000.log.gz"
writeGzFile(t, filepath.Join(dir, timberjackRotated), "timberjack rotated content\n")
logrotatePlain := "client.log.1"
writeFile(t, filepath.Join(dir, logrotatePlain), "logrotate plain content\n")
logrotateGz := "client.log.2.gz"
writeGzFile(t, filepath.Join(dir, logrotateGz), "logrotate gz content\n")
names := runAddRotatedLogFiles(t, dir, 10)
require.Contains(t, names, timberjackRotated, "timberjack rotated file should be in bundle")
require.Contains(t, names, logrotatePlain, "logrotate plain rotated file should be in bundle")
require.Contains(t, names, logrotateGz, "logrotate gzipped rotated file should be in bundle")
require.NotContains(t, names, "client.log", "active log should not be added by addRotatedLogFiles")
require.NotContains(t, names, "other.log", "unrelated files should not be in bundle")
}
// TestAddRotatedLogFiles_RespectsLogFileCount asserts that only the newest
// logFileCount rotated files are bundled, ordered by mtime.
func TestAddRotatedLogFiles_RespectsLogFileCount(t *testing.T) {
dir := t.TempDir()
oldest := filepath.Join(dir, "client.log.3")
middle := filepath.Join(dir, "client.log.2")
newest := filepath.Join(dir, "client.log.1")
writeFile(t, oldest, "old\n")
writeFile(t, middle, "mid\n")
writeFile(t, newest, "new\n")
now := time.Now()
require.NoError(t, os.Chtimes(oldest, now.Add(-2*time.Hour), now.Add(-2*time.Hour)))
require.NoError(t, os.Chtimes(middle, now.Add(-1*time.Hour), now.Add(-1*time.Hour)))
require.NoError(t, os.Chtimes(newest, now, now))
names := runAddRotatedLogFiles(t, dir, 2)
require.Contains(t, names, "client.log.1")
require.Contains(t, names, "client.log.2")
require.NotContains(t, names, "client.log.3", "oldest file should be dropped when logFileCount=2")
}
// runAddRotatedLogFiles calls addRotatedLogFiles against a fresh in-memory
// zip writer and returns the set of entry names that ended up in the archive.
func runAddRotatedLogFiles(t *testing.T, dir string, logFileCount uint32) map[string]struct{} {
t.Helper()
var buf bytes.Buffer
g := &BundleGenerator{
archive: zip.NewWriter(&buf),
logFileCount: logFileCount,
}
g.addRotatedLogFiles(dir)
require.NoError(t, g.archive.Close())
zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
require.NoError(t, err)
names := make(map[string]struct{}, len(zr.File))
for _, f := range zr.File {
names[f.Name] = struct{}{}
}
return names
}
func writeFile(t *testing.T, path, content string) {
t.Helper()
require.NoError(t, os.WriteFile(path, []byte(content), 0o644))
}
func writeGzFile(t *testing.T, path, content string) {
t.Helper()
var buf bytes.Buffer
gw := gzip.NewWriter(&buf)
_, err := io.WriteString(gw, content)
require.NoError(t, err)
require.NoError(t, gw.Close())
require.NoError(t, os.WriteFile(path, buf.Bytes(), 0o644))
}

View File

@@ -72,7 +72,6 @@ import (
sProto "github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/capture"
"github.com/netbirdio/netbird/version"
)
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
@@ -1073,7 +1072,6 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
state.KernelInterface = !e.wgInterface.IsUserspaceBind()
state.FQDN = conf.GetFqdn()
state.WgPort = e.config.WgPort
e.statusRecorder.UpdateLocalPeerState(state)
@@ -1152,7 +1150,6 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR
LogPath: e.config.LogPath,
TempDir: e.config.TempDir,
ClientMetrics: e.clientMetrics,
DaemonVersion: version.NetbirdVersion(),
RefreshStatus: func() {
e.RunHealthProbes(true)
},

View File

@@ -111,7 +111,6 @@ type LocalPeerState struct {
PubKey string
KernelInterface bool
FQDN string
WgPort int
Routes map[string]struct{}
}
@@ -1016,14 +1015,17 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
return d.relayStates
}
// extend the list of stun, turn servers with relay address
// extend the list of stun, turn servers with the relay server connections
relayStates := slices.Clone(d.relayStates)
// if the server connection is not established then we will use the general address
// in case of connection we will use the instance specific address
instanceAddr, _, err := d.relayMgr.RelayInstanceAddress()
if err != nil {
// TODO add their status
states := d.relayMgr.RelayStates()
if len(states) == 0 {
// no relay connection tracked yet; surface configured servers as
// unavailable with the real reconnect error when known
err := relayClient.ErrRelayClientNotConnected
if connErr := d.relayMgr.RelayConnectError(); connErr != nil {
err = connErr
}
for _, r := range d.relayMgr.ServerURLs() {
relayStates = append(relayStates, relay.ProbeResult{
URI: r,
@@ -1033,10 +1035,14 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
return relayStates
}
relayState := relay.ProbeResult{
URI: instanceAddr,
for _, rs := range states {
relayStates = append(relayStates, relay.ProbeResult{
URI: rs.URL,
Err: rs.Err,
Transport: rs.Transport,
})
}
return append(relayStates, relayState)
return relayStates
}
func (d *Status) ForwardingRules() []firewall.ForwardRule {
@@ -1358,7 +1364,6 @@ func (fs FullStatus) ToProto() *proto.FullStatus {
pbFullStatus.LocalPeerState.PubKey = fs.LocalPeerState.PubKey
pbFullStatus.LocalPeerState.KernelInterface = fs.LocalPeerState.KernelInterface
pbFullStatus.LocalPeerState.Fqdn = fs.LocalPeerState.FQDN
pbFullStatus.LocalPeerState.WgPort = int32(fs.LocalPeerState.WgPort)
pbFullStatus.LocalPeerState.RosenpassPermissive = fs.RosenpassState.Permissive
pbFullStatus.LocalPeerState.RosenpassEnabled = fs.RosenpassState.Enabled
pbFullStatus.NumberOfForwardingRules = int32(fs.NumOfForwardingRules)
@@ -1397,6 +1402,7 @@ func (fs FullStatus) ToProto() *proto.FullStatus {
pbRelayState := &proto.RelayState{
URI: relayState.URI,
Available: relayState.Err == nil,
Transport: relayState.Transport,
}
if err := relayState.Err; err != nil {
pbRelayState.Error = err.Error()

View File

@@ -32,6 +32,9 @@ type ProbeResult struct {
URI string
Err error
Addr string
// Transport is the negotiated relay transport, empty
// for stun/turn probes or when not connected.
Transport string
}
type StunTurnProbe struct {

View File

@@ -1614,7 +1614,6 @@ type LocalPeerState struct {
RosenpassPermissive bool `protobuf:"varint,6,opt,name=rosenpassPermissive,proto3" json:"rosenpassPermissive,omitempty"`
Networks []string `protobuf:"bytes,7,rep,name=networks,proto3" json:"networks,omitempty"`
Ipv6 string `protobuf:"bytes,8,opt,name=ipv6,proto3" json:"ipv6,omitempty"`
WgPort int32 `protobuf:"varint,9,opt,name=wgPort,proto3" json:"wgPort,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -1705,13 +1704,6 @@ func (x *LocalPeerState) GetIpv6() string {
return ""
}
func (x *LocalPeerState) GetWgPort() int32 {
if x != nil {
return x.WgPort
}
return 0
}
// SignalState contains the latest state of a signal connection
type SignalState struct {
state protoimpl.MessageState `protogen:"open.v1"`
@@ -1840,6 +1832,7 @@ type RelayState struct {
URI string `protobuf:"bytes,1,opt,name=URI,proto3" json:"URI,omitempty"`
Available bool `protobuf:"varint,2,opt,name=available,proto3" json:"available,omitempty"`
Error string `protobuf:"bytes,3,opt,name=error,proto3" json:"error,omitempty"`
Transport string `protobuf:"bytes,4,opt,name=transport,proto3" json:"transport,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -1895,6 +1888,13 @@ func (x *RelayState) GetError() string {
return ""
}
func (x *RelayState) GetTransport() string {
if x != nil {
return x.Transport
}
return ""
}
type NSGroupState struct {
state protoimpl.MessageState `protogen:"open.v1"`
Servers []string `protobuf:"bytes,1,rep,name=servers,proto3" json:"servers,omitempty"`
@@ -2717,7 +2717,6 @@ type DebugBundleRequest struct {
SystemInfo bool `protobuf:"varint,3,opt,name=systemInfo,proto3" json:"systemInfo,omitempty"`
UploadURL string `protobuf:"bytes,4,opt,name=uploadURL,proto3" json:"uploadURL,omitempty"`
LogFileCount uint32 `protobuf:"varint,5,opt,name=logFileCount,proto3" json:"logFileCount,omitempty"`
CliVersion string `protobuf:"bytes,6,opt,name=cliVersion,proto3" json:"cliVersion,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -2780,13 +2779,6 @@ func (x *DebugBundleRequest) GetLogFileCount() uint32 {
return 0
}
func (x *DebugBundleRequest) GetCliVersion() string {
if x != nil {
return x.CliVersion
}
return ""
}
type DebugBundleResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"`
@@ -6405,7 +6397,7 @@ const file_daemon_proto_rawDesc = "" +
"\n" +
"sshHostKey\x18\x13 \x01(\fR\n" +
"sshHostKey\x12\x12\n" +
"\x04ipv6\x18\x14 \x01(\tR\x04ipv6\"\x9c\x02\n" +
"\x04ipv6\x18\x14 \x01(\tR\x04ipv6\"\x84\x02\n" +
"\x0eLocalPeerState\x12\x0e\n" +
"\x02IP\x18\x01 \x01(\tR\x02IP\x12\x16\n" +
"\x06pubKey\x18\x02 \x01(\tR\x06pubKey\x12(\n" +
@@ -6414,8 +6406,7 @@ const file_daemon_proto_rawDesc = "" +
"\x10rosenpassEnabled\x18\x05 \x01(\bR\x10rosenpassEnabled\x120\n" +
"\x13rosenpassPermissive\x18\x06 \x01(\bR\x13rosenpassPermissive\x12\x1a\n" +
"\bnetworks\x18\a \x03(\tR\bnetworks\x12\x12\n" +
"\x04ipv6\x18\b \x01(\tR\x04ipv6\x12\x16\n" +
"\x06wgPort\x18\t \x01(\x05R\x06wgPort\"S\n" +
"\x04ipv6\x18\b \x01(\tR\x04ipv6\"S\n" +
"\vSignalState\x12\x10\n" +
"\x03URL\x18\x01 \x01(\tR\x03URL\x12\x1c\n" +
"\tconnected\x18\x02 \x01(\bR\tconnected\x12\x14\n" +
@@ -6423,12 +6414,13 @@ const file_daemon_proto_rawDesc = "" +
"\x0fManagementState\x12\x10\n" +
"\x03URL\x18\x01 \x01(\tR\x03URL\x12\x1c\n" +
"\tconnected\x18\x02 \x01(\bR\tconnected\x12\x14\n" +
"\x05error\x18\x03 \x01(\tR\x05error\"R\n" +
"\x05error\x18\x03 \x01(\tR\x05error\"p\n" +
"\n" +
"RelayState\x12\x10\n" +
"\x03URI\x18\x01 \x01(\tR\x03URI\x12\x1c\n" +
"\tavailable\x18\x02 \x01(\bR\tavailable\x12\x14\n" +
"\x05error\x18\x03 \x01(\tR\x05error\"r\n" +
"\x05error\x18\x03 \x01(\tR\x05error\x12\x1c\n" +
"\ttransport\x18\x04 \x01(\tR\ttransport\"r\n" +
"\fNSGroupState\x12\x18\n" +
"\aservers\x18\x01 \x03(\tR\aservers\x12\x18\n" +
"\adomains\x18\x02 \x03(\tR\adomains\x12\x18\n" +
@@ -6492,17 +6484,14 @@ const file_daemon_proto_rawDesc = "" +
"\x12translatedHostname\x18\x04 \x01(\tR\x12translatedHostname\x128\n" +
"\x0etranslatedPort\x18\x05 \x01(\v2\x10.daemon.PortInfoR\x0etranslatedPort\"G\n" +
"\x17ForwardingRulesResponse\x12,\n" +
"\x05rules\x18\x01 \x03(\v2\x16.daemon.ForwardingRuleR\x05rules\"\xb4\x01\n" +
"\x05rules\x18\x01 \x03(\v2\x16.daemon.ForwardingRuleR\x05rules\"\x94\x01\n" +
"\x12DebugBundleRequest\x12\x1c\n" +
"\tanonymize\x18\x01 \x01(\bR\tanonymize\x12\x1e\n" +
"\n" +
"systemInfo\x18\x03 \x01(\bR\n" +
"systemInfo\x12\x1c\n" +
"\tuploadURL\x18\x04 \x01(\tR\tuploadURL\x12\"\n" +
"\flogFileCount\x18\x05 \x01(\rR\flogFileCount\x12\x1e\n" +
"\n" +
"cliVersion\x18\x06 \x01(\tR\n" +
"cliVersion\"}\n" +
"\flogFileCount\x18\x05 \x01(\rR\flogFileCount\"}\n" +
"\x13DebugBundleResponse\x12\x12\n" +
"\x04path\x18\x01 \x01(\tR\x04path\x12 \n" +
"\vuploadedKey\x18\x02 \x01(\tR\vuploadedKey\x120\n" +

View File

@@ -349,7 +349,6 @@ message LocalPeerState {
bool rosenpassPermissive = 6;
repeated string networks = 7;
string ipv6 = 8;
int32 wgPort = 9;
}
// SignalState contains the latest state of a signal connection
@@ -371,6 +370,7 @@ message RelayState {
string URI = 1;
bool available = 2;
string error = 3;
string transport = 4;
}
message NSGroupState {
@@ -472,7 +472,6 @@ message DebugBundleRequest {
bool systemInfo = 3;
string uploadURL = 4;
uint32 logFileCount = 5;
string cliVersion = 6;
}
message DebugBundleResponse {

View File

@@ -1,16 +1,17 @@
#!/bin/bash
set -e
if ! which realpath >/dev/null 2>&1; then
echo realpath is not installed
echo run: brew install coreutils
exit 1
if ! which realpath > /dev/null 2>&1
then
echo realpath is not installed
echo run: brew install coreutils
exit 1
fi
old_pwd=$(pwd)
script_path=$(dirname "$(realpath "$0")")
script_path=$(dirname $(realpath "$0"))
cd "$script_path"
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.36.6
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.6.1
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
protoc -I ./ ./daemon.proto --go_out=../ --go-grpc_out=../ --experimental_allow_proto3_optional
cd "$old_pwd"

View File

@@ -14,7 +14,6 @@ import (
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/proto"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/version"
)
// DebugBundle creates a debug bundle and returns the location.
@@ -68,8 +67,6 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
CapturePath: capturePath,
RefreshStatus: refreshStatus,
ClientMetrics: clientMetrics,
DaemonVersion: version.NetbirdVersion(),
CliVersion: req.CliVersion,
},
debug.BundleConfig{
Anonymize: req.GetAnonymize(),

View File

@@ -98,6 +98,7 @@ type RelayStateOutputDetail struct {
URI string `json:"uri" yaml:"uri"`
Available bool `json:"available" yaml:"available"`
Error string `json:"error" yaml:"error"`
Transport string `json:"transport,omitempty" yaml:"transport,omitempty"`
}
type RelayStateOutput struct {
@@ -143,7 +144,6 @@ type OutputOverview struct {
IPv6 string `json:"netbirdIpv6,omitempty" yaml:"netbirdIpv6,omitempty"`
PubKey string `json:"publicKey" yaml:"publicKey"`
KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"`
WgPort int `json:"wireguardPort" yaml:"wireguardPort"`
FQDN string `json:"fqdn" yaml:"fqdn"`
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
@@ -188,7 +188,6 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
IPv6: pbFullStatus.GetLocalPeerState().GetIpv6(),
PubKey: pbFullStatus.GetLocalPeerState().GetPubKey(),
KernelInterface: pbFullStatus.GetLocalPeerState().GetKernelInterface(),
WgPort: int(pbFullStatus.GetLocalPeerState().GetWgPort()),
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
@@ -219,7 +218,8 @@ func mapRelays(relays []*proto.RelayState) RelayStateOutput {
RelayStateOutputDetail{
URI: relay.URI,
Available: available,
Error: relay.GetError(),
Error: relayErrorString(relay.GetError()),
Transport: relay.GetTransport(),
},
)
@@ -235,6 +235,12 @@ func mapRelays(relays []*proto.RelayState) RelayStateOutput {
}
}
// relayErrorString flattens a newline-joined aggregated relay error onto a
// single line for status output.
func relayErrorString(s string) string {
return strings.ReplaceAll(s, "\n", "; ")
}
func mapNSGroups(servers []*proto.NSGroupState) []NsServerGroupStateOutput {
mappedNSGroups := make([]NsServerGroupStateOutput, 0, len(servers))
for _, pbNsGroupServer := range servers {
@@ -441,6 +447,8 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
available = "Unavailable"
reason = fmt.Sprintf(", reason: %s", relay.Error)
}
} else if relay.Transport != "" {
available = fmt.Sprintf("%s via %s", available, relay.Transport)
}
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
@@ -549,21 +557,6 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
goarm = fmt.Sprintf(" (ARMv%s)", os.Getenv("GOARM"))
}
daemonVersion := "N/A"
if o.DaemonVersion != "" {
daemonVersion = o.DaemonVersion
}
cliVersion := version.NetbirdVersion()
if o.CliVersion != "" {
cliVersion = o.CliVersion
}
wgPortString := "N/A"
if o.WgPort > 0 {
wgPortString = fmt.Sprintf("%d", o.WgPort)
}
summary := fmt.Sprintf(
"OS: %s\n"+
"Daemon version: %s\n"+
@@ -577,7 +570,6 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
"NetBird IP: %s\n"+
"%s"+
"Interface type: %s\n"+
"Wireguard port: %s\n"+
"Quantum resistance: %s\n"+
"Lazy connection: %s\n"+
"SSH Server: %s\n"+
@@ -585,8 +577,8 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
"%s"+
"Peers count: %s\n",
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
daemonVersion,
cliVersion,
o.DaemonVersion,
version.NetbirdVersion(),
o.ProfileName,
managementConnString,
signalConnString,
@@ -596,7 +588,6 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
interfaceIP,
ipv6Line,
interfaceTypeString,
wgPortString,
rosenpassEnabledStatus,
lazyConnectionEnabledStatus,
sshServerStatus,

View File

@@ -94,7 +94,6 @@ var resp = &proto.StatusResponse{
Ipv6: "fd00::100",
PubKey: "Some-Pub-Key",
KernelInterface: true,
WgPort: 51820,
Fqdn: "some-localhost.awesome-domain.com",
Networks: []string{
"10.10.0.0/24",
@@ -211,7 +210,6 @@ var overview = OutputOverview{
IPv6: "fd00::100",
PubKey: "Some-Pub-Key",
KernelInterface: true,
WgPort: 51820,
FQDN: "some-localhost.awesome-domain.com",
NSServerGroups: []NsServerGroupStateOutput{
{
@@ -371,7 +369,6 @@ func TestParsingToJSON(t *testing.T) {
"netbirdIpv6": "fd00::100",
"publicKey": "Some-Pub-Key",
"usesKernelInterface": true,
"wireguardPort": 51820,
"fqdn": "some-localhost.awesome-domain.com",
"quantumResistance": false,
"quantumResistancePermissive": false,
@@ -490,7 +487,6 @@ netbirdIp: 192.168.178.100/16
netbirdIpv6: fd00::100
publicKey: Some-Pub-Key
usesKernelInterface: true
wireguardPort: 51820
fqdn: some-localhost.awesome-domain.com
quantumResistance: false
quantumResistancePermissive: false
@@ -583,13 +579,12 @@ FQDN: some-localhost.awesome-domain.com
NetBird IP: 192.168.178.100/16
NetBird IPv6: fd00::100
Interface type: Kernel
Wireguard port: %d
Quantum resistance: false
Lazy connection: false
SSH Server: Disabled
Networks: 10.10.0.0/24
Peers count: 2/2 Connected
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion, overview.WgPort)
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
assert.Equal(t, expectedDetail, detail)
}
@@ -609,7 +604,6 @@ FQDN: some-localhost.awesome-domain.com
NetBird IP: 192.168.178.100/16
NetBird IPv6: fd00::100
Interface type: Kernel
Wireguard port: 51820
Quantum resistance: false
Lazy connection: false
SSH Server: Disabled
@@ -647,3 +641,13 @@ func TestTimeAgo(t *testing.T) {
})
}
}
func TestMapRelaysTransport(t *testing.T) {
out := mapRelays([]*proto.RelayState{
{URI: "rels://relay.example:443", Available: true, Transport: "quic"},
{URI: "rels://relay2.example:443", Available: true, Transport: "ws"},
})
require.Len(t, out.Details, 2)
assert.Equal(t, "quic", out.Details[0].Transport)
assert.Equal(t, "ws", out.Details[1].Transport)
}

View File

@@ -502,7 +502,7 @@ func (s *serviceClient) getConnectionForm() *widget.Form {
{Text: "Pre-shared Key", Widget: s.iPreSharedKey},
{Text: "Quantum-Resistance", Widget: s.sRosenpassPermissive},
{Text: "Interface Name", Widget: s.iInterfaceName},
{Text: "Interface Port", Widget: s.iInterfacePort, HintText: "If set to 0, a random free port will be used"},
{Text: "Interface Port", Widget: s.iInterfacePort},
{Text: "MTU", Widget: s.iMTU},
{Text: "Log File", Widget: s.iLogFile},
},
@@ -558,8 +558,8 @@ func (s *serviceClient) parseNumericSettings() (int64, int64, error) {
if err != nil {
return 0, 0, errors.New("invalid interface port")
}
if port < 0 || port > 65535 {
return 0, 0, errors.New("invalid interface port: out of range 0-65535")
if port < 1 || port > 65535 {
return 0, 0, errors.New("invalid interface port: out of range 1-65535")
}
var mtu int64
@@ -1438,7 +1438,7 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config {
}
config.WgIface = cfg.InterfaceName
if cfg.WireguardPort >= 0 && cfg.WireguardPort <= 65535 {
if cfg.WireguardPort != 0 {
config.WgPort = int(cfg.WireguardPort)
} else {
config.WgPort = iface.DefaultWgPort

View File

@@ -21,7 +21,6 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
uptypes "github.com/netbirdio/netbird/upload-server/types"
"github.com/netbirdio/netbird/version"
)
// Initial state for the debug collection
@@ -463,7 +462,6 @@ func (s *serviceClient) createDebugBundleFromCollection(
request := &proto.DebugBundleRequest{
Anonymize: params.anonymize,
SystemInfo: params.systemInfo,
CliVersion: version.NetbirdVersion(),
}
if params.upload {
@@ -595,7 +593,6 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
request := &proto.DebugBundleRequest{
Anonymize: anonymize,
SystemInfo: systemInfo,
CliVersion: version.NetbirdVersion(),
}
if uploadURL != "" {

2
go.mod
View File

@@ -24,13 +24,13 @@ require (
golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/grpc v1.80.0
google.golang.org/protobuf v1.36.11
gopkg.in/natefinch/lumberjack.v2 v2.2.1
)
require (
fyne.io/fyne/v2 v2.7.0
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9
git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3
github.com/DeRuina/timberjack v1.4.2
github.com/awnumar/memguard v0.23.0
github.com/aws/aws-sdk-go-v2 v1.38.3
github.com/aws/aws-sdk-go-v2/config v1.31.6

4
go.sum
View File

@@ -29,8 +29,6 @@ github.com/Azure/go-ntlmssp v0.1.0 h1:DjFo6YtWzNqNvQdrwEyr/e4nhU3vRiwenz5QX7sFz+
github.com/Azure/go-ntlmssp v0.1.0/go.mod h1:NYqdhxd/8aAct/s4qSYZEerdPuH1liG2/X9DiVTbhpk=
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
github.com/DeRuina/timberjack v1.4.2 h1:4bKlzhKdsR+2oNkgef9mqb4n11ICow8VK88RfzJPzN8=
github.com/DeRuina/timberjack v1.4.2/go.mod h1:RLoeQrwrCGIEF8gO5nV5b/gMD0QIy7bzQhBUgpp1EqE=
github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI=
github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU=
github.com/Masterminds/semver/v3 v3.3.0 h1:B8LGeaivUe71a5qox1ICM/JLl0NqZSW5CHyL+hmvYS0=
@@ -942,6 +940,8 @@ gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8
gopkg.in/go-playground/validator.v9 v9.29.1/go.mod h1:+c9/zcJMFNgbLvly1L1V+PpxWdVbfP1avr/N00E2vyQ=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI=
gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=

View File

@@ -978,7 +978,6 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
Mode: m.Mode,
ListenPort: m.ListenPort,
AccessRestrictions: m.AccessRestrictions,
Private: m.Private,
}
}

View File

@@ -1,88 +0,0 @@
package grpc
import (
"reflect"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/proto"
)
// authTokenField is the only per-proxy field that shallowCloneMapping must NOT
// copy from the source, since callers assign it individually after cloning.
const authTokenField = "AuthToken"
// TestShallowCloneMapping_ClonesAllFields populates every exported field of
// ProxyMapping with a non-zero value and verifies the clone carries each one
// (except AuthToken). It uses reflection so adding a new field to ProxyMapping
// without updating shallowCloneMapping fails this test.
func TestShallowCloneMapping_ClonesAllFields(t *testing.T) {
src := &proto.ProxyMapping{}
populated := populateExportedFields(t, reflect.ValueOf(src).Elem())
require.NotEmpty(t, populated, "ProxyMapping should expose fields to populate")
clone := shallowCloneMapping(src)
require.NotNil(t, clone, "clone must not be nil")
srcVal := reflect.ValueOf(src).Elem()
cloneVal := reflect.ValueOf(clone).Elem()
for _, name := range populated {
srcField := srcVal.FieldByName(name).Interface()
cloneField := cloneVal.FieldByName(name).Interface()
if name == authTokenField {
assert.Zero(t, cloneField, "AuthToken must not be cloned; it is set per proxy after cloning")
continue
}
assert.Equal(t, srcField, cloneField, "field %s must be carried over by shallowCloneMapping", name)
}
}
// populateExportedFields sets a non-zero value on every settable exported field
// of the struct and returns their names.
func populateExportedFields(t *testing.T, v reflect.Value) []string {
t.Helper()
var names []string
typ := v.Type()
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
structField := typ.Field(i)
if structField.PkgPath != "" || !field.CanSet() {
continue
}
setNonZero(t, field, structField.Name)
names = append(names, structField.Name)
}
return names
}
// setNonZero assigns a deterministic non-zero value based on the field kind.
func setNonZero(t *testing.T, field reflect.Value, name string) {
t.Helper()
switch field.Kind() {
case reflect.String:
field.SetString("non-zero-" + name)
case reflect.Bool:
field.SetBool(true)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
field.SetInt(7)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.SetUint(7)
case reflect.Ptr:
field.Set(reflect.New(field.Type().Elem()))
case reflect.Slice:
field.Set(reflect.MakeSlice(field.Type(), 1, 1))
case reflect.Map:
field.Set(reflect.MakeMapWithSize(field.Type(), 0))
default:
t.Fatalf("unhandled field kind %s for field %s; extend setNonZero", field.Kind(), name)
}
}

View File

@@ -1216,7 +1216,6 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
Preload("NetworkResources").
Preload("Onboarding").
Preload("Services.Targets").
Preload("Domains").
Take(&account, idQueryCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
@@ -1303,7 +1302,7 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
}
var wg sync.WaitGroup
errChan := make(chan error, 16)
errChan := make(chan error, 12)
wg.Add(1)
go func() {
@@ -1404,17 +1403,6 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
account.Services = services
}()
wg.Add(1)
go func() {
defer wg.Done()
domains, err := s.ListCustomDomains(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.Domains = domains
}()
wg.Add(1)
go func() {
defer wg.Done()

View File

@@ -4,8 +4,6 @@ import (
"context"
"net"
"net/netip"
"os"
"runtime"
"testing"
"time"
@@ -23,63 +21,6 @@ import (
"github.com/netbirdio/netbird/route"
)
// TestGetAccount_LoadsCustomDomains verifies GetAccount populates account.Domains.
// SynthesizePrivateServiceZones depends on this relation to anchor a custom-domain
// private service's DNS zone; without the preload the relation is empty and the
// service is silently skipped, so a custom domain never resolves on clients.
func TestGetAccount_LoadsCustomDomains(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
require.NoError(t, err)
defer cleanup()
assertGetAccountLoadsCustomDomains(t, store)
}
func TestPostgresql_GetAccount_LoadsCustomDomains(t *testing.T) {
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
t.Skip("skip CI tests on darwin and windows")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
require.NoError(t, err)
t.Cleanup(cleanup)
assertGetAccountLoadsCustomDomains(t, store)
}
// assertGetAccountLoadsCustomDomains exercises both the gorm and pgx GetAccount
// paths: it persists two custom domains and asserts the relation comes back
// populated, which SynthesizePrivateServiceZones relies on.
func assertGetAccountLoadsCustomDomains(t *testing.T, store Store) {
t.Helper()
ctx := context.Background()
accountID := "acct-custom-domains"
require.NoError(t, store.SaveAccount(ctx, newAccountWithId(ctx, accountID, "user-1", "")))
_, err := store.CreateCustomDomain(ctx, accountID, "example.com", "eu.proxy.netbird.io", true)
require.NoError(t, err, "creating the first custom domain must succeed")
_, err = store.CreateCustomDomain(ctx, accountID, "apps.acme.io", "us.proxy.netbird.io", false)
require.NoError(t, err, "creating the second custom domain must succeed")
account, err := store.GetAccount(ctx, accountID)
require.NoError(t, err)
require.Len(t, account.Domains, 2, "GetAccount must preload the account's custom domains")
byDomain := map[string]string{}
for _, d := range account.Domains {
require.NotNil(t, d)
byDomain[d.Domain] = d.TargetCluster
}
assert.Equal(t, "eu.proxy.netbird.io", byDomain["example.com"], "custom domain must carry its target cluster")
assert.Equal(t, "us.proxy.netbird.io", byDomain["apps.acme.io"], "custom domain must carry its target cluster")
}
// TestGetAccount_ComprehensiveFieldValidation validates that GetAccount properly loads
// all fields and nested objects from the database, including deeply nested structures.
func TestGetAccount_ComprehensiveFieldValidation(t *testing.T) {

View File

@@ -273,7 +273,7 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
}
peerGroups := a.GetPeerGroups(peerID)
zonesByApex := map[string]*nbdns.CustomZone{}
zonesByCluster := map[string]*nbdns.CustomZone{}
for _, svc := range a.Services {
if svc == nil || !svc.Enabled || !svc.Private {
@@ -290,24 +290,19 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
continue
}
serviceDomainZone := a.privateServiceDomainZone(svc)
if serviceDomainZone == "" {
continue
}
zone, exists := zonesByApex[serviceDomainZone]
zone, exists := zonesByCluster[svc.ProxyCluster]
if !exists {
// NonAuthoritative makes this a match-only zone: queries for
// names without an explicit record fall through to the
// upstream resolver instead of returning NXDOMAIN. Without
// it, adding a single private service would black-hole every
// other name under the zone apex.
// other name under the cluster apex.
zone = &nbdns.CustomZone{
Domain: dns.Fqdn(serviceDomainZone),
Domain: dns.Fqdn(svc.ProxyCluster),
Records: []nbdns.SimpleRecord{},
NonAuthoritative: true,
}
zonesByApex[serviceDomainZone] = zone
zonesByCluster[svc.ProxyCluster] = zone
}
emitted := 0
@@ -345,8 +340,8 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
}
}
out := make([]nbdns.CustomZone, 0, len(zonesByApex))
for _, zone := range zonesByApex {
out := make([]nbdns.CustomZone, 0, len(zonesByCluster))
for _, zone := range zonesByCluster {
if len(zone.Records) == 0 {
continue
}
@@ -362,33 +357,6 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
return out
}
// privateServiceDomainZone returns the DNS zone name for the given private service domain by
// looking at the proxy cluster domain then the custom domains.
func (a *Account) privateServiceDomainZone(svc *service.Service) string {
if domainFromSuffix(svc.Domain, svc.ProxyCluster) {
return svc.ProxyCluster
}
// Longest matching custom domain wins
zoneName := ""
for _, d := range a.Domains {
if d == nil || d.TargetCluster != svc.ProxyCluster {
continue
}
if domainFromSuffix(svc.Domain, d.Domain) && len(d.Domain) > len(zoneName) {
zoneName = d.Domain
}
}
return zoneName
}
func domainFromSuffix(domain, suffix string) bool {
if suffix == "" {
return false
}
return domain == suffix || strings.HasSuffix(domain, "."+suffix)
}
// peerInDistributionGroups reports whether any of the peer's groups
// matches the service's bearer-auth distribution_groups.
func peerInDistributionGroups(peerGroups LookupMap, distributionGroups []string) bool {

View File

@@ -11,7 +11,6 @@ import (
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
proxydomain "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
@@ -235,113 +234,6 @@ func TestPrivateZone_GetPeerNetworkMap_PeerOutsideGroups_OmitsSynthZone(t *testi
assert.False(t, ok, "peer outside the distribution_groups must not see the synth zone")
}
func TestSynthesizePrivateServiceZones_CustomDomain_ZoneApexIsRegisteredDomain(t *testing.T) {
account := privateZoneTestAccount(t)
// A custom-domain service: Domain is the custom FQDN, ProxyCluster
// is the cluster serving it, and account.Domains holds the registered
// custom domain. The synth zone apex must be the registered domain,
// not the cluster, or the client's match-only zone never intercepts
// the query.
account.Services[0].Domain = "app.example.com"
account.Domains = []*proxydomain.Domain{
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
}
zones := account.SynthesizePrivateServiceZones("user-peer")
require.Len(t, zones, 1, "custom-domain service must still produce one zone")
zone := zones[0]
assert.Equal(t, "example.com.", zone.Domain, "zone apex must be the registered custom domain, not the cluster or the service FQDN")
assert.True(t, zone.NonAuthoritative, "synth zone must remain match-only")
require.Len(t, zone.Records, 1, "custom-domain service yields one A record")
rec := zone.Records[0]
assert.Equal(t, "app.example.com.", rec.Name, "record name is the custom service FQDN")
assert.Equal(t, "100.64.0.99", rec.RData, "record points at the embedded proxy peer's tunnel IP")
}
func TestSynthesizePrivateServiceZones_CustomAndFreeDomain_SeparateZones(t *testing.T) {
account := privateZoneTestAccount(t)
account.Domains = []*proxydomain.Domain{
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
}
account.Services = append(account.Services, &service.Service{
ID: "svc-2",
AccountID: "acct-1",
Name: "custom",
Domain: "app.example.com",
ProxyCluster: "eu.proxy.netbird.io",
Enabled: true,
Private: true,
Mode: service.ModeHTTP,
AccessGroups: []string{"grp-admins"},
})
zones := account.SynthesizePrivateServiceZones("user-peer")
require.Len(t, zones, 2, "a free-domain and a custom-domain service must not collapse into one zone")
free, ok := findCustomZone(zones, "eu.proxy.netbird.io")
require.True(t, ok, "free-domain service keeps the shared cluster-apex zone")
require.Len(t, free.Records, 1, "cluster zone carries only the free-domain record")
assert.Equal(t, "myapp.eu.proxy.netbird.io.", free.Records[0].Name, "cluster zone record is the free-domain FQDN")
custom, ok := findCustomZone(zones, "example.com")
require.True(t, ok, "custom-domain service gets its own zone at the registered custom domain apex")
require.Len(t, custom.Records, 1, "custom zone carries only the custom-domain record")
assert.Equal(t, "app.example.com.", custom.Records[0].Name, "custom zone record is the custom-domain FQDN")
}
func TestSynthesizePrivateServiceZones_TwoServicesSameCustomDomain_OneZone(t *testing.T) {
account := privateZoneTestAccount(t)
account.Domains = []*proxydomain.Domain{
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
}
account.Services[0].Domain = "a.example.com"
account.Services = append(account.Services, &service.Service{
ID: "svc-2",
AccountID: "acct-1",
Name: "bapp",
Domain: "b.example.com",
ProxyCluster: "eu.proxy.netbird.io",
Enabled: true,
Private: true,
Mode: service.ModeHTTP,
AccessGroups: []string{"grp-admins"},
})
zones := account.SynthesizePrivateServiceZones("user-peer")
require.Len(t, zones, 1, "two services under the same registered custom domain must share one zone")
assert.Equal(t, "example.com.", zones[0].Domain, "shared zone apex is the registered custom domain")
require.Len(t, zones[0].Records, 2, "both services surface as records in the shared custom-domain zone")
names := []string{zones[0].Records[0].Name, zones[0].Records[1].Name}
assert.ElementsMatch(t, []string{"a.example.com.", "b.example.com."}, names, "both custom-domain service FQDNs must surface")
}
func TestSynthesizePrivateServiceZones_CustomDomainNotRegistered_NoZone(t *testing.T) {
account := privateZoneTestAccount(t)
// Service domain is outside the cluster and no account.Domains entry
// covers it: there is no apex that would intercept the query, so the
// service must be skipped rather than emit an unmatchable record.
account.Services[0].Domain = "app.example.com"
zones := account.SynthesizePrivateServiceZones("user-peer")
assert.Empty(t, zones, "a custom-domain service with no registered domain apex must not produce a zone")
}
func TestSynthesizePrivateServiceZones_CustomDomainClusterMismatch_NoZone(t *testing.T) {
account := privateZoneTestAccount(t)
// The registered custom domain matches the service FQDN by suffix but
// targets a different cluster than the service's ProxyCluster. It must
// be ignored, leaving no apex to intercept the query — otherwise the
// zone would point at this cluster's proxy peers under a domain owned
// by a different cluster.
account.Services[0].Domain = "app.example.com"
account.Domains = []*proxydomain.Domain{
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "us.proxy.netbird.io", Validated: true},
}
zones := account.SynthesizePrivateServiceZones("user-peer")
assert.Empty(t, zones, "a custom domain targeting a different cluster must not anchor the service zone")
}
func TestSynthesizePrivateServiceZones_TwoServicesSameCluster_OneZone(t *testing.T) {
account := privateZoneTestAccount(t)
account.Services = append(account.Services, &service.Service{
@@ -362,72 +254,3 @@ func TestSynthesizePrivateServiceZones_TwoServicesSameCluster_OneZone(t *testing
names := []string{zones[0].Records[0].Name, zones[0].Records[1].Name}
assert.ElementsMatch(t, []string{"myapp.eu.proxy.netbird.io.", "anotherapp.eu.proxy.netbird.io."}, names, "both service domains must surface")
}
func TestSynthesizePrivateServiceZones_MixedClusterCustomAndPublic(t *testing.T) {
account := privateZoneTestAccount(t)
account.Domains = []*proxydomain.Domain{
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
}
privateService := func(id, domain string) *service.Service {
return &service.Service{
ID: id,
AccountID: "acct-1",
Name: id,
Domain: domain,
ProxyCluster: "eu.proxy.netbird.io",
Enabled: true,
Private: true,
Mode: service.ModeHTTP,
AccessGroups: []string{"grp-admins"},
}
}
publicService := func(id, domain string) *service.Service {
s := privateService(id, domain)
s.Private = false
return s
}
account.Services = []*service.Service{
// 3 private services under the cluster suffix.
privateService("cluster-1", "cluster1.eu.proxy.netbird.io"),
privateService("cluster-2", "cluster2.eu.proxy.netbird.io"),
privateService("cluster-3", "cluster3.eu.proxy.netbird.io"),
// 4 private services under the custom domain suffix.
privateService("custom-1", "custom1.example.com"),
privateService("custom-2", "custom2.example.com"),
privateService("custom-3", "custom3.example.com"),
privateService("custom-4", "custom4.example.com"),
// 2 public services, one per suffix, must not surface.
publicService("public-cluster", "public.eu.proxy.netbird.io"),
publicService("public-custom", "public.example.com"),
}
zones := account.SynthesizePrivateServiceZones("user-peer")
require.Len(t, zones, 2, "one zone per apex: the cluster apex and the custom domain apex")
cluster, ok := findCustomZone(zones, "eu.proxy.netbird.io")
require.True(t, ok, "cluster-suffix services collapse into the cluster-apex zone")
clusterNames := recordNames(cluster)
assert.ElementsMatch(t,
[]string{"cluster1.eu.proxy.netbird.io.", "cluster2.eu.proxy.netbird.io.", "cluster3.eu.proxy.netbird.io."},
clusterNames,
"only the 3 private cluster services surface in the cluster zone (public one excluded)")
custom, ok := findCustomZone(zones, "example.com")
require.True(t, ok, "custom-suffix services collapse into the custom-domain-apex zone")
customNames := recordNames(custom)
assert.ElementsMatch(t,
[]string{"custom1.example.com.", "custom2.example.com.", "custom3.example.com.", "custom4.example.com."},
customNames,
"only the 4 private custom services surface in the custom zone (public one excluded)")
}
// recordNames returns the record names of a zone for order-independent assertions.
func recordNames(zone nbdns.CustomZone) []string {
names := make([]string, 0, len(zone.Records))
for _, r := range zone.Records {
names = append(names, r.Name)
}
return names
}

View File

@@ -557,6 +557,7 @@ func (c *NetworkMapComponents) getRoutingPeerRoutes(peerID string) (enabledRoute
return enabledRoutes, disabledRoutes
}
func (c *NetworkMapComponents) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route {
var filteredRoutes []*route.Route
for _, r := range routes {
@@ -627,14 +628,9 @@ func (c *NetworkMapComponents) getDefaultPermit(r *route.Route, includeIPv6 bool
rules := []*RouteFirewallRule{&rule}
isDefaultV4 := r.Network.Addr().Is4() && r.Network.Bits() == 0
if includeIPv6 && (r.IsDynamic() || isDefaultV4) {
if includeIPv6 && r.IsDynamic() {
ruleV6 := rule
ruleV6.SourceRanges = []string{"::/0"}
if isDefaultV4 {
ruleV6.Destination = "::/0"
ruleV6.RouteID = r.ID + "-v6-default"
}
rules = append(rules, &ruleV6)
}

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"net"
"net/netip"
"slices"
"testing"
"time"
@@ -1030,48 +1029,6 @@ func TestComponents_RouteDefaultPermit(t *testing.T) {
assert.True(t, hasDefaultPermit, "route without ACG should have default permit rule with 0.0.0.0/0 source")
}
// TestComponents_ExitNodeDefaultPermitIPv6 verifies that a default exit node route
// (0.0.0.0/0) without AccessControlGroups also emits an IPv6 default permit rule
// (::/0 source and destination) for peers that support IPv6, mirroring the route
// the client installs. Without it, IPv6 traffic is routed to the exit node but
// dropped at the forward chain.
func TestComponents_ExitNodeDefaultPermitIPv6(t *testing.T) {
account, validatedPeers := scalableTestAccount(20, 2)
routingPeerID := "peer-5"
routingPeer := account.Peers[routingPeerID]
routingPeer.IPv6 = netip.MustParseAddr("fd00::5")
routingPeer.Meta.Capabilities = append(routingPeer.Meta.Capabilities, nbpeer.PeerCapabilityIPv6Overlay)
account.Routes["route-exit"] = &route.Route{
ID: "route-exit", Network: netip.MustParsePrefix("0.0.0.0/0"),
PeerID: routingPeerID, Peer: routingPeer.Key,
Enabled: true, Groups: []string{"group-all"}, PeerGroups: []string{"group-0"},
AccessControlGroups: []string{},
AccountID: "test-account",
}
nm := componentsNetworkMap(account, routingPeerID, validatedPeers)
require.NotNil(t, nm)
hasV4 := false
hasV6 := false
for _, rfr := range nm.RoutesFirewallRules {
switch rfr.Destination {
case "0.0.0.0/0":
if slices.Contains(rfr.SourceRanges, "0.0.0.0/0") {
hasV4 = true
}
case "::/0":
if slices.Contains(rfr.SourceRanges, "::/0") {
hasV6 = true
}
}
}
assert.True(t, hasV4, "exit node route should have an IPv4 default permit rule (0.0.0.0/0)")
assert.True(t, hasV6, "exit node route should have an IPv6 default permit rule (::/0)")
}
// ──────────────────────────────────────────────────────────────────────────────
// 15. MULTIPLE ROUTERS PER NETWORK
// ──────────────────────────────────────────────────────────────────────────────

View File

@@ -6,6 +6,7 @@ import (
"time"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
)
@@ -119,8 +120,8 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
}
// PeerConnected increments the number of connected peers and increments number of idle connections
func (m *Metrics) PeerConnected(id string) {
m.peers.Add(m.ctx, 1)
func (m *Metrics) PeerConnected(id, transport string) {
m.peers.Add(m.ctx, 1, metric.WithAttributes(attribute.String("transport", transport)))
m.mutexActivity.Lock()
defer m.mutexActivity.Unlock()
@@ -138,8 +139,8 @@ func (m *Metrics) RecordPeerStoreTime(duration time.Duration) {
}
// PeerDisconnected decrements the number of connected peers and decrements number of idle or active connections
func (m *Metrics) PeerDisconnected(id string) {
m.peers.Add(m.ctx, -1)
func (m *Metrics) PeerDisconnected(id, transport string) {
m.peers.Add(m.ctx, -1, metric.WithAttributes(attribute.String("transport", transport)))
m.mutexActivity.Lock()
defer m.mutexActivity.Unlock()

View File

@@ -11,4 +11,6 @@ type Conn interface {
Write(ctx context.Context, b []byte) (n int, err error)
RemoteAddr() net.Addr
Close() error
// Protocol returns the transport name.
Protocol() string
}

View File

@@ -42,6 +42,11 @@ func (c *Conn) RemoteAddr() net.Addr {
return c.session.RemoteAddr()
}
// Protocol returns the transport name for this connection.
func (c *Conn) Protocol() string {
return "quic"
}
func (c *Conn) Close() error {
c.closedMu.Lock()
if c.closed {

View File

@@ -64,6 +64,11 @@ func (c *Conn) RemoteAddr() net.Addr {
return c.rAddr
}
// Protocol returns the transport name for this connection.
func (c *Conn) Protocol() string {
return "ws"
}
func (c *Conn) Close() error {
c.closedMu.Lock()
c.closed = true

View File

@@ -154,15 +154,16 @@ func (r *Relay) Accept(conn listener.Conn) {
}
r.notifier.PeerCameOnline(peer.ID())
transport := conn.Protocol()
r.metrics.RecordPeerStoreTime(time.Since(storeTime))
r.metrics.PeerConnected(peer.String())
r.metrics.PeerConnected(peer.String(), transport)
go func() {
peer.Work()
if deleted := r.store.DeletePeer(peer); deleted {
r.notifier.PeerWentOffline(peer.ID())
}
peer.log.Debugf("relay connection closed")
r.metrics.PeerDisconnected(peer.String())
r.metrics.PeerDisconnected(peer.String(), transport)
}()
if err := h.handshakeResponse(hsCtx); err != nil {

View File

@@ -9,12 +9,14 @@ import (
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
"github.com/netbirdio/netbird/shared/relay/client/dialer"
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
"github.com/netbirdio/netbird/shared/relay/healthcheck"
"github.com/netbirdio/netbird/shared/relay/messages"
)
@@ -143,6 +145,11 @@ func (cc *connContainer) close() {
}
}
// transportConn is implemented by relay connections that know their transport.
type transportConn interface {
Protocol() string
}
// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
// managing connections to other peers. All exported functions are safe to call concurrently. After close the connection,
// the client can be reused by calling Connect again. When the client is closed, all connections are closed too.
@@ -172,6 +179,31 @@ type Client struct {
stateSubscription *PeersStateSubscription
mtu uint16
// transportFallback, when set, records datagram-too-large failures so a
// datagram-sized transport is avoided on subsequent connects. Shared via
// the manager.
transportFallback *transportFallback
// datagramFallbackTriggered guards a single fallback per connection so a
// burst of oversized datagrams triggers one reconnect, not many.
datagramFallbackTriggered atomic.Bool
// transport is the negotiated relay transport of the
// current connection, guarded by mu.
transport string
}
// Transport returns the negotiated relay transport of the current connection,
// or an empty string when not connected.
func (c *Client) Transport() string {
c.mu.Lock()
defer c.mu.Unlock()
return c.transport
}
// SetTransportFallback wires the shared datagram-transport fallback tracker.
func (c *Client) SetTransportFallback(tf *transportFallback) {
c.transportFallback = tf
}
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
@@ -361,12 +393,13 @@ func (c *Client) Close() error {
}
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
dialers := c.getDialers()
mode := transportModeFromEnv()
dialers := c.getDialers(mode)
var conn net.Conn
if c.serverIP.IsValid() {
var err error
conn, err = c.dialRaceDirect(ctx, dialers)
conn, err = c.dialRaceDirect(ctx, mode, dialers)
if err != nil {
c.log.Infof("dial via server IP %s failed, falling back to FQDN: %v", c.serverIP, err)
conn = nil
@@ -375,6 +408,9 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
if conn == nil {
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...)
if mode.sequential() {
rd.WithSequential()
}
var err error
conn, err = rd.Dial(ctx)
if err != nil {
@@ -382,6 +418,10 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
}
}
c.relayConn = conn
c.datagramFallbackTriggered.Store(false)
if tc, ok := conn.(transportConn); ok {
c.transport = tc.Protocol()
}
instanceURL, err := c.handShake(ctx)
if err != nil {
@@ -396,7 +436,7 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
}
// dialRaceDirect dials c.serverIP, preserving the original FQDN as the TLS ServerName for SNI.
func (c *Client) dialRaceDirect(ctx context.Context, dialers []dialer.DialeFn) (net.Conn, error) {
func (c *Client) dialRaceDirect(ctx context.Context, mode TransportMode, dialers []dialer.DialeFn) (net.Conn, error) {
directURL, serverName, err := substituteHost(c.connectionURL, c.serverIP)
if err != nil {
return nil, fmt.Errorf("substitute host: %w", err)
@@ -406,6 +446,9 @@ func (c *Client) dialRaceDirect(ctx context.Context, dialers []dialer.DialeFn) (
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, directURL, dialers...).
WithServerName(serverName)
if mode.sequential() {
rd.WithSequential()
}
return rd.Dial(ctx)
}
@@ -631,13 +674,53 @@ func (c *Client) writeTo(containerRef *connContainer, dstID messages.PeerID, pay
}
// the write always return with 0 length because the underling does not support the size feedback.
_, err = c.relayConn.Write(msg)
conn := c.relayConn
_, err = conn.Write(msg)
if err != nil {
c.log.Errorf("failed to write transport message: %s", err)
if errors.Is(err, netErr.ErrDatagramTooLarge) {
c.onDatagramTooLarge(conn, err)
} else {
c.log.Errorf("failed to write transport message: %s", err)
}
}
return len(payload), err
}
// onDatagramTooLarge reacts to a datagram rejected as too large for the path.
// When a non-datagram transport is available, it records a fallback for this
// server and closes the connection so the reconnect avoids datagram-sized
// transports. A single fallback is triggered per connection regardless of how
// many oversized datagrams arrive. cause carries the datagram size and budget.
func (c *Client) onDatagramTooLarge(conn net.Conn, cause error) {
// Handle one oversized datagram per connection; a burst triggers a single
// fallback (and a single log line), not many.
if !c.datagramFallbackTriggered.CompareAndSwap(false, true) {
return
}
// If the selected mode offers no non-datagram transport (e.g. pinned to a
// datagram-sized transport), reconnecting would just re-fail, so leave the
// connection up rather than loop.
if len(nonDatagramSized(c.baseDialers(transportModeFromEnv()))) == 0 {
c.log.Warnf("%s, but no non-datagram transport is available, not falling back", cause)
return
}
// Without the shared tracker a reconnect would just select the same
// transport again and re-fail, so leave the connection up rather than loop.
if c.transportFallback == nil {
c.log.Debugf("%s, but no transport fallback configured, leaving connection up", cause)
return
}
window := c.transportFallback.recordFailure(c.connectionURL)
c.log.Warnf("%s, avoiding datagram-sized transport for %s", cause, window)
if err := conn.Close(); err != nil {
c.log.Debugf("close relay connection for transport fallback: %s", err)
}
}
func (c *Client) listenForStopEvents(ctx context.Context, hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
for {
select {
@@ -729,6 +812,7 @@ func (c *Client) close(gracefullyExit bool) error {
return nil
}
c.serviceIsRunning = false
c.transport = ""
c.muInstanceURL.Lock()
c.instanceURL = nil

View File

@@ -0,0 +1,18 @@
package dialer
// DatagramSized is implemented by dialers whose connections carry each write in
// a single datagram, so a write can be rejected when it exceeds the path's
// datagram budget (e.g. QUIC). Transports without this capability (e.g.
// WebSocket over TCP) impose no per-write size limit, so the relay client can
// fall back to them when a datagram-sized transport rejects a write as too
// large. The capability is advertised per dialer rather than hardcoded, so a
// new transport only needs to declare whether it is datagram-sized.
type DatagramSized interface {
DatagramSized()
}
// IsDatagramSized reports whether d produces datagram-sized connections.
func IsDatagramSized(d DialeFn) bool {
_, ok := d.(DatagramSized)
return ok
}

View File

@@ -4,4 +4,9 @@ import "errors"
var (
ErrClosedByServer = errors.New("closed by server")
// ErrDatagramTooLarge is returned when a transport message exceeds the
// QUIC datagram size the path to the relay can carry. The relay client
// treats it as a signal to fall back to a non-datagram transport.
ErrDatagramTooLarge = errors.New("datagram frame too large")
)

View File

@@ -8,7 +8,6 @@ import (
"time"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
)
@@ -52,15 +51,17 @@ func (c *Conn) Read(b []byte) (n int, err error) {
}
func (c *Conn) Write(b []byte) (int, error) {
err := c.session.SendDatagram(b)
if err != nil {
err = c.remoteCloseErrHandling(err)
log.Errorf("failed to write to QUIC stream: %v", err)
return 0, err
if err := c.session.SendDatagram(b); err != nil {
return 0, c.writeErrHandling(err, len(b))
}
return len(b), nil
}
// Protocol returns the transport name for this connection.
func (c *Conn) Protocol() string {
return Network
}
func (c *Conn) RemoteAddr() net.Addr {
return c.session.RemoteAddr()
}
@@ -95,3 +96,15 @@ func (c *Conn) remoteCloseErrHandling(err error) error {
}
return err
}
// writeErrHandling normalizes SendDatagram errors. A datagram that exceeds the
// path's QUIC packet budget is mapped to ErrDatagramTooLarge (annotated with the
// datagram size and path budget) so the relay client can fall back to a
// non-datagram transport.
func (c *Conn) writeErrHandling(err error, size int) error {
var tooLarge *quic.DatagramTooLargeError
if errors.As(err, &tooLarge) {
return fmt.Errorf("%w: %d byte datagram over path budget %d", netErr.ErrDatagramTooLarge, size, tooLarge.MaxDatagramPayloadSize)
}
return c.remoteCloseErrHandling(err)
}

View File

@@ -2,13 +2,13 @@ package quic
import (
"context"
"errors"
"fmt"
"net"
"strings"
"time"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/logging"
log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/client/net"
@@ -23,6 +23,12 @@ func (d Dialer) Protocol() string {
return Network
}
// DatagramSized marks QUIC as a datagram-sized transport: relay traffic is
// carried in QUIC DATAGRAM frames, which must fit a single packet.
func (d Dialer) DatagramSized() {
// Intentional marker method; presence is the capability signal.
}
func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, error) {
quicURL, err := prepareURL(address)
if err != nil {
@@ -47,26 +53,21 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn,
MaxIdleTimeout: 4 * time.Minute,
EnableDatagrams: true,
InitialPacketSize: nbRelay.QUICInitialPacketSize,
Tracer: connectionTracer(quicURL),
}
udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{Port: 0})
if err != nil {
log.Errorf("failed to listen on UDP: %s", err)
return nil, err
return nil, fmt.Errorf("listen udp: %w", err)
}
udpAddr, err := net.ResolveUDPAddr("udp", quicURL)
if err != nil {
log.Errorf("failed to resolve UDP address: %s", err)
return nil, err
return nil, fmt.Errorf("resolve %s: %w", quicURL, err)
}
session, err := quic.Dial(ctx, udpConn, udpAddr, tlsClientConfig, quicConfig)
if err != nil {
if errors.Is(err, context.Canceled) {
return nil, err
}
log.Errorf("failed to dial to Relay server via QUIC '%s': %s", quicURL, err)
return nil, err
}
@@ -74,6 +75,28 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn,
return conn, nil
}
// connectionTracer returns a QUIC tracer that logs the DPLPMTUD result and the
// reason a relay connection closed, so the path MTU settled on and teardown
// cause are visible in logs. Lines carry the relay address as a structured
// field, matching the rest of the relay client logging.
func connectionTracer(addr string) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
relayLog := log.WithField("relay", addr)
return func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return &logging.ConnectionTracer{
UpdatedMTU: func(mtu logging.ByteCount, done bool) {
if done {
relayLog.Infof("QUIC path MTU settled at %d", mtu)
return
}
relayLog.Debugf("QUIC path MTU probing at %d", mtu)
},
ClosedConnection: func(err error) {
relayLog.Debugf("QUIC connection closed: %v", err)
},
}
}
}
func prepareURL(address string) (string, error) {
var host string
var defaultPort string

View File

@@ -3,6 +3,7 @@ package dialer
import (
"context"
"errors"
"fmt"
"net"
"time"
@@ -32,6 +33,7 @@ type RaceDial struct {
serverName string
dialerFns []DialeFn
connectionTimeout time.Duration
sequential bool
}
func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL string, dialerFns ...DialeFn) *RaceDial {
@@ -53,9 +55,24 @@ func (r *RaceDial) WithServerName(serverName string) *RaceDial {
return r
}
// WithSequential makes Dial try the dialers in order, falling back to the next
// only when one fails to connect, instead of racing them concurrently.
//
// Mutates the receiver and is not safe for concurrent reconfiguration; a
// RaceDial is intended to be constructed per dial and discarded.
func (r *RaceDial) WithSequential() *RaceDial {
r.sequential = true
return r
}
func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) {
if r.sequential {
return r.dialSequential(ctx)
}
connChan := make(chan dialResult, len(r.dialerFns))
winnerConn := make(chan net.Conn, 1)
errChan := make(chan error, 1)
abortCtx, abort := context.WithCancel(ctx)
defer abort()
@@ -63,15 +80,41 @@ func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) {
go r.dial(dfn, abortCtx, connChan)
}
go r.processResults(connChan, winnerConn, abort)
go r.processResults(connChan, winnerConn, errChan, abort)
conn, ok := <-winnerConn
if !ok {
return nil, errors.New("failed to dial to Relay server on any protocol")
return nil, <-errChan
}
return conn, nil
}
// dialSequential tries each dialer in order, returning the first connection and
// falling back to the next on failure.
func (r *RaceDial) dialSequential(ctx context.Context) (net.Conn, error) {
var errs []error
for _, dfn := range r.dialerFns {
if err := ctx.Err(); err != nil {
return nil, err
}
attemptCtx, cancel := context.WithTimeout(ctx, r.connectionTimeout)
r.log.Infof("dialing Relay server via %s", dfn.Protocol())
conn, err := dfn.Dial(attemptCtx, r.serverURL, r.serverName)
cancel()
if err != nil {
if errors.Is(err, context.Canceled) {
return nil, err
}
r.log.Errorf("failed to dial via %s: %s", dfn.Protocol(), err)
errs = append(errs, fmt.Errorf("%s: %w", dfn.Protocol(), err))
continue
}
r.log.Infof("successfully dialed via: %s", dfn.Protocol())
return conn, nil
}
return nil, dialErr(errs)
}
func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) {
ctx, cancel := context.WithTimeout(abortCtx, r.connectionTimeout)
defer cancel()
@@ -81,8 +124,9 @@ func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dia
connChan <- dialResult{Conn: conn, Protocol: dfn.Protocol(), Err: err}
}
func (r *RaceDial) processResults(connChan chan dialResult, winnerConn chan net.Conn, abort context.CancelFunc) {
func (r *RaceDial) processResults(connChan chan dialResult, winnerConn chan net.Conn, errChan chan error, abort context.CancelFunc) {
var hasWinner bool
errsByProtocol := make(map[string]error)
for i := 0; i < len(r.dialerFns); i++ {
dr := <-connChan
if dr.Err != nil {
@@ -90,6 +134,7 @@ func (r *RaceDial) processResults(connChan chan dialResult, winnerConn chan net.
r.log.Infof("connection attempt aborted via: %s", dr.Protocol)
} else {
r.log.Errorf("failed to dial via %s: %s", dr.Protocol, dr.Err)
errsByProtocol[dr.Protocol] = fmt.Errorf("%s: %w", dr.Protocol, dr.Err)
}
continue
}
@@ -107,5 +152,29 @@ func (r *RaceDial) processResults(connChan chan dialResult, winnerConn chan net.
hasWinner = true
winnerConn <- dr.Conn
}
if !hasWinner {
errChan <- dialErr(r.orderedErrs(errsByProtocol))
}
close(winnerConn)
}
// orderedErrs returns the per-protocol errors in dialer order, so the combined
// error is stable regardless of which attempt failed first.
func (r *RaceDial) orderedErrs(byProtocol map[string]error) []error {
errs := make([]error, 0, len(byProtocol))
for _, dfn := range r.dialerFns {
if err, ok := byProtocol[dfn.Protocol()]; ok {
errs = append(errs, err)
}
}
return errs
}
// dialErr combines per-dialer failures, preserving the underlying reasons
// (e.g. "connection refused") rather than a generic message.
func dialErr(errs []error) error {
if len(errs) == 0 {
return errors.New("no relay transport available")
}
return errors.Join(errs...)
}

View File

@@ -250,3 +250,66 @@ func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
}
}
}
func TestRaceDialSequentialFallback(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
var firstDialed, secondDialed bool
preferred := &MockDialer{
protocolStr: "quic",
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
firstDialed = true
return nil, errors.New("quic unreachable")
},
}
fallbackConn := &MockConn{remoteAddr: &MockAddr{network: "ws"}}
fallback := &MockDialer{
protocolStr: "ws",
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
secondDialed = true
return fallbackConn, nil
},
}
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, preferred, fallback).WithSequential()
conn, err := rd.Dial(context.Background())
if err != nil {
t.Fatalf("expected fallback to succeed, got %v", err)
}
if conn != fallbackConn {
t.Errorf("expected fallback connection, got %v", conn)
}
if !firstDialed || !secondDialed {
t.Errorf("expected both dialers attempted in order, first=%v second=%v", firstDialed, secondDialed)
}
}
func TestRaceDialSequentialPreferredWins(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
preferredConn := &MockConn{remoteAddr: &MockAddr{network: "quic"}}
preferred := &MockDialer{
protocolStr: "quic",
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
return preferredConn, nil
},
}
fallback := &MockDialer{
protocolStr: "ws",
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
t.Errorf("fallback dialer must not be tried when preferred succeeds")
return nil, errors.New("should not happen")
},
}
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, preferred, fallback).WithSequential()
conn, err := rd.Dial(context.Background())
if err != nil {
t.Fatalf("expected preferred to succeed, got %v", err)
}
if conn != preferredConn {
t.Errorf("expected preferred connection, got %v", conn)
}
}

View File

@@ -33,6 +33,11 @@ func NewConn(wsConn *websocket.Conn, serverAddress string, underlying net.Conn)
}
}
// Protocol returns the transport name for this connection.
func (c *Conn) Protocol() string {
return Network
}
func (c *Conn) Read(b []byte) (n int, err error) {
t, ioReader, err := c.Conn.Reader(c.ctx)
if err != nil {

View File

@@ -22,7 +22,7 @@ type Dialer struct {
}
func (d Dialer) Protocol() string {
return "WS"
return Network
}
func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, error) {
@@ -39,7 +39,12 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn,
if errors.Is(err, context.Canceled) {
return nil, err
}
log.Errorf("failed to dial to Relay server '%s': %s", wsURL, err)
// websocket.Dial wraps the cause in verbose layers; surface the
// underlying network error when present.
var opErr *net.OpError
if errors.As(err, &opErr) {
return nil, opErr
}
return nil, err
}
if resp.Body != nil {

View File

@@ -9,11 +9,42 @@ import (
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
)
// getDialers returns the list of dialers to use for connecting to the relay server.
func (c *Client) getDialers() []dialer.DialeFn {
if c.mtu > 0 && c.mtu > iface.DefaultMTU {
c.log.Infof("MTU %d exceeds default (%d), forcing WebSocket transport to avoid DATAGRAM frame size issues", c.mtu, iface.DefaultMTU)
return []dialer.DialeFn{ws.Dialer{}}
// getDialers returns the ordered dialers for connecting to the relay server. It
// applies the datagram fallback generically: if this server recently rejected a
// datagram-sized transport, those dialers are dropped, leaving the rest.
func (c *Client) getDialers(mode TransportMode) []dialer.DialeFn {
dialers := c.baseDialers(mode)
if c.transportFallback != nil && c.transportFallback.avoidDatagramSized(c.connectionURL) {
if filtered := nonDatagramSized(dialers); len(filtered) > 0 {
c.log.Infof("relay recently rejected a datagram-sized transport, avoiding it")
return filtered
}
}
return []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}}
return dialers
}
// baseDialers returns the ordered dialers for the mode, before any datagram
// fallback filtering. For racing modes (auto) the order is irrelevant; for
// prefer modes the first entry is tried before falling back to the second.
func (c *Client) baseDialers(mode TransportMode) []dialer.DialeFn {
switch mode {
case TransportModeWS:
c.log.Infof("%s=ws, using WebSocket transport", EnvRelayTransport)
return []dialer.DialeFn{ws.Dialer{}}
case TransportModeQUIC:
c.log.Infof("%s=quic, using QUIC transport", EnvRelayTransport)
return []dialer.DialeFn{quic.Dialer{}}
}
all := []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}}
if mode == TransportModePreferWS {
all = []dialer.DialeFn{ws.Dialer{}, quic.Dialer{}}
}
if c.mtu > 0 && c.mtu > iface.DefaultMTU {
c.log.Infof("MTU %d exceeds default (%d), avoiding datagram-sized transports", c.mtu, iface.DefaultMTU)
return nonDatagramSized(all)
}
return all
}

View File

@@ -0,0 +1,101 @@
//go:build !js
package client
import (
"os"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/shared/relay/client/dialer"
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
"github.com/netbirdio/netbird/shared/relay/client/dialer/quic"
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
)
// TestDatagramSizedCapability locks the capability the generic fallback relies
// on: QUIC is datagram-sized, WebSocket is not.
func TestDatagramSizedCapability(t *testing.T) {
assert.True(t, dialer.IsDatagramSized(quic.Dialer{}), "QUIC must advertise datagram-sized")
assert.False(t, dialer.IsDatagramSized(ws.Dialer{}), "WebSocket must not advertise datagram-sized")
}
func protocols(dialers []dialer.DialeFn) []string {
out := make([]string, len(dialers))
for i, d := range dialers {
out[i] = d.Protocol()
}
return out
}
func TestGetDialers(t *testing.T) {
const url = "rels://relay.example:443"
tests := []struct {
name string
mode string
mtu uint16
preferWS bool
want []string
}{
{name: "auto races quic and ws", mode: "auto", mtu: iface.DefaultMTU, want: []string{"quic", "ws"}},
{name: "ws pinned", mode: "ws", mtu: iface.DefaultMTU, want: []string{"ws"}},
{name: "quic pinned", mode: "quic", mtu: iface.DefaultMTU, want: []string{"quic"}},
{name: "prefer-quic orders quic first", mode: "prefer-quic", mtu: iface.DefaultMTU, want: []string{"quic", "ws"}},
{name: "prefer-ws orders ws first", mode: "prefer-ws", mtu: iface.DefaultMTU, want: []string{"ws", "quic"}},
{name: "mtu above default forces ws", mode: "auto", mtu: iface.DefaultMTU + 100, want: []string{"ws"}},
{name: "sticky fallback forces ws in auto", mode: "auto", mtu: iface.DefaultMTU, preferWS: true, want: []string{"ws"}},
{name: "sticky fallback forces ws in prefer-quic", mode: "prefer-quic", mtu: iface.DefaultMTU, preferWS: true, want: []string{"ws"}},
{name: "quic pin overrides sticky fallback", mode: "quic", mtu: iface.DefaultMTU, preferWS: true, want: []string{"quic"}},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv(EnvRelayTransport, tc.mode)
if tc.mode == "" {
os.Unsetenv(EnvRelayTransport)
}
tf := newTransportFallback()
if tc.preferWS {
tf.recordFailure(url)
}
c := &Client{
log: log.WithField("test", t.Name()),
connectionURL: url,
mtu: tc.mtu,
transportFallback: tf,
}
assert.Equal(t, tc.want, protocols(c.getDialers(transportModeFromEnv())))
})
}
}
// TestStickyFallbackAfterDatagramTooLarge verifies the full chain: an oversized
// datagram records a fallback that makes the next dial pick WebSocket, the way a
// reconnect would after the connection is closed.
func TestStickyFallbackAfterDatagramTooLarge(t *testing.T) {
const url = "rels://relay.example:443"
t.Setenv(EnvRelayTransport, string(TransportModeAuto))
c := &Client{
log: log.WithField("test", t.Name()),
connectionURL: url,
mtu: iface.DefaultMTU,
transportFallback: newTransportFallback(),
}
// First dial races both transports.
assert.Equal(t, []string{"quic", "ws"}, protocols(c.getDialers(transportModeFromEnv())))
// An oversized datagram records the fallback for this server.
c.onDatagramTooLarge(&closeTrackingConn{}, netErr.ErrDatagramTooLarge)
// The reconnect now sticks to WebSocket.
assert.Equal(t, []string{"ws"}, protocols(c.getDialers(transportModeFromEnv())))
}

View File

@@ -7,7 +7,11 @@ import (
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
)
func (c *Client) getDialers() []dialer.DialeFn {
func (c *Client) getDialers(_ TransportMode) []dialer.DialeFn {
// JS/WASM build only uses WebSocket transport
return []dialer.DialeFn{ws.Dialer{}}
}
func (c *Client) baseDialers(_ TransportMode) []dialer.DialeFn {
return []dialer.DialeFn{ws.Dialer{}}
}

View File

@@ -2,6 +2,7 @@ package client
import (
"context"
"sync/atomic"
"time"
"github.com/cenkalti/backoff/v4"
@@ -20,6 +21,10 @@ type Guard struct {
// maxBackoffInterval caps the exponential backoff between reconnect
// attempts.
maxBackoffInterval time.Duration
// lastErr is the error from the most recent failed reconnect attempt,
// surfaced as the home relay status while disconnected.
lastErr atomic.Pointer[error]
}
// NewGuard creates a new guard for the relay client. A non-positive
@@ -37,6 +42,19 @@ func NewGuard(sp *ServerPicker, maxBackoffInterval time.Duration) *Guard {
return g
}
// LastError returns the error from the most recent failed reconnect attempt, or
// nil if reconnection last succeeded.
func (g *Guard) LastError() error {
if p := g.lastErr.Load(); p != nil {
return *p
}
return nil
}
func (g *Guard) setLastError(err error) {
g.lastErr.Store(&err)
}
// StartReconnectTrys is called when the relay client is disconnected from the relay server.
// It attempts to reconnect to the relay server. The function first tries a quick reconnect
// to the same server that was used before, if the server URL is still valid. If the quick
@@ -63,6 +81,7 @@ func (g *Guard) StartReconnectTrys(ctx context.Context, relayClient *Client) {
case <-ticker.C:
if err := g.retry(ctx); err != nil {
log.Errorf("failed to pick new Relay server: %s", err)
g.setLastError(err)
continue
}
return
@@ -89,6 +108,7 @@ func (g *Guard) tryToQuickReconnect(parentCtx context.Context, rc *Client) bool
if err := rc.Connect(parentCtx); err != nil {
log.Errorf("failed to reconnect to relay server: %s", err)
g.setLastError(err)
return false
}
return true
@@ -100,6 +120,7 @@ func (g *Guard) retry(ctx context.Context) error {
if err != nil {
return err
}
g.setLastError(nil)
// prevent to work with a deprecated Relay client instance
g.drainRelayClientChan()
@@ -125,6 +146,7 @@ func (g *Guard) isServerURLStillValid(rc *Client) bool {
}
func (g *Guard) notifyReconnected() {
g.setLastError(nil)
select {
case g.OnReconnected <- struct{}{}:
default:

View File

@@ -79,23 +79,30 @@ type Manager struct {
cleanupInterval time.Duration
keepUnusedServerTime time.Duration
// transportFallback is shared across home and foreign relay clients so a
// datagram-too-large failure makes that server avoid datagram-sized transports across reconnects.
transportFallback *transportFallback
}
// NewManager creates a new manager instance.
// The serverURL address can be empty. In this case, the manager will not serve.
func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uint16, opts ...ManagerOption) *Manager {
tokenStore := &relayAuth.TokenStore{}
tf := newTransportFallback()
m := &Manager{
ctx: ctx,
peerID: peerID,
tokenStore: tokenStore,
mtu: mtu,
ctx: ctx,
peerID: peerID,
tokenStore: tokenStore,
mtu: mtu,
transportFallback: tf,
serverPicker: &ServerPicker{
TokenStore: tokenStore,
PeerID: peerID,
MTU: mtu,
ConnectionTimeout: defaultConnectionTimeout,
TransportFallback: tf,
},
relayClients: make(map[string]*RelayTrack),
onDisconnectedListeners: make(map[string]*list.List),
@@ -123,6 +130,9 @@ func (m *Manager) Serve() error {
client, err := m.serverPicker.PickServer(m.ctx)
if err != nil {
// record the initial failure so status shows the real reason before
// the guard's first retry tick
m.reconnectGuard.setLastError(err)
go m.reconnectGuard.StartReconnectTrys(m.ctx, nil)
} else {
m.storeClient(client)
@@ -235,6 +245,67 @@ func (m *Manager) ServerURLs() []string {
return m.serverPicker.ServerURLs.Load().([]string)
}
// RelayConnectError returns the error from the most recent failed home relay
// reconnect attempt, or nil if the relay last connected successfully.
func (m *Manager) RelayConnectError() error {
return m.reconnectGuard.LastError()
}
// RelayConnState is the connection state of a single relay server.
type RelayConnState struct {
// URL is the server's instance address when connected, otherwise the
// configured server URL.
URL string
// Transport is the negotiated transport, empty if not connected.
Transport string
// Err is set when the relay is not connected.
Err error
}
// RelayStates returns the connection state of the home relay and every foreign
// relay the manager currently tracks.
func (m *Manager) RelayStates() []RelayConnState {
var states []RelayConnState
m.relayClientMu.RLock()
home := m.relayClient
m.relayClientMu.RUnlock()
if home != nil {
st := relayConnState(home)
// The home relay reconnects through the guard, so the real failure
// reason lives there rather than on the (stale) client.
if st.Err != nil {
if gErr := m.reconnectGuard.LastError(); gErr != nil {
st.Err = gErr
}
}
states = append(states, st)
}
// Snapshot the tracks, then query each outside the map lock: a track can be
// held by an in-progress Connect, and blocking on it must not stall other
// relay operations.
m.relayClientsMutex.RLock()
tracks := make([]*RelayTrack, 0, len(m.relayClients))
for _, rt := range m.relayClients {
tracks = append(tracks, rt)
}
m.relayClientsMutex.RUnlock()
// Only connected foreign relays carry state; a failed connect is evicted
// immediately (openConnVia), so there is no error state to surface.
for _, rt := range tracks {
rt.RLock()
rc := rt.relayClient
rt.RUnlock()
if rc != nil {
states = append(states, relayConnState(rc))
}
}
return states
}
// HasRelayAddress returns true if the manager is serving. With this method can check if the peer can communicate with
// Relay service.
func (m *Manager) HasRelayAddress() bool {
@@ -287,6 +358,7 @@ func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string
m.relayClientsMutex.Unlock()
relayClient := NewClientWithServerIP(serverAddress, serverIP, m.tokenStore, m.peerID, m.mtu)
relayClient.SetTransportFallback(m.transportFallback)
err := relayClient.Connect(m.ctx)
if err != nil {
rt.err = err
@@ -452,3 +524,11 @@ func (m *Manager) notifyOnDisconnectListeners(serverAddress string) {
}
delete(m.onDisconnectedListeners, serverAddress)
}
func relayConnState(c *Client) RelayConnState {
addr, err := c.ServerInstanceURL()
if err != nil {
return RelayConnState{URL: c.connectionURL, Err: err}
}
return RelayConnState{URL: addr, Transport: c.Transport()}
}

View File

@@ -29,6 +29,7 @@ type ServerPicker struct {
PeerID string
MTU uint16
ConnectionTimeout time.Duration
TransportFallback *transportFallback
}
func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
@@ -39,6 +40,7 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
connResultChan := make(chan connResult, totalServers)
successChan := make(chan connResult, 1)
errChan := make(chan error, 1)
concurrentLimiter := make(chan struct{}, maxConcurrentServers)
log.Debugf("pick server from list: %v", sp.ServerURLs.Load().([]string))
@@ -53,23 +55,24 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
}(url)
}
go sp.processConnResults(connResultChan, successChan)
go sp.processConnResults(connResultChan, successChan, errChan)
select {
case cr, ok := <-successChan:
if !ok {
return nil, errors.New("failed to connect to any relay server: all attempts failed")
return nil, <-errChan
}
log.Infof("chosen home Relay server: %s", cr.Url)
return cr.RelayClient, nil
case <-ctx.Done():
return nil, fmt.Errorf("failed to connect to any relay server: %w", ctx.Err())
return nil, fmt.Errorf("connect to relay server: %w", ctx.Err())
}
}
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
log.Infof("try to connecting to relay server: %s", url)
relayClient := NewClient(url, sp.TokenStore, sp.PeerID, sp.MTU)
relayClient.SetTransportFallback(sp.TransportFallback)
err := relayClient.Connect(ctx)
resultChan <- connResult{
RelayClient: relayClient,
@@ -78,12 +81,14 @@ func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan con
}
}
func (sp *ServerPicker) processConnResults(resultChan chan connResult, successChan chan connResult) {
func (sp *ServerPicker) processConnResults(resultChan chan connResult, successChan chan connResult, errChan chan error) {
var hasSuccess bool
var errs []error
for numOfResults := 0; numOfResults < cap(resultChan); numOfResults++ {
cr := <-resultChan
if cr.Err != nil {
log.Tracef("failed to connect to Relay server: %s: %v", cr.Url, cr.Err)
errs = append(errs, cr.Err)
continue
}
log.Infof("connected to Relay server: %s", cr.Url)
@@ -99,5 +104,16 @@ func (sp *ServerPicker) processConnResults(resultChan chan connResult, successCh
hasSuccess = true
successChan <- cr
}
if !hasSuccess {
errChan <- pickErr(errs)
}
close(successChan)
}
// pickErr combines per-server connection failures into a single error.
func pickErr(errs []error) error {
if len(errs) == 0 {
return errors.New("no relay server available")
}
return errors.Join(errs...)
}

View File

@@ -0,0 +1,129 @@
package client
import (
"os"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/shared/relay/client/dialer"
)
// EnvRelayTransport pins the relay transport. Valid values: "auto" (default,
// race QUIC and WebSocket), "quic" (QUIC only), "ws" (WebSocket only),
// "prefer-quic" / "prefer-ws" (try the preferred transport first, fall back to
// the other only if it fails to connect; no race). The prefer modes trade a
// slower connect when the preferred transport is blackholed for deterministic
// transport selection.
const EnvRelayTransport = "NB_RELAY_TRANSPORT"
const (
// transportFallbackBase is the initial window a relay server avoids
// datagram-sized transports after a datagram is rejected as too large.
transportFallbackBase = 10 * time.Minute
// transportFallbackMax caps the pinned window when failures repeat.
transportFallbackMax = 60 * time.Minute
)
// TransportMode selects which relay dialers are used.
type TransportMode string
const (
TransportModeAuto TransportMode = "auto"
TransportModeQUIC TransportMode = "quic"
TransportModeWS TransportMode = "ws"
TransportModePreferQUIC TransportMode = "prefer-quic"
TransportModePreferWS TransportMode = "prefer-ws"
)
// transportModeFromEnv reads EnvRelayTransport, defaulting to auto for an empty
// or unrecognized value.
func transportModeFromEnv() TransportMode {
switch TransportMode(strings.ToLower(strings.TrimSpace(os.Getenv(EnvRelayTransport)))) {
case "", TransportModeAuto:
return TransportModeAuto
case TransportModeQUIC:
return TransportModeQUIC
case TransportModeWS:
return TransportModeWS
case TransportModePreferQUIC:
return TransportModePreferQUIC
case TransportModePreferWS:
return TransportModePreferWS
default:
log.Warnf("invalid %s value %q, using %q", EnvRelayTransport, os.Getenv(EnvRelayTransport), TransportModeAuto)
return TransportModeAuto
}
}
// sequential reports whether the mode tries dialers in order with fallback
// instead of racing them concurrently.
func (m TransportMode) sequential() bool {
return m == TransportModePreferQUIC || m == TransportModePreferWS
}
// transportFallback tracks relay servers that have rejected a datagram-sized
// transport (a write too large for the path) and should temporarily avoid such
// transports. It is shared across the relay manager so the preference survives
// client recreation (foreign relay clients are evicted and rebuilt on
// disconnect). Entries are keyed by server URL and expire after a window that
// grows on repeated failures.
type transportFallback struct {
mu sync.Mutex
entries map[string]*fallbackEntry
}
type fallbackEntry struct {
until time.Time
duration time.Duration
}
func newTransportFallback() *transportFallback {
return &transportFallback{entries: make(map[string]*fallbackEntry)}
}
// avoidDatagramSized reports whether serverURL is currently within a window
// where datagram-sized transports should be avoided.
func (f *transportFallback) avoidDatagramSized(serverURL string) bool {
f.mu.Lock()
defer f.mu.Unlock()
e := f.entries[serverURL]
return e != nil && time.Now().Before(e.until)
}
// recordFailure makes serverURL avoid datagram-sized transports for a window:
// transportFallbackBase on the first failure, doubling up to transportFallbackMax
// when a datagram transport fails again after a previous window expired. It
// returns the active window duration.
func (f *transportFallback) recordFailure(serverURL string) time.Duration {
f.mu.Lock()
defer f.mu.Unlock()
now := time.Now()
e := f.entries[serverURL]
switch {
case e == nil:
e = &fallbackEntry{duration: transportFallbackBase}
f.entries[serverURL] = e
case now.Before(e.until):
return time.Until(e.until)
default:
e.duration = min(e.duration*2, transportFallbackMax)
}
e.until = now.Add(e.duration)
return e.duration
}
// nonDatagramSized returns the dialers from in that are not datagram-sized,
// preserving order.
func nonDatagramSized(in []dialer.DialeFn) []dialer.DialeFn {
out := make([]dialer.DialeFn, 0, len(in))
for _, d := range in {
if !dialer.IsDatagramSized(d) {
out = append(out, d)
}
}
return out
}

View File

@@ -0,0 +1,140 @@
package client
import (
"net"
"os"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
)
// closeTrackingConn records whether Close was called; only Close is exercised.
type closeTrackingConn struct {
net.Conn
closed bool
}
func (c *closeTrackingConn) Close() error {
c.closed = true
return nil
}
func TestTransportModeFromEnv(t *testing.T) {
tests := []struct {
value string
want TransportMode
}{
{"", TransportModeAuto},
{"auto", TransportModeAuto},
{"quic", TransportModeQUIC},
{"QUIC", TransportModeQUIC},
{"ws", TransportModeWS},
{" Ws ", TransportModeWS},
{"prefer-quic", TransportModePreferQUIC},
{"prefer-ws", TransportModePreferWS},
{"garbage", TransportModeAuto},
}
for _, tc := range tests {
t.Run(tc.value, func(t *testing.T) {
t.Setenv(EnvRelayTransport, tc.value)
if tc.value == "" {
os.Unsetenv(EnvRelayTransport)
}
assert.Equal(t, tc.want, transportModeFromEnv())
})
}
}
func TestTransportFallbackRecordAndExpiry(t *testing.T) {
const url = "rels://relay.example:443"
f := newTransportFallback()
assert.False(t, f.avoidDatagramSized(url), "no fallback recorded yet")
d := f.recordFailure(url)
assert.Equal(t, transportFallbackBase, d, "first failure pins for the base window")
assert.True(t, f.avoidDatagramSized(url), "datagram-sized transport avoided within the window")
// A second failure while still inside the window must not grow the window.
d = f.recordFailure(url)
assert.LessOrEqual(t, d, transportFallbackBase, "still within the active window")
require.NotNil(t, f.entries[url])
assert.Equal(t, transportFallbackBase, f.entries[url].duration, "duration unchanged inside window")
// Expire the window: datagram-sized transport allowed again.
f.entries[url].until = time.Now().Add(-time.Second)
assert.False(t, f.avoidDatagramSized(url), "window expired, datagram-sized transport allowed")
}
func TestTransportFallbackGrowsOnRepeat(t *testing.T) {
const url = "rels://relay.example:443"
f := newTransportFallback()
want := transportFallbackBase
for i := range 6 {
d := f.recordFailure(url)
assert.Equal(t, want, d, "window after %d expiries", i)
// expire the window so the next failure is treated as a repeat
f.entries[url].until = time.Now().Add(-time.Second)
want = min(want*2, transportFallbackMax)
}
assert.Equal(t, transportFallbackMax, f.entries[url].duration, "window caps at the max")
}
func TestOnDatagramTooLargeAuto(t *testing.T) {
const url = "rels://relay.example:443"
t.Setenv(EnvRelayTransport, string(TransportModeAuto))
tf := newTransportFallback()
c := &Client{
log: log.WithField("test", t.Name()),
connectionURL: url,
transportFallback: tf,
}
conn := &closeTrackingConn{}
c.onDatagramTooLarge(conn, netErr.ErrDatagramTooLarge)
assert.True(t, conn.closed, "connection closed to force reconnect")
assert.True(t, tf.avoidDatagramSized(url), "fallback recorded for the server")
// A second oversized datagram on the same connection must not re-close.
conn.closed = false
c.onDatagramTooLarge(conn, netErr.ErrDatagramTooLarge)
assert.False(t, conn.closed, "single fallback per connection")
}
func TestOnDatagramTooLargeQUICPinned(t *testing.T) {
const url = "rels://relay.example:443"
t.Setenv(EnvRelayTransport, string(TransportModeQUIC))
tf := newTransportFallback()
c := &Client{
log: log.WithField("test", t.Name()),
connectionURL: url,
transportFallback: tf,
}
conn := &closeTrackingConn{}
c.onDatagramTooLarge(conn, netErr.ErrDatagramTooLarge)
assert.False(t, conn.closed, "QUIC pin keeps the connection, no fallback redial")
assert.False(t, tf.avoidDatagramSized(url), "QUIC pin records no fallback")
}
func TestTransportFallbackPerServer(t *testing.T) {
f := newTransportFallback()
f.recordFailure("rels://a.example:443")
assert.True(t, f.avoidDatagramSized("rels://a.example:443"))
assert.False(t, f.avoidDatagramSized("rels://b.example:443"), "fallback is scoped to one server")
}

View File

@@ -1,16 +1,15 @@
package util
import (
"fmt"
"io"
"os"
"path/filepath"
"slices"
"strconv"
"github.com/DeRuina/timberjack"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/grpclog"
"gopkg.in/natefinch/lumberjack.v2"
"github.com/netbirdio/netbird/formatter"
)
@@ -38,7 +37,8 @@ func InitLog(logLevel string, logs ...string) error {
func InitLogger(logger *log.Logger, logLevel string, logs ...string) error {
level, err := log.ParseLevel(logLevel)
if err != nil {
return fmt.Errorf("failed parsing log-level %s: %w", logLevel, err)
logger.Errorf("Failed parsing log-level %s: %s", logLevel, err)
return err
}
var writers []io.Writer
logFmt := os.Getenv("NB_LOG_FORMAT")
@@ -59,11 +59,7 @@ func InitLogger(logger *log.Logger, logLevel string, logs ...string) error {
case "":
logger.Warnf("empty log path received: %#v", logPath)
default:
writer, err := setupLogFile(logPath, isRotationDisabled(logger))
if err != nil {
return fmt.Errorf("failed setting up log file: %s, %w", logPath, err)
}
writers = append(writers, writer)
writers = append(writers, newRotatedOutput(logPath))
}
}
@@ -98,43 +94,17 @@ func FindFirstLogPath(logs []string) string {
return ""
}
func isRotationDisabled(logger *log.Logger) bool {
v, _ := os.LookupEnv("NB_LOG_DISABLE_ROTATION")
disabled, _ := strconv.ParseBool(v)
if disabled {
logger.Warnf("log rotation is disabled by env flag")
return true
}
conflict, configPath := FindFirstLogrotateConflict()
if conflict {
logger.Warnf("log rotation conflict detected in: %#v, rotation is disabled", configPath)
return true
}
return false
}
func setupLogFile(logPath string, disableRotation bool) (io.Writer, error) {
if disableRotation {
file, err := os.OpenFile(logPath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0600)
if err != nil {
return nil, err
}
return file, nil
}
return newRotatedOutput(logPath), nil
}
func newRotatedOutput(logPath string) io.Writer {
maxLogSize := getLogMaxSize()
timberjackLogger := &timberjack.Logger{
lumberjackLogger := &lumberjack.Logger{
// Log file absolute path, os agnostic
Filename: filepath.ToSlash(logPath),
MaxSize: maxLogSize, // MB
MaxBackups: 10,
MaxAge: 30, // days
Compression: "gzip",
Filename: filepath.ToSlash(logPath),
MaxSize: maxLogSize, // MB
MaxBackups: 10,
MaxAge: 30, // days
Compress: true,
}
return timberjackLogger
return lumberjackLogger
}
func setGRPCLibLogger(logger *log.Logger) {
@@ -157,7 +127,7 @@ func getLogMaxSize() int {
if sizeVar, ok := os.LookupEnv("NB_LOG_MAX_SIZE_MB"); ok {
size, err := strconv.ParseInt(sizeVar, 10, 64)
if err != nil {
log.Errorf("failed parsing log-size %s: %s. Should be just an integer", sizeVar, err)
log.Errorf("Failed parsing log-size %s: %s. Should be just an integer", sizeVar, err)
return defaultLogSize
}

View File

@@ -1,96 +0,0 @@
package util
import (
"io"
"os"
"path/filepath"
"strings"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
)
// TestSetupLogFile_RotatesOnSize drives >MaxSize bytes through the writer
// returned by setupLogFile and asserts a backup file appears.
func TestSetupLogFile_RotatesOnSize(t *testing.T) {
t.Setenv("NB_LOG_MAX_SIZE_MB", "1")
dir := t.TempDir()
logPath := filepath.Join(dir, "netbird.log")
w, err := setupLogFile(logPath, false)
require.NoError(t, err)
t.Cleanup(func() {
if c, ok := w.(io.Closer); ok {
_ = c.Close()
}
})
chunk := []byte(strings.Repeat("x", 64*1024) + "\n")
for range 20 {
_, err := w.Write(chunk)
require.NoError(t, err)
}
info, err := os.Stat(logPath)
require.NoError(t, err)
require.Less(t, info.Size(), int64(1<<20),
"active log should be < 1 MB after rotation, got %d", info.Size())
require.Eventually(t, func() bool {
entries, _ := os.ReadDir(dir)
for _, e := range entries {
name := e.Name()
if name == filepath.Base(logPath) {
continue
}
if strings.HasPrefix(name, "netbird-") && strings.HasSuffix(name, ".log.gz") {
return true
}
}
return false
}, 5*time.Second, 50*time.Millisecond, "expected a rotated backup file in %s", dir)
}
// TestSetupLogFile_RotationDisabled verifies that with rotation off, the file
// grows past MaxSize and no backups are created.
func TestSetupLogFile_RotationDisabled(t *testing.T) {
t.Setenv("NB_LOG_MAX_SIZE_MB", "1")
dir := t.TempDir()
logPath := filepath.Join(dir, "netbird.log")
w, err := setupLogFile(logPath, true)
require.NoError(t, err)
f, ok := w.(*os.File)
require.True(t, ok, "expected plain *os.File when rotation is disabled, got %T", w)
t.Cleanup(func() { _ = f.Close() })
chunk := []byte(strings.Repeat("y", 64*1024) + "\n")
for range 20 {
_, err := w.Write(chunk)
require.NoError(t, err)
}
info, err := os.Stat(logPath)
require.NoError(t, err)
require.GreaterOrEqual(t, info.Size(), int64(1<<20),
"file should exceed MaxSize when rotation is disabled, got %d", info.Size())
entries, err := os.ReadDir(dir)
require.NoError(t, err)
require.Len(t, entries, 1, "no backup files should exist when rotation is disabled, got %v", entries)
}
// TestIsRotationDisabled_EnvFlag covers the NB_LOG_DISABLE_ROTATION env path.
// The logrotate-conflict branch is exercised separately on linux.
func TestIsRotationDisabled_EnvFlag(t *testing.T) {
logger := log.New()
logger.SetOutput(io.Discard)
t.Setenv("NB_LOG_DISABLE_ROTATION", "true")
require.True(t, isRotationDisabled(logger))
}

View File

@@ -1,93 +0,0 @@
//go:build linux
package util
import (
"bufio"
"errors"
"io/fs"
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
)
const (
defaultLogrotateConfPath = "/etc/logrotate.conf"
defaultLogrotateConfDir = "/etc/logrotate.d"
netbirdString = "netbird"
)
// FindLogrotateConflicts scans the standard logrotate locations for
// indications of conflict with netbird. It returns true and the config file
// path if a conflict was found.
func FindFirstLogrotateConflict() (bool, string) {
return findFirstLogrotateConflictIn(defaultLogrotateConfPath, defaultLogrotateConfDir)
}
func findFirstLogrotateConflictIn(confPath, confDir string) (bool, string) {
for _, f := range listLogrotateConfigs(confPath, confDir) {
present, err := scanLogrotateFile(f, netbirdString)
if err != nil {
if !errors.Is(err, fs.ErrNotExist) {
log.Debugf("scan %s: %v", f, err)
}
continue
}
if present {
return present, f
}
}
return false, ""
}
// listLogrotateConfigs returns all config files for logrotate.
func listLogrotateConfigs(confPath, confDir string) []string {
files := []string{confPath}
entries, err := os.ReadDir(confDir)
if err != nil {
return files
}
for _, e := range entries {
if e.IsDir() {
continue
}
files = append(files, filepath.Join(confDir, e.Name()))
}
return files
}
// scanLogrotateFile reads a config and reports if a non-comment line
// contains the given substring.
func scanLogrotateFile(path string, substring string) (bool, error) {
f, err := os.Open(path)
if err != nil {
return false, err
}
defer func() {
if err := f.Close(); err != nil {
log.Debugf("close %s: %v", path, err)
}
}()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(stripLogrotateComment(scanner.Text()))
if line == "" {
continue
}
if strings.Contains(line, substring) {
return true, nil
}
}
if err := scanner.Err(); err != nil {
return false, err
}
return false, nil
}
func stripLogrotateComment(line string) string {
before, _, _ := strings.Cut(line, "#")
return before
}

View File

@@ -1,95 +0,0 @@
//go:build linux
package util
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
)
func TestFindFirstLogrotateConflict(t *testing.T) {
t.Run("conflict in confDir", func(t *testing.T) {
confPath, confDir := newLogrotateLayout(t)
conflictPath := filepath.Join(confDir, "netbird")
writeLogrotateConfig(t, conflictPath, `/var/log/netbird/*.log {
daily
rotate 7
}`)
writeLogrotateConfig(t, filepath.Join(confDir, "nginx"), `/var/log/nginx/*.log { daily }`)
got, path := findFirstLogrotateConflictIn(confPath, confDir)
require.True(t, got)
require.Equal(t, conflictPath, path)
})
t.Run("conflict in main conf file", func(t *testing.T) {
confPath, confDir := newLogrotateLayout(t)
writeLogrotateConfig(t, confPath, `weekly
rotate 4
include /etc/logrotate.d
/var/log/netbird/client.log { rotate 5 }`)
got, path := findFirstLogrotateConflictIn(confPath, confDir)
require.True(t, got)
require.Equal(t, confPath, path)
})
t.Run("no conflict when netbird is absent", func(t *testing.T) {
confPath, confDir := newLogrotateLayout(t)
writeLogrotateConfig(t, filepath.Join(confDir, "nginx"), `/var/log/nginx/*.log { daily }`)
writeLogrotateConfig(t, filepath.Join(confDir, "syslog"), `/var/log/syslog { weekly }`)
got, path := findFirstLogrotateConflictIn(confPath, confDir)
require.False(t, got)
require.Empty(t, path)
})
t.Run("commented-out netbird line is ignored", func(t *testing.T) {
confPath, confDir := newLogrotateLayout(t)
writeLogrotateConfig(t, filepath.Join(confDir, "misc"), `# /var/log/netbird/*.log { daily }
/var/log/other.log { weekly }`)
got, path := findFirstLogrotateConflictIn(confPath, confDir)
require.False(t, got)
require.Empty(t, path)
})
t.Run("subdirectories in confDir are ignored", func(t *testing.T) {
confPath, confDir := newLogrotateLayout(t)
sub := filepath.Join(confDir, "nested")
require.NoError(t, os.MkdirAll(sub, 0o755))
writeLogrotateConfig(t, filepath.Join(sub, "netbird"), `/var/log/netbird/*.log { daily }`)
got, path := findFirstLogrotateConflictIn(confPath, confDir)
require.False(t, got)
require.Empty(t, path)
})
t.Run("missing paths return no conflict", func(t *testing.T) {
dir := t.TempDir()
got, path := findFirstLogrotateConflictIn(
filepath.Join(dir, "does-not-exist.conf"),
filepath.Join(dir, "does-not-exist.d"),
)
require.False(t, got)
require.Empty(t, path)
})
}
// newLogrotateLayout creates a temp logrotate.conf path and logrotate.d dir,
// returning their paths. The conf file itself is not created.
func newLogrotateLayout(t *testing.T) (confPath, confDir string) {
t.Helper()
root := t.TempDir()
confDir = filepath.Join(root, "logrotate.d")
require.NoError(t, os.MkdirAll(confDir, 0o755))
return filepath.Join(root, "logrotate.conf"), confDir
}
func writeLogrotateConfig(t *testing.T, path, body string) {
t.Helper()
require.NoError(t, os.WriteFile(path, []byte(body), 0o644))
}

View File

@@ -1,10 +0,0 @@
//go:build !linux
package util
// FindLogrotateConflicts scans the standard logrotate locations for
// indications of conflict with netbird. It will always return false for
// non-linux devices.
func FindFirstLogrotateConflict() (bool, string) {
return false, ""
}